#!/usr/bin/env python
import os
import sys
import argparse
import subprocess
import math
import pandas as pd
from scipy import stats
import PanChIP.commands as commands
from PanChIP.version import __version__

lib_name = 'v.3.0'

def remove_prefix(input_string, prefix):
    if prefix and input_string.startswith(prefix):
        return input_string[len(prefix):]
    return input_string

def remove_suffix(input_string, suffix):
    if suffix and input_string.endswith(suffix):
        return input_string[:-len(suffix)]
    return input_string

class Panchip(object):

    def __init__(self):
        parser = commands.panchip_parser()

        option_ix = 1
        while (option_ix < len(sys.argv) and
               sys.argv[option_ix].startswith('-')):
            option_ix += 1

        args = parser.parse_args(sys.argv[1:option_ix+1])

        if args.command is None or not hasattr(self, args.command):
            print('Unrecognized command')
            parser.print_help()
            exit(1)

        command = " ".join(sys.argv)

        import PanChIP

        getattr(self, args.command)([sys.argv[0]] + sys.argv[option_ix:])

    def init(self, argv):
        parser = commands.init_parser()
        args = parser.parse_args(argv[2:])

        lib_dir = os.path.abspath(args.library_directory)
        os.makedirs(lib_dir, exist_ok=True)

        # -------------------------------------------------------------
        # 1. Download split ZIP parts using Python
        # -------------------------------------------------------------
        import urllib.request
        import shutil
        import tarfile

        # Base URL where split files are hosted
        base_url = "https://hanjun.group/wp-content/uploads/2025/12"

        # Map remote filenames (WordPress-safe) to local "real" split names
        # Assumes you created the split archive with something like:
        #   zip -s 1500m -r v.3.0_split.zip v.3.0/
        parts_mapping = [
            # remote_name                 local_name (real split archive name)
            (f"panchip_{lib_name}_split0.zip", f"{lib_name}_split.zip"),   # last part
            (f"panchip_{lib_name}_split1.zip", f"{lib_name}_split.z01"),   # first part
            (f"panchip_{lib_name}_split2.zip", f"{lib_name}_split.z02"),   # second part
        ]

        local_paths = []
        for remote_name, local_name in parts_mapping:
            url = f"{base_url}/{remote_name}"
            local_path = os.path.join(lib_dir, local_name)
            local_paths.append(local_path)

            print(f"Downloading {url} -> {local_path}")
            with urllib.request.urlopen(url) as r, open(local_path, "wb") as f:
                shutil.copyfileobj(r, f)

        # -------------------------------------------------------------
        # 2. Recombine the split ZIP into a single normal ZIP
        #    using the `zip` command: zip -s 0 base_split --out full.zip
        # -------------------------------------------------------------
        split_base = os.path.join(lib_dir, f"{lib_name}_split.zip")
        combined_zip_path = os.path.join(lib_dir, f"{lib_name}_full.zip")

        print(f"Recombining split archive → {combined_zip_path}")
        # Requires the `zip` CLI to be installed
        subprocess.run(
            ["zip", "-s", "0", split_base, "--out", combined_zip_path],
            check=True
        )

        # Optionally remove the split parts once recombined
        for p in local_paths:
            try:
                os.remove(p)
            except OSError:
                pass

        # -------------------------------------------------------------
        # 3. Extract the recombined ZIP using system `unzip`
        #    This archive now directly contains the v.3.0/ directory.
        # -------------------------------------------------------------
        print("Extracting full library ZIP...")
        subprocess.run(
            ["unzip", "-o", combined_zip_path, "-d", lib_dir],
            check=True
        )

        try:
            os.remove(combined_zip_path)
        except OSError:
            pass

        # After this, you should have:  <lib_dir>/v.3.0/...
        # which matches how analysis/filter refer to the library:
        #   lib = <library_directory>/<lib_name>/Analysis
        #   lib = <library_directory>/<lib_name>/Filter

        # -------------------------------------------------------------
        # 4. Download & extract PanChIP tar.gz using Python
        # -------------------------------------------------------------
        tar_url = f"https://github.com/hanjunlee21/PanChIP/archive/refs/tags/v.{__version__}.tar.gz"
        tar_path = os.path.join(lib_dir, f"v.{__version__}.tar.gz")

        print(f"Downloading {tar_url} → {tar_path}")
        with urllib.request.urlopen(tar_url) as r, open(tar_path, "wb") as f:
            shutil.copyfileobj(r, f)

        print("Extracting tar.gz...")
        with tarfile.open(tar_path, "r:gz") as tar:
            tar.extractall(lib_dir)

        try:
            os.remove(tar_path)
        except OSError:
            pass

    def analysis(self, argv):
        parser = commands.analysis_parser()

        args = parser.parse_args(argv[2:])

        lib_dir = os.path.abspath(args.library_directory + '/Analysis')
        input_dir = os.path.abspath(args.input_directory)
        output_dir = os.path.abspath(args.output_directory)
        thread = str(args.threads)
        repeat = str(args.repeats)

        os.makedirs(output_dir, exist_ok = True)

        f = open(output_dir + '/bedfiles.sh','w+')
        f.write('#!/bin/bash\nls ' + input_dir + ' | awk \'{printf "%s ", $1}\' | sed \'s/.bed//g\' > ' + output_dir + '/bedfiles.txt')
        f.close()

        subprocess.call(['sh', output_dir + '/bedfiles.sh'])

        input_list = open(output_dir + '/bedfiles.txt', 'r').read().rstrip()
        subprocess.call(['rm', output_dir + '/bedfiles.sh'])
        subprocess.call(['rm', output_dir + '/bedfiles.txt'])

        f = open(output_dir + '/executable.sh','w+')
        f.write('#!/bin/bash\n\ninputfiles="' + input_list + '"\ninput="' + input_dir + '"\noutput="' + output_dir + '"\nlib="' + os.path.abspath(args.library_directory) + '/' + lib_name + '/Analysis"\nthreads="' + thread + '"\nrepeat="' + repeat + '"\n\n')

        with open(os.path.abspath(args.library_directory) + '/PanChIP-v.' + __version__ + '/PanChIP/analysis.sh') as infile:
            for line in infile:
                f.write(line)

        f.close()

        f = open(output_dir + '/analysis.sh','w+')
        f.write('#!/bin/bash\n\ncd ' + output_dir + '\nchmod u+x ./executable.sh\n./executable.sh')
        f.close()
        subprocess.call(['sh', output_dir + '/analysis.sh'])

    def filter(self, argv):
        parser = commands.filter_parser()

        args = parser.parse_args(argv[2:])

        lib_dir = os.path.abspath(args.library_directory + '/Filter')
        input_dir = os.path.abspath(args.input_file)
        output_dir = os.path.abspath(args.output_directory)
        thread = str(args.threads)

        os.makedirs(output_dir, exist_ok = True)

        f = open(output_dir + '/executable.sh','w+')
        f.write('#!/bin/bash\n\ninputfiles="' + remove_suffix(remove_prefix(input_dir, os.path.dirname(os.path.abspath(args.input_file)) + '/'), '.bed') + '"\ninput="' + os.path.dirname(os.path.abspath(args.input_file)) + '"\noutput="' + output_dir + '"\nlib="' + os.path.abspath(args.library_directory) + '/' + lib_name + '/Filter"\nthreads="' + thread + '"\n')

        with open(os.path.abspath(args.library_directory) + '/PanChIP-v.' + __version__ + '/PanChIP/filter.sh') as infile:
            for line in infile:
                f.write(line)

        f.close()

        f = open(output_dir + '/filter.sh','w+')
        f.write('#!/bin/bash\n\ncd ' + output_dir + '\nchmod u+x ./executable.sh\n./executable.sh')
        f.close()
        subprocess.call(['sh', output_dir + '/filter.sh'])

        df = pd.read_csv(output_dir + '/primary.output.tsv', sep='\t', header=0, names = ["TF","Experiment","Input"])
        st = pd.read_csv(output_dir + '/statistics.tsv', sep='\t', header=0, names=["TF","Mean","Standard Deviation","Signal-to-noise Ratio","Filter"])

        f = open(output_dir + '/adjusted_P.txt','w+')
        f.write('Adjusted P\n')
        for tf in st['TF']:
            tfdf = df[(df['TF'] == tf)]
            welch = stats.ttest_ind(tfdf['Input'], df['Input'], equal_var = False)
            adjp = min(welch.pvalue*math.sqrt(len(st['TF'])), 1)
            f.write(str(adjp)+'\n')
        f.close()
        f = open(output_dir + '/adjusted_P.sh','w+')
        f.write('#!/bin/bash\n\ncd ' + output_dir + '\nsed -i \'s/nan/NA/g\' adjusted_P.txt\npaste statistics.tsv adjusted_P.txt | awk -F \'\t\' \'{if(NR==1) {printf "%s\\t%s\\t%s\\t%s\\t%s\\t%s\\n",$1,$2,$3,$4,$6,$5} else if(NR>1&&$2>=0) {if($4>2&&$4!="NA"&&$6<0.05) {filter="PASS"} else {filter="FAIL"}; printf "%s\\t%s\\t%s\\t%s\\t%s\\t%s\\n",$1,$2,$3,$4,$6,filter}}\'> statistics.tmp\nrm statistics.tsv adjusted_P.txt\nmv statistics.tmp statistics.tsv\n')
        f.close()
        subprocess.call(['sh', output_dir + '/adjusted_P.sh'])
        subprocess.call(['rm', output_dir + '/adjusted_P.sh'])

if __name__ == '__main__':
    Panchip()
