#!python

# Created on Tue Sep 8 20:45:18 2020
# Author: XiaoTao Wang

## Required modules

import argparse, sys, logging, logging.handlers, traceback, eaglec, os, subprocess, glob

currentVersion = eaglec.__version__


def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Predict SVs at single resolution.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('-H', '--hic', help='''Path to a contact matrix in .cool format. Only matrices at <=50kb and 500kb are supported.''')
    parser.add_argument('-O', '--output-file', help='''Output file name''')
    parser.add_argument('-g', '--genome', default='hg38', choices = ['hg38', 'hg19', 'chm13', 'other'],
                       help='''Reference genome name.''')
    parser.add_argument('-C', '--chroms', nargs = '*', default = ['#', 'X'],
                        help = 'List of chromosome labels. Only Hi-C data within the specified '
                        'chromosomes will be included. Specially, "#" stands for chromosomes '
                        'with numerical labels. "--chroms" with zero argument will include '
                        'all chromosome data.')
    parser.add_argument('--balance-type', default='ICE', choices=['ICE', 'CNV', 'Raw'],
                        help = '''Normalization method. If you choose ICE, make sure you have run
                        "cooler balance" on your Hi-C matrix before you run this command; If you
                        choose CNV, make sure you have run "correct-cnv" of the NeoLoopFinder toolkit
                        before you run this command.''')
    parser.add_argument('--low-resolution-breaks', help='''SV breakpoints detected by this program on lower-resolution
                        matrices (the "full" format is required, see "--output-format" below). By default, the program
                        searches for SVs throughout the whole genome; when this file is provided, the program will only search
                        for SVs within the local region (determined by the parameter "--region-size") of each input SV.''')
    parser.add_argument('--region-size', default=100000, type=int, help='''The extended genomic span of the input
                        SV breakpoints (bp). Ignore it if you want to search SVs throughout the whole genome.''')
    parser.add_argument('--output-format', default='full', choices=['full', 'NeoLoopFinder'],
                        help='''Format of the reported SVs. full: 8 columns will be reported for each
                        SV, information includes breakpoint coordinates and probability values of each
                        fusion type (++/+-/-+/--); NeoLoopFinder: 6-column format that can be directly
                        used as the NeoLoopFinder input.''')
    parser.add_argument('--maximum-size', type=int, help='''Maximum genomic distance (in base pairs) between
                        SV breakpoints. For example, if "--maximum-size 1000000" is specified on your command,
                        the program will only consider SVs with breakpoint distance less than 1Mb, and ignore
                        any longer-range SVs and inter-chromosomal translocations.''')
    parser.add_argument('--cache-folder', help='''Path to the folder containing the pre-calculated expected
                        interaction frequencies at each genomic distance for each chromosome.''')
    parser.add_argument('--prob-cutoff', default=0.8, type=float, help='''Probability threshold.''')
    parser.add_argument('--logFile', default = 'eaglec.log', help = '''Logging file name.''')
    parser.add_argument('--add-log-header', action = 'store_true')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.FileHandler(args.logFile)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('INFO')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        ## Unified Formatter
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)
        
        if args.add_log_header:
            ## Logging for argument setting
            arglist = ['# ARGUMENT LIST:',
                    '# Cool URI = {0}'.format(args.hic),
                    '# Balance Type = {0}'.format(args.balance_type),
                    '# Reference Genome = {0}'.format(args.genome),
                    '# Included Chromosomes = {0}'.format(args.chroms),
                    '# Probability Cutoff = {0}'.format(args.prob_cutoff),
                    '# Low-resolution SVs = {0}'.format(args.low_resolution_breaks),
                    '# Extended Region Size = {0}bp'.format(args.region_size),
                    '# Maximum breakpoint distance = {0}'.format(args.maximum_size),
                    '# Output File Name = {0}'.format(args.output_file),
                    '# Output Format = {0}'.format(args.output_format),
                    '# Log file name = {0}'.format(args.logFile)
                    ]
            argtxt = '\n'.join(arglist)
            logger.info('\n' + argtxt)

        from eaglec.scoreUtils import intraPredict, interPredict
        from eaglec.utilities import list_intra_cache, list_inter_cache
        from eaglec.CNN import create_cnn
        import cooler
        import numpy as np

        clr = cooler.Cooler(args.hic)
        cnn_root_folder = os.path.join(os.path.split(eaglec.__file__)[0], 'data')
        if clr.binsize <= 50000:
            if 'sum' in clr.info:
                seq_depth = clr.info['sum']
            else:
                if clr.info['nnz'] < 100000000:
                    seq_depth = clr.pixels()[:]['count'].values.sum()
                else:
                    seq_depth = clr.info['nnz'] * 2.14447683 # the coefficient was estimated from 90 cool files at 10kb using linear model
            
            seq_depth = (303104/len(clr.bins())) * seq_depth

            logger.info('matched sequencing depth in human at 10Kb: {0}'.format(seq_depth))

            # load model
            if seq_depth > 300000000:
                cnn_folder = os.path.join(cnn_root_folder, 'bulk', '300M-800M')
            elif 200000000 < seq_depth <= 300000000:
                cnn_folder = os.path.join(cnn_root_folder, 'bulk', '200M-300M')
            elif 100000000 < seq_depth <= 200000000:
                cnn_folder = os.path.join(cnn_root_folder, 'bulk', '100M-200M')
            elif 50000000 < seq_depth <= 100000000:
                cnn_folder = os.path.join(cnn_root_folder, 'bulk', '50M-100M') 
            elif 10000000 < seq_depth <= 50000000:
                cnn_folder = os.path.join(cnn_root_folder, 'bulk', '10M-50M') 
            else:
                cnn_folder = os.path.join(cnn_root_folder, 'bulk', '5M-10M') 
            
            logger.info('Load CNN models from {0} ...'.format(cnn_folder))

            cnn_models = []
            for wf in ['CNN-weights.0.1.0.4.0.6.h5', 'CNN-weights.0.1.0.8.1.0.h5', 'CNN-weights.0.2.0.6.0.6.h5']:
                cnn_weights = os.path.join(cnn_folder, wf)
                cnn_models.append(create_cnn(cnn_weights))
            
            for wf in glob.glob(os.path.join(cnn_folder, 'CNN-weights.*.h5')):
                model = create_cnn(wf)
                cnn_models.append(model)

        elif clr.binsize == 500000:
            if 'sum' in clr.info:
                seq_depth = clr.info['sum']
            else:
                seq_depth = clr.pixels()[:]['count'].values.sum()
            
            if args.genome == 'other':
                logger.info('matched sequencing depth in human at 500kb: {0}'.format(seq_depth))
            
            # load model
            if seq_depth > 10000000:
                cnn_folder = os.path.join(cnn_root_folder, 'single-cells', '10M-20M')
            elif 5000000 < seq_depth <= 10000000:
                cnn_folder = os.path.join(cnn_root_folder, 'single-cells', '5M-10M')
            elif 3000000 < seq_depth <= 5000000:
                cnn_folder = os.path.join(cnn_root_folder, 'single-cells', '3M-5M')
            elif 1000000 < seq_depth <= 3000000:
                cnn_folder = os.path.join(cnn_root_folder, 'single-cells', '1M-3M')
            elif 750000 < seq_depth <= 1000000:
                cnn_folder = os.path.join(cnn_root_folder, 'single-cells', '750K-1M')
            elif 500000 < seq_depth <= 750000:
                cnn_folder = os.path.join(cnn_root_folder, 'single-cells', '500K-750K')
            elif 250000 < seq_depth <= 500000:
                cnn_folder = os.path.join(cnn_root_folder, 'single-cells', '250K-500K')
            else:
                cnn_folder = os.path.join(cnn_root_folder, 'single-cells', '100K-250K')
            
            logger.info('Load CNN models from {0} ...'.format(cnn_folder))
            
            cnn_models = []
            for wf in ['CNN-weights.0.1.0.2.h5', 'CNN-weights.0.6.0.6.h5', 'CNN-weights.1.0.1.0.h5']:
                cnn_weights = os.path.join(cnn_folder, wf)
                cnn_models.append(create_cnn(cnn_weights))
            
            for wf in glob.glob(os.path.join(cnn_folder, 'CNN-weights.*.h5')):
                model = create_cnn(wf)
                cnn_models.append(model)

        
        logger.info('Done')

        # read chromosomes
        chroms = []
        for c in clr.chromnames:
            chromlabel = c.lstrip('chr')
            if (not args.chroms) or (chromlabel.isdigit() and '#' in args.chroms) or (chromlabel in args.chroms):
                chroms.append(c)

        if args.balance_type == 'CNV':
            balance = 'sweight'
        elif args.balance_type == 'ICE':
            balance = 'weight'
        else:
            balance = False
        
         # cache folder
        cache_folder = '.{0}.{1}.{2}.{3}.{4}.{5}'.format(os.path.split(clr.filename)[1],
                                                     clr.info['nnz'],
                                                     args.balance_type,
                                                     os.path.split(str(args.low_resolution_breaks))[1],
                                                     args.region_size,
                                                     args.maximum_size) # job identify
        
       
        if not os.path.exists(cache_folder):
            os.mkdir(cache_folder)
        
        if args.low_resolution_breaks is None:
            logger.info('Interemediate results at the {0}kb resolution will be cached to {1}'.format(clr.binsize//1000, cache_folder))

        # predict intra SVs
        intra_expected_count = intraPredict(clr, cnn_models, chroms, cache_folder, seq_depth,
                                            balance=balance, ref=args.genome, width=10,
                                            extended_size=args.region_size, low_res_fil=args.low_resolution_breaks,
                                            maxsize=args.maximum_size, expected_folder=args.cache_folder)
                                                         
        # predict inter SVs
        if args.maximum_size is None:
            inter_expected_count = interPredict(clr, cnn_models, chroms, cache_folder, seq_depth, balance=balance, ref=args.genome, width=10,
                                                extended_size=args.region_size, low_res_fil=args.low_resolution_breaks)
        else:
            inter_expected_count = 0
        
        inter_collect = list_inter_cache(cache_folder)
        intra_collect = list_intra_cache(cache_folder)
        collect = inter_collect + intra_collect
        if (inter_expected_count + intra_expected_count) == len(collect):
            with open(args.output_file, 'w') as out:
                if args.output_format == 'full':
                    out.write('\t'.join(['chrom1', 'pos1', 'chrom2', 'pos2', '++', '+-', '-+', '--'])+'\n')
                for cache_fil in sorted(collect):
                    with open(cache_fil, 'r') as source:
                        for line in source:
                            c1, p1, c2, p2, label1, prob1, prob2, prob3, prob4, _ = line.rstrip().split()
                            prob1, prob2, prob3, prob4 = float(prob1), float(prob2), float(prob3), float(prob4)
                            if not label1 in ['0,0', '0,1', '1,0', '0,2', '2,0', '1,1']:
                                continue
                            if max([prob1, prob2, prob3, prob4]) > args.prob_cutoff:
                                if args.output_format == 'full':
                                    out.write('{0}\t{1}\t{2}\t{3}\t{4:.4g}\t{5:.4g}\t{6:.4g}\t{7:.4g}\n'.format(c1, p1, c2, p2, prob1, prob2, prob3, prob4))
                                else:
                                    strands = ['++', '+-', '-+', '--']
                                    probs = np.r_[[prob1, prob2, prob3, prob4]]
                                    idx = np.where(probs > 0.5)[0]
                                    for i in idx:
                                        strand = strands[i]
                                        annot = 'translocation'
                                        if c1 == c2:
                                            if strand == '+-':
                                                annot = 'deletion'
                                            elif strand == '-+':
                                                annot = 'duplication'
                                            elif strand in ['++', '--']:
                                                annot = 'inversion'
                                        out.write('\t'.join([c1, c2, strand, str(p1), str(p2), annot])+'\n')

        logger.info('Done')


if __name__ == '__main__':
    run()