#!python

# Created on Tue Sep 8 20:45:18 2020
# Author: XiaoTao Wang

## Required modules

import argparse, sys, logging, eaglec, os

currentVersion = eaglec.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Predict and combine SV predictions from contact maps
                                     at multiple resolutions.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('--mcool', help='''Path to a mcool file containing contact matrices at multiple resolutions''')
    parser.add_argument('--resolutions', default='5000,10000,25000',
                        help='''Resolutions at which the contact matrices will be used for SV prediction.''')
    parser.add_argument('--high-res', default='1000,2000',
                        help='''Resolutions that are exclusively used to detect short-range SVs and
                        refine SV calls identified at coarser resolutions. If a resolution specified
                        by this parameter is not included in the "--resolutions" list, it will be skipped.''')
    parser.add_argument('--balance-type', default='ICE', choices=['ICE', 'CNV', 'Raw'],
                        help = '''Normalization method.''')
    parser.add_argument('--model-path', default='EagleC2-models',
                        help='''Path to the folder containing pretrained models''')
    parser.add_argument('-O', '--output-prefix', help='''Prefix of the output file names''')
    parser.add_argument('-g', '--genome', default='hg38', choices = ['hg38', 'hg19', 'chm13', 'other'],
                       help='''Name of the reference genome.''')
    parser.add_argument('-C', '--chroms', nargs = '*', default = ['#', 'X'],
                        help = '''List of chromosome labels. Only Hi-C data within the specified
                        chromosomes will be included. Specially, "#" stands for chromosomes
                        with numerical labels. "--chroms" with zero argument will include
                        all chromosome data.''')
    parser.add_argument('--entropy-cutoff', default=0.97, type=float, help='''Entropy cutoff.
                        We utilize a pre-filtering procedure based on Shannon entropy before
                        feeding the images into the CNN models. Tuning down this cutoff can
                        greatly accelerate the program but may reduce sensitivity.
                        Value Range: (0, 1]''')
    parser.add_argument('--prob-cutoff-1', default=0.5, type=float, help='''Probability cutoff
                        for filtering the original SV calls. Value Range: [0, 1]''')
    parser.add_argument('--prob-cutoff-2', default=0.5, type=float, help='''Probability cutoff
                        used to determine the more precise coordinates of original SV calls at finer
                        resolutions. Value Range: [0, 1]''')
    parser.add_argument('--allowed-gaps', default=2, type=int,
                        help='''The maximum number of gap bins allowed near the SV breakpoints.''')
    parser.add_argument('-p', '--nproc', default=8, type=int,
                        help='The maximum number of processes to be allocated.')
    parser.add_argument('--cpu', action='store_true',
                        help='''If this flag is set, the program will run entirely on the CPU,
                        even if a GPU is available.''')
    # specific to intra-chromosomal interactions
    group_1 = parser.add_argument_group(title = 'Parameters for filtering intra-chromosomal candidates:')
    group_1.add_argument('--intra-qvalue-cutoff', default=0.03, type=float, help='''Q-value cutoff
                        for pre-filtering intra-chromosomal pixels before feeding images into the
                        CNN models. Lowering this value can speed up the program but may reduce
                        sensitivity. Value Range: (0, 1]''')
    group_1.add_argument('--intra-min', default=1, type=int,
                        help='''The minimum intra-chromosomal interaction frequency. For intra-chromosomal
                        interactions, only those with the interaction strength greater than this value
                        will be considered. Increasing this value may speed up the program but could reduce
                        sensitivity.''')
    group_1.add_argument('-k', '--max-dis', default=100, type=int,
                        help='''The maximum genomic distance, measured in pixels, within which 
                        the expected interaction strength is calculated to identify intra-chromosomal
                        significant interactions. Interactions beyond this distance will be assigned
                        the same expected interaction strength as those at this genomic distance.
                        Decreasing this value may speed up the program but could compromise sensitivity.''')
    group_1.add_argument('--intra-min-cluster-size', default=4, type=int,
                        help='''The minimum size of clusters for intra-chromosomal interactions. The
                        value will be directly passed to the HDBSCAN clustering algorithm. Setting it
                        to 0 will disable the HDBSCAN clustering.''')
    group_1.add_argument('--intra-min-samples', default=4, type=int,
                        help='''The number of samples in a neighbourhood for a point to be considered
                        a core point. The value will be directly passed to the HDBSCAN clustering
                        algorithm. Setting it to 0 will disable the HDBSCAN clustering.''')
    group_1.add_argument('--intra-max-cluster-size', default=250, type=int,
                        help='''The maximum size of clusters for intra-chromosomal interactions.''')
    group_1.add_argument('--intra-decay-rate', default=0.8, type=float,
                        help='''The decay rate for intra-chromosomal clusters with a size greater
                        than the "--intra-max-cluster-size". Value Range: (0, 1]''')
    group_1.add_argument('--intra-niter', default=3, type=int,
                        help='''The number of HDBSCAN clustering iterations. During each iteration,
                        large clusters will shrink at a rate determined by "--intra-decay-rate" by
                        discarding interactions with weaker signals.''')
    group_1.add_argument('--intra-extend-size', default='2,2,2',
                        help='''The number of units extended along each axis for each intra-chromosomal
                        candidate. The values should be separated by commas, with each value corresponding
                        to a resolution specified by the "--resolutions" parameter. Increasing a value expands
                        the prediction to include more candidate pixels, potentially enhancing sensitivity but
                        significantly slowing down the program.''')
    # specific to inter-chromosomal interactions
    group_2 = parser.add_argument_group(title = 'Parameters for filtering inter-chromosomal candidates:')
    group_2.add_argument('--inter-qvalue-cutoff', default=0.03, type=float, help='''Q-value cutoff
                        for pre-filtering inter-chromosomal pixels before feeding images into the
                        CNN models. Value Range: (0, 1]''')
    group_2.add_argument('--inter-min-per', default=50, type=float,
                        help='''Percentile of the minimum inter-chromosomal interaction frequency.
                        Only inter-chromosomal interactions with an interaction strength higher than
                        this value will be considered. Value Range: [0, 100]''')
    group_2.add_argument('--inter-min-cluster-size', default=4, type=int,
                        help='''The minimum size of clusters for inter-chromosomal interactions. The
                        value will be directly passed to the HDBSCAN clustering algorithm. Setting it
                        to 0 will disable the HDBSCAN clustering.''')
    group_2.add_argument('--inter-min-samples', default=4, type=int,
                        help='''This parameter functions similarly to "--intra-min-samples", but
                        applies to inter-chromosomal interactions. Setting it to 0 will disable the
                        HDBSCAN clustering.''')
    group_2.add_argument('--inter-max-cluster-size', default=250, type=int,
                        help='''The maximum size of clusters for inter-chromosomal interactions.''')
    group_2.add_argument('--inter-decay-rate', default=0.8, type=float,
                        help='''The decay rate for inter-chromosomal clusters with a size greater
                        than the value determined by "--inter-max-cluster-size". Value Range: (0, 1]''')
    group_2.add_argument('--inter-niter', default=3, type=int,
                        help='''This parameter functions similarly to "--intra-niter", but
                        applies to inter-chromosomal interactions.''')
    group_2.add_argument('--inter-extend-size', default='1,1,1',
                        help='''This parameter functions similarly to "--intra-extend-size", but
                        applies to inter-chromosomal interactions.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(10)
        console = logging.StreamHandler()
        logfil = '{0}.log'.format(args.output_prefix)
        filehandler = logging.FileHandler(logfil)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('INFO')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        ## Unified Formatter
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# Path to mcool = {0}'.format(args.mcool),
                   '# Resolutions = {0}'.format(args.resolutions),
                   '# High Resolutions = {0}'.format(args.high_res),
                   '# Balance Type = {0}'.format(args.balance_type),
                   '# Path to CNN models = {0}'.format(args.model_path),
                   '# Reference Genome = {0}'.format(args.genome),
                   '# Included Chromosomes = {0}'.format(args.chroms),
                   '# Entropy Cutoff = {0}'.format(args.entropy_cutoff),
                   '# Probability Cutoff 1 = {0}'.format(args.prob_cutoff_1),
                   '# Probability Cutoff 2 = {0}'.format(args.prob_cutoff_2),
                   '# Allowed Number of Gap Bins = {0}'.format(args.allowed_gaps),
                   '# Number of Allocated Processes = {0}'.format(args.nproc),
                   '# Output File Prefix = {0}'.format(args.output_prefix),
                   '# CPU Only = {0}'.format(args.cpu),
                   '# Qvalue Cutoff (Cis) = {0}'.format(args.intra_qvalue_cutoff),
                   '# Minimum Cis-contact Strength (Cis) = {0}'.format(args.intra_min),
                   '# Maximum Genomic Distance (Cis) = {0}'.format(args.max_dis),
                   '# Minimum Cluster Size (Cis) = {0}'.format(args.intra_min_cluster_size),
                   '# Minimum Number of Samples (Cis) = {0}'.format(args.intra_min_samples),
                   '# Maximum Cluster Size (Cis) = {0}'.format(args.intra_max_cluster_size),
                   '# Decay Rate for Large Clusters (Cis) = {0}'.format(args.intra_decay_rate),
                   '# HDBSCAN Iterations (Cis) = {0}'.format(args.intra_niter),
                   '# Extend Size (Cis) = {0}'.format(args.intra_extend_size),
                   '# Qvalue Cutoff (Trans) = {0}'.format(args.inter_qvalue_cutoff),
                   '# Minimum Trans-contact Percentile = {0}'.format(args.inter_min_per),
                   '# Minimum Cluster Size (Trans) = {0}'.format(args.inter_min_cluster_size),
                   '# Minimum Number of Samples (Trans) = {0}'.format(args.inter_min_samples),
                   '# Maximum Cluster Size (Trans) = {0}'.format(args.inter_max_cluster_size),
                   '# Decay Rate for Large Clusters (Trans) = {0}'.format(args.inter_decay_rate),
                   '# HDBSCAN Iterations (Trans) = {0}'.format(args.inter_niter),
                   '# Extend Size (Trans) = {0}'.format(args.inter_extend_size)
                   ]
        argtxt = '\n'.join(arglist)
        logger.info('\n' + argtxt)

        import tensorflow as tf

        if args.cpu:
            # Force TensorFlow to use CPU only
            tf.config.set_visible_devices([], 'GPU')
            logger.info('Running on CPU only.')
            #logger.info('Visible devices: {0}'.format(tf.config.get_visible_devices('GPU')))

        import cooler, time, tempfile, joblib
        import numpy as np
        from eaglec.utilities import calculate_expected, load_gap
        from eaglec.searchCandidates import select_intra_candidate, select_inter_candidate
        from eaglec.extractMatrix import collect_images
        from eaglec.predictCore import load_models, predict, refine_predictions

        # read resolutions
        resolutions = [int(r) for r in args.resolutions.split(',')]
        high_res = [int(r) for r in args.high_res.split(',') if int(r) in resolutions]
        clr = cooler.Cooler('{0}::resolutions/{1}'.format(args.mcool, resolutions[0]))
        # read chromosomes
        chroms = []
        for c in clr.chromnames:
            chromlabel = c.lstrip('chr')
            if (not args.chroms) or (chromlabel.isdigit() and '#' in args.chroms) or (chromlabel in args.chroms):
                chroms.append(c)
        
        # cache folder
        cache_root = os.path.abspath(os.path.expanduser('.eaglec2'))
        if not os.path.exists(cache_root):
            os.mkdir(cache_root)

        tl = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time()))
        cache_folder = tempfile.mkdtemp(**{'suffix':tl, 'dir':cache_root})

        # valid bins
        if args.balance_type == 'CNV':
            balance = 'sweight'
        elif args.balance_type == 'ICE':
            balance = 'weight'
        else:
            balance = False
        
        # expected values
        expected_values = {}
        logger.info('Calculating the expected values ...')
        for res in resolutions:
            logger.info('Resolution {0}'.format(res))
            max_bins = max(200, 2000000//res)
            clr = cooler.Cooler('{0}::resolutions/{1}'.format(args.mcool, res))
            expected_values[res] = calculate_expected(clr, chroms, balance, max_bins, nproc=args.nproc)
            
        joblib.dump(expected_values, os.path.join(cache_folder, 'expected_values.pkl'))

        # intra candidates
        logger.info('Searching for intra-chromosomal candidates ...')
        intra = {}
        extend_sizes = [int(i) for i in args.intra_extend_size.split(',')]
        extend_size_map = dict(zip(resolutions, extend_sizes))
        for res in resolutions:
            logger.info('Resolution {0}'.format(res))
            clr = cooler.Cooler('{0}::resolutions/{1}'.format(args.mcool, res))
            if not res in high_res:
                tmp = select_intra_candidate(clr, chroms, balance, expected_values[res], k=args.max_dis,
                                             q_thre=args.intra_qvalue_cutoff, minv=args.intra_min,
                                             min_cluster_size=args.intra_min_cluster_size,
                                             min_samples=args.intra_min_samples,
                                             max_cluster_size=args.intra_max_cluster_size,
                                             decay_rate=args.intra_decay_rate,
                                             niter=args.intra_niter, buff=extend_size_map[res],
                                             nproc=args.nproc, highres=False)
                intra[res] = tmp
            else:
                tmp = select_intra_candidate(clr, chroms, balance, expected_values[res], k=args.max_dis,
                                             q_thre=args.intra_qvalue_cutoff, minv=args.intra_min,
                                             min_cluster_size=args.intra_min_cluster_size,
                                             min_samples=args.intra_min_samples,
                                             max_cluster_size=args.intra_max_cluster_size,
                                             decay_rate=args.intra_decay_rate,
                                             niter=args.intra_niter, buff=extend_size_map[res],
                                             nproc=args.nproc, highres=True)
                intra[res] = tmp
        
        intra_counts = count_candidates(intra)
        logger.info('Totally detected {0} intra-chromosomal candidates'.format(intra_counts[1]))
        logger.info('The number of intra-chromosomal candidates by resolution: {0}'.format(intra_counts[0]))

        filtered_intra = intra
        
        # inter candidates
        logger.info('Searching for inter-chromosomal candidates ...')
        inter = {}
        extend_sizes = [int(i) for i in args.inter_extend_size.split(',')]
        extend_size_map = dict(zip(resolutions, extend_sizes))
        for res in resolutions:
            if not res in high_res:
                clr = cooler.Cooler('{0}::resolutions/{1}'.format(args.mcool, res))
                logger.info('Resolution {0}'.format(res))
                tmp = select_inter_candidate(clr, chroms, balance, windows=[3,5], min_per=args.inter_min_per,
                                             q_thre=args.inter_qvalue_cutoff, min_cluster_size=args.inter_min_cluster_size,
                                             min_samples=args.inter_min_samples, max_cluster_size=args.inter_max_cluster_size,
                                             decay_rate=args.inter_decay_rate, niter=args.inter_niter,
                                             buff=extend_size_map[res], nproc=args.nproc)
                inter[res] = tmp
        
        inter_counts = count_candidates(inter)
        logger.info('Totally detected {0} inter-chromosomal candidates'.format(inter_counts[1]))
        logger.info('The number of inter-chromosomal candidates by resolution: {0}'.format(inter_counts[0]))

        filtered_inter = inter

        logger.info('The extracted images will be exported to the folder {0}'.format(cache_folder))
        intra_n = collect_images(args.mcool, filtered_intra, expected_values, balance, cache_folder,
                                 w=15, entropy_cutoff=args.entropy_cutoff, nproc=args.nproc)
        inter_n = collect_images(args.mcool, filtered_inter, expected_values, balance, cache_folder,
                                 w=15, entropy_cutoff=args.entropy_cutoff, nproc=args.nproc)
        
        logger.info('Totally collected {0} images'.format(intra_n + inter_n))
        logger.info('Predicting SVs at each resolution ...')
        models = load_models(args.model_path)
        # load gap regions
        gaps = {}
        for res in resolutions:
            clr = cooler.Cooler('{0}::resolutions/{1}'.format(args.mcool, res))
            gaps[res] = load_gap(clr, chroms, ref_genome=args.genome, balance=balance)

        original_predictions = predict(cache_folder, models, gaps, prob_cutoff=args.prob_cutoff_1,
                                       max_gap=args.allowed_gaps, batch_size=256)
        logger.info('Done')

        logger.info('Fine-map original SV calls at higher resolutions ...')
        outfil = os.path.join(cache_folder, 'refine_input_cache.pkl')
        joblib.dump([original_predictions, resolutions, balance, expected_values, gaps, cache_folder],
                    outfil, compress=('xz', 3))
        SVs = refine_predictions(original_predictions, resolutions, models, args.mcool, balance,
                                 expected_values, gaps, cache_folder, max_gap=args.allowed_gaps,
                                 w=15, baseline_prob=args.prob_cutoff_2)
        with open('{0}.SV_calls.txt'.format(args.output_prefix), 'w') as out:
            out.write('\t'.join(['chrom1', 'pos1', 'chrom2', 'pos2', '++', '+-', '-+', '--', '++/--', '+-/-+',
                                    'original resolution', 'fine-mapped resolution', 'gap info'])+'\n')
            for line in SVs:
                c1, p1, c2, p2, prob1, prob2, prob3, prob4, prob5, prob6, res1, res2, ng = line
                out.write('{0}\t{1}\t{2}\t{3}\t{4:.4g}\t{5:.4g}\t{6:.4g}\t{7:.4g}\t{8:.4g}\t{9:.4g}\t{10}\t{11}\t{12}\n'.format(c1, p1, c2, p2, prob1, prob2, prob3, prob4, prob5, prob6, res1, res2, ng))

        logger.info('Done')

def count_candidates(by_res):

    counts = {}
    for r in by_res:
        tmp = by_res[r]
        counts[r] = sum([len(tmp[k]) for k in tmp])
    
    total = sum(list(counts.values()))
    
    return counts, total


if __name__ == '__main__':
    run()