#!python

# Created on Thu Aug 8 11:49:10 2021
# Author: XiaoTao Wang

## Required modules

import argparse, sys, eaglec, os

currentVersion = eaglec.__version__


def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Merge SV calls from different resolutions.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('--hic-10k', help='''Path to a 10kb matrix in .cool format''')
    parser.add_argument('--hic-5k', help='''Path to a 5kb matrix in .cool format''')
    parser.add_argument('--balance-type', default='ICE', choices=['ICE', 'CNV', 'Raw'],
                        help = '''Normalization method. If you choose ICE, make sure you have run
                        "cooler balance" on your Hi-C matrix before you run this command; If you
                        choose CNV, make sure you have run "correct-cnv" of the NeoLoopFinder toolkit
                        before you run this command.''')
    parser.add_argument('-C', '--chroms', nargs = '*', default = ['#', 'X'],
                        help = 'List of chromosome labels. Only Hi-C data within the specified '
                        'chromosomes will be included. Specially, "#" stands for chromosomes '
                        'with numerical labels. "--chroms" with zero argument will include '
                        'all chromosome data.')
    parser.add_argument('--full-sv-files', nargs='+', help='''Path to the input SV files in full format
                       (see "--output-format" below). If redundant SVs (the coordinates are close each
                       other, see "--buff-size" below) are detected in two input files, files that appear
                       first have priority over files that appear later.''')
    parser.add_argument('--buff-size', default=50000, type=int, help='''Two SVs are determined as redundant
                        if the genomic distances between both side of the breakpoints are less than this span.
                        (bp)''')
    parser.add_argument('--cache-10k', help='''Path to the folder containing the pre-calculated expected
                        interaction frequencies at the 10kb resolution.''')
    parser.add_argument('--cache-5k', help='''Path to the folder containing the pre-calculated expected
                        interaction frequencies at the 5kb resolution.''')
    parser.add_argument('--output-format', default='full', choices=['full', 'NeoLoopFinder'],
                        help='''Format of the reported SVs. full: 8 columns will be reported for each
                        SV, information includes breakpoint coordinates and probability values of each
                        fusion type (++/+-/-+/--); NeoLoopFinder: 6-column format that can be directly
                        used as the NeoLoopFinder input.''')
    parser.add_argument('-O', '--output-file', help='''Output file name''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        import numpy as np
        import cooler, glob
        from eaglec.CNN import create_cnn
        from eaglec.scoreUtils import interFiltering, intraFiltering
        from eaglec.utilities import list_intra_cache, list_inter_cache

        clr_10k = cooler.Cooler(args.hic_10k)
        clr_5k = cooler.Cooler(args.hic_5k)
        cnn_root_folder = os.path.join(os.path.split(eaglec.__file__)[0], 'data')
        # load CNN models for 10k marices
        if 'sum' in clr_10k.info:
            seq_depth_10k = clr_10k.info['sum']
        else:
            if clr_10k.info['nnz'] < 100000000:
                seq_depth_10k = clr_10k.pixels()[:]['count'].values.sum()
            else:
                seq_depth_10k = clr_10k.info['nnz'] * 2.14447683 # the coefficient was estimated from 90 cool files at 10kb using linear model
        
        seq_depth_10k = (303104/len(clr_10k.bins())) * seq_depth_10k

        # load model
        if seq_depth_10k > 300000000:
            cnn_folder = os.path.join(cnn_root_folder, 'bulk', '300M-800M')
        elif 200000000 < seq_depth_10k <= 300000000:
            cnn_folder = os.path.join(cnn_root_folder, 'bulk', '200M-300M')
        elif 100000000 < seq_depth_10k <= 200000000:
            cnn_folder = os.path.join(cnn_root_folder, 'bulk', '100M-200M')
        elif 50000000 < seq_depth_10k <= 100000000:
            cnn_folder = os.path.join(cnn_root_folder, 'bulk', '50M-100M') 
        elif 10000000 < seq_depth_10k <= 50000000:
            cnn_folder = os.path.join(cnn_root_folder, 'bulk', '10M-50M') 
        else:
            cnn_folder = os.path.join(cnn_root_folder, 'bulk', '5M-10M') 

        cnn_models_10k = []
        for wf in ['CNN-weights.0.1.0.4.0.6.h5', 'CNN-weights.0.1.0.8.1.0.h5', 'CNN-weights.0.2.0.6.0.6.h5']:
            cnn_weights = os.path.join(cnn_folder, wf)
            cnn_models_10k.append(create_cnn(cnn_weights))
            
        for wf in glob.glob(os.path.join(cnn_folder, 'CNN-weights.*.h5')):
            model = create_cnn(wf)
            cnn_models_10k.append(model)
        
        # load CNN models for 5k marices
        if 'sum' in clr_5k.info:
            seq_depth_5k = clr_5k.info['sum']
        else:
            if clr_5k.info['nnz'] < 100000000:
                seq_depth_5k = clr_5k.pixels()[:]['count'].values.sum()
            else:
                seq_depth_5k = clr_5k.info['nnz'] * 2.14447683 # the coefficient was estimated from 90 cool files at 10kb using linear model
        
        seq_depth_5k = (303104/len(clr_5k.bins())) * seq_depth_5k

        # load model
        if seq_depth_5k > 300000000:
            cnn_folder = os.path.join(cnn_root_folder, 'bulk', '300M-800M')
        elif 200000000 < seq_depth_5k <= 300000000:
            cnn_folder = os.path.join(cnn_root_folder, 'bulk', '200M-300M')
        elif 100000000 < seq_depth_5k <= 200000000:
            cnn_folder = os.path.join(cnn_root_folder, 'bulk', '100M-200M')
        elif 50000000 < seq_depth_5k <= 100000000:
            cnn_folder = os.path.join(cnn_root_folder, 'bulk', '50M-100M') 
        elif 10000000 < seq_depth_5k <= 50000000:
            cnn_folder = os.path.join(cnn_root_folder, 'bulk', '10M-50M') 
        else:
            cnn_folder = os.path.join(cnn_root_folder, 'bulk', '5M-10M') 

        cnn_models_5k = []
        for wf in ['CNN-weights.0.1.0.4.0.6.h5', 'CNN-weights.0.1.0.8.1.0.h5', 'CNN-weights.0.2.0.6.0.6.h5']:
            cnn_weights = os.path.join(cnn_folder, wf)
            cnn_models_5k.append(create_cnn(cnn_weights))
            
        for wf in glob.glob(os.path.join(cnn_folder, 'CNN-weights.*.h5')):
            model = create_cnn(wf)
            cnn_models_5k.append(model)

        # cache folder
        cache_folder = '.{0}.{1}.{2}.{3}.{4}'.format(clr_10k.info['nnz'],
                                                    clr_5k.info['nnz'],
                                                    args.balance_type,
                                                    '.'.join([os.path.split(f)[1] for f in args.full_sv_files]),
                                                    args.buff_size) # job identify
                                                             
        
       
        if not os.path.exists(cache_folder):
            os.mkdir(cache_folder)

        buff = args.buff_size
        SV_pool = [load_sv_full(fil) for fil in args.full_sv_files]
        SVs = SV_pool[0]
        if len(SV_pool) > 1:
            for query_SVs in SV_pool[1:]:
                for sv in query_SVs:
                    if not check_in(SVs, sv, buff_size=buff):
                        SVs.append(sv)
        
        bychroms = {}
        for c1, p1, c2, p2, prob1, prob2, prob3, prob4 in SVs:
            if not (c1, c2) in bychroms:
                bychroms[(c1, c2)] = []
            bychroms[(c1, c2)].append((p1, p2, prob1, prob2, prob3, prob4))
        
        # filtering
        if args.balance_type == 'CNV':
            balance = 'sweight'
        elif args.balance_type == 'ICE':
            balance = 'weight'
        else:
            balance = False

        chroms = []
        for c in clr_10k.chromnames:
            chromlabel = c.lstrip('chr')
            if (not args.chroms) or (chromlabel.isdigit() and '#' in args.chroms) or (chromlabel in args.chroms):
                chroms.append(c)
        

        # filtering inter-chromosomal SVs
        intra_expected_count = intraFiltering(clr_10k, clr_5k, cnn_models_10k, cnn_models_5k,
                                              seq_depth_10k, seq_depth_5k, chroms, cache_folder,
                                              bychroms, balance=balance, width=10, cache_5k=args.cache_5k,
                                              cache_10k=args.cache_10k)
        inter_expected_count = interFiltering(clr_10k, clr_5k, cnn_models_10k, cnn_models_5k,
                                              seq_depth_10k, seq_depth_5k, chroms, cache_folder,
                                              bychroms, balance=balance, width=10)
        inter_collect = list_inter_cache(cache_folder)
        intra_collect = list_intra_cache(cache_folder)
        collect = inter_collect + intra_collect
        if (inter_expected_count + intra_expected_count) == len(collect):
            with open(args.output_file, 'w') as out:
                if args.output_format == 'full':
                    out.write('\t'.join(['chrom1', 'pos1', 'chrom2', 'pos2', '++', '+-', '-+', '--'])+'\n')
                for cache_fil in sorted(collect):
                    with open(cache_fil, 'r') as source:
                        for line in source:
                            c1, p1, c2, p2, prob1, prob2, prob3, prob4 = line.rstrip().split()
                            prob1, prob2, prob3, prob4 = float(prob1), float(prob2), float(prob3), float(prob4)
                            if args.output_format == 'full':
                                out.write('{0}\t{1}\t{2}\t{3}\t{4:.4g}\t{5:.4g}\t{6:.4g}\t{7:.4g}\n'.format(c1, p1, c2, p2, prob1, prob2, prob3, prob4))
                            else:
                                strands = ['++', '+-', '-+', '--']
                                probs = np.r_[[prob1, prob2, prob3, prob4]]
                                idx = np.where(probs > 0.5)[0]
                                for i in idx:
                                    strand = strands[i]
                                    annot = 'translocation'
                                    if c1 == c2:
                                        if strand == '+-':
                                            annot = 'deletion'
                                        elif strand == '-+':
                                            annot = 'duplication'
                                        elif strand in ['++', '--']:
                                            annot = 'inversion'
                                    out.write('\t'.join([c1, c2, strand, str(p1), str(p2), annot])+'\n')


def check_in(pool, sv, buff_size=50000):

    label = False
    c1, p1, c2, p2 = sv[:4]
    if c1 > c2:
        c1, c2 = c2, c1
        p1, p2 = p2, p1
    for ref_sv in pool:
        ref_c1, ref_p1, ref_c2, ref_p2 = ref_sv[:4]
        if ref_c1 > ref_c2:
            ref_c1, ref_c2 = ref_c2, ref_c1
            ref_p1, ref_p2 = ref_p2, ref_p1
        
        if (c1 == ref_c1) and (c2 == ref_c2) and (abs(p1 - ref_p1) < buff_size) and (abs(p2 - ref_p2) < buff_size):
            label = True
            break
    
    return label

def load_sv_full(fil):

    SVs = []
    with open(fil, 'r') as source:
        source.readline()
        for line in source:
            c1, p1, c2, p2, prob1, prob2, prob3, prob4 = line.rstrip().split()
            p1, p2 = int(p1), int(p2)
            prob1, prob2, prob3, prob4 = float(prob1), float(prob2), float(prob3), float(prob4)
            SVs.append((c1, p1, c2, p2, prob1, prob2, prob3, prob4))
    
    return SVs


if __name__ == '__main__':
    run()