#!/usr/bin/env python
import os
import sys
import argparse
import subprocess
import math
import pandas as pd
from scipy import stats
import PanChIP.commands as commands
from PanChIP.version import __version__

lib_name = 'v.3.0'

def remove_prefix(input_string, prefix):
    if prefix and input_string.startswith(prefix):
        return input_string[len(prefix):]
    return input_string

def remove_suffix(input_string, suffix):
    if suffix and input_string.endswith(suffix):
        return input_string[:-len(suffix)]
    return input_string

class Panchip(object):

    def __init__(self):
        parser = commands.panchip_parser()

        option_ix = 1
        while (option_ix < len(sys.argv) and
               sys.argv[option_ix].startswith('-')):
            option_ix += 1
                
        args = parser.parse_args(sys.argv[1:option_ix+1])
        
        if args.command is None or not hasattr(self, args.command):
            print('Unrecognized command')
            parser.print_help()
            exit(1)
        
        command = " ".join(sys.argv)
        
        import PanChIP
        
        getattr(self, args.command)([sys.argv[0]] + sys.argv[option_ix:])

    def init(self, argv):
        parser = commands.init_parser()
        args = parser.parse_args(argv[2:])

        lib_dir = os.path.abspath(args.library_directory)
        os.makedirs(lib_dir, exist_ok=True)

        # -------------------------------------------------------------
        # 1. Download split ZIP parts using Python
        # -------------------------------------------------------------
        import urllib.request
        import shutil

        # Base URL where split files are hosted
        # Example: https://server/files  (no trailing slash)
        base_url = "https://hanjun.group/wp-content/uploads/2025/12"

        # Local filenames for downloaded split pieces
        split_files = [
            os.path.join(lib_dir, f"{lib_name}_split0.zip"),  # real: .zip (last part)
            os.path.join(lib_dir, f"{lib_name}_split1.zip"),  # real: .z01 (first)
            os.path.join(lib_dir, f"{lib_name}_split2.zip"),  # real: .z02 (second)
        ]

        # Download each split piece
        for idx, path in enumerate(split_files):
            url = f"{base_url}/{lib_name}_split{idx}.zip"
            print(f"Downloading {url} -> {path}")
            with urllib.request.urlopen(url) as r, open(path, "wb") as f:
                shutil.copyfileobj(r, f)

        # -------------------------------------------------------------
        # 2. Combine split ZIP files into one full ZIP
        # Order must be: .z01, .z02, ..., .zip  → [1, 2, 0]
        # -------------------------------------------------------------
        combined_zip_path = os.path.join(lib_dir, lib_name + ".zip")
        print(f"Combining parts → {combined_zip_path}")

        with open(combined_zip_path, "wb") as combined:
            for idx in [1, 2, 0]:
                part = split_files[idx]
                with open(part, "rb") as p:
                    shutil.copyfileobj(p, combined)

        # Remove the downloaded split pieces
        for p in split_files:
            try:
                os.remove(p)
            except:
                pass

        # -------------------------------------------------------------
        # 3. Extract the combined ZIP using Python
        # -------------------------------------------------------------
        import zipfile

        print("Extracting ZIP...")
        with zipfile.ZipFile(combined_zip_path, "r") as z:
            z.extractall(lib_dir)

        os.remove(combined_zip_path)

        # -------------------------------------------------------------
        # 4. Download & extract PanChIP tar.gz using Python
        # -------------------------------------------------------------
        import tarfile

        tar_url = f"https://github.com/hanjunlee21/PanChIP/archive/refs/tags/v.{__version__}.tar.gz"
        tar_path = os.path.join(lib_dir, f"v.{__version__}.tar.gz")

        print(f"Downloading {tar_url} → {tar_path}")
        with urllib.request.urlopen(tar_url) as r, open(tar_path, "wb") as f:
            shutil.copyfileobj(r, f)

        print("Extracting tar.gz...")
        with tarfile.open(tar_path, "r:gz") as tar:
            tar.extractall(lib_dir)

        os.remove(tar_path)

    def analysis(self, argv):
        parser = commands.analysis_parser()

        args = parser.parse_args(argv[2:])
        
        lib_dir = os.path.abspath(args.library_directory + '/Analysis')
        input_dir = os.path.abspath(args.input_directory)
        output_dir = os.path.abspath(args.output_directory)
        thread = str(args.threads)
        repeat = str(args.repeats)
        
        os.makedirs(output_dir, exist_ok = True)
        
        f = open(output_dir + '/bedfiles.sh','w+')
        f.write('#!/bin/bash\nls ' + input_dir + ' | awk \'{printf "%s ", $1}\' | sed \'s/.bed//g\' > ' + output_dir + '/bedfiles.txt')
        f.close()
        
        subprocess.call(['sh', output_dir + '/bedfiles.sh'])
        
        input_list = open(output_dir + '/bedfiles.txt', 'r').read().rstrip()
        subprocess.call(['rm', output_dir + '/bedfiles.sh'])
        subprocess.call(['rm', output_dir + '/bedfiles.txt'])
        
        f = open(output_dir + '/executable.sh','w+')
        f.write('#!/bin/bash\n\ninputfiles="' + input_list + '"\ninput="' + input_dir + '"\noutput="' + output_dir + '"\nlib="' + os.path.abspath(args.library_directory) + '/' + lib_name + '/Analysis"\nthreads="' + thread + '"\nrepeat="' + repeat + '"\n\n')
        
        with open(os.path.abspath(args.library_directory) + '/PanChIP-v.' + __version__ + '/PanChIP/analysis.sh') as infile:
            for line in infile:
                f.write(line)
                
        f.close()
        
        f = open(output_dir + '/analysis.sh','w+')
        f.write('#!/bin/bash\n\ncd ' + output_dir + '\nchmod u+x ./executable.sh\n./executable.sh')
        f.close()
        subprocess.call(['sh', output_dir + '/analysis.sh'])
        
    def filter(self, argv):
        parser = commands.filter_parser()

        args = parser.parse_args(argv[2:])
        
        lib_dir = os.path.abspath(args.library_directory + '/Filter')
        input_dir = os.path.abspath(args.input_file)
        output_dir = os.path.abspath(args.output_directory)
        thread = str(args.threads)
        
        os.makedirs(output_dir, exist_ok = True)
        
        f = open(output_dir + '/executable.sh','w+')
        f.write('#!/bin/bash\n\ninputfiles="' + remove_suffix(remove_prefix(input_dir, os.path.dirname(os.path.abspath(args.input_file)) + '/'), '.bed') + '"\ninput="' + os.path.dirname(os.path.abspath(args.input_file)) + '"\noutput="' + output_dir + '"\nlib="' + os.path.abspath(args.library_directory) + '/' + lib_name + '/Filter"\nthreads="' + thread + '"\n')
        
        with open(os.path.abspath(args.library_directory) + '/PanChIP-v.' + __version__ + '/PanChIP/filter.sh') as infile:
            for line in infile:
                f.write(line)
                
        f.close()
        
        f = open(output_dir + '/filter.sh','w+')
        f.write('#!/bin/bash\n\ncd ' + output_dir + '\nchmod u+x ./executable.sh\n./executable.sh')
        f.close()
        subprocess.call(['sh', output_dir + '/filter.sh'])
        
        df = pd.read_csv(output_dir + '/primary.output.tsv', sep='\t', header=0, names = ["TF","Experiment","Input"])
        st = pd.read_csv(output_dir + '/statistics.tsv', sep='\t', header=0, names=["TF","Mean","Standard Deviation","Signal-to-noise Ratio","Filter"])
        
        f = open(output_dir + '/adjusted_P.txt','w+')
        f.write('Adjusted P\n')
        for tf in st['TF']:
            tfdf = df[(df['TF'] == tf)]
            welch = stats.ttest_ind(tfdf['Input'], df['Input'], equal_var = False)
            adjp = min(welch.pvalue*math.sqrt(len(st['TF'])), 1)
            f.write(str(adjp)+'\n')
        f.close()
        f = open(output_dir + '/adjusted_P.sh','w+')
        f.write('#!/bin/bash\n\ncd ' + output_dir + '\nsed -i \'s/nan/NA/g\' adjusted_P.txt\npaste statistics.tsv adjusted_P.txt | awk -F \'\t\' \'{if(NR==1) {printf "%s\\t%s\\t%s\\t%s\\t%s\\t%s\\n",$1,$2,$3,$4,$6,$5} else if(NR>1&&$2>=0) {if($4>2&&$4!="NA"&&$6<0.05) {filter="PASS"} else {filter="FAIL"}; printf "%s\\t%s\\t%s\\t%s\\t%s\\t%s\\n",$1,$2,$3,$4,$6,filter}}\'> statistics.tmp\nrm statistics.tsv adjusted_P.txt\nmv statistics.tmp statistics.tsv\n')
        f.close()
        subprocess.call(['sh', output_dir + '/adjusted_P.sh'])
        subprocess.call(['rm', output_dir + '/adjusted_P.sh'])

if __name__ == '__main__':
    Panchip()
