#!python

# Created on Tue Sep 8 20:45:18 2020
# Author: XiaoTao Wang

## Required modules

import argparse, sys, logging, logging.handlers, eaglec, os, subprocess, glob, time

currentVersion = eaglec.__version__


def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Predict and combine SV predictions from contact maps
                                     at 5kb, 10kb, and 50kb resolutions.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('--hic-5k', help='''Path to a 5kb matrix in .cool format''')
    parser.add_argument('--hic-10k', help='''Path to a 10kb matrix in .cool format''')
    parser.add_argument('--hic-50k', help='''Path to a 50kb matrix in .cool format''')
    parser.add_argument('-O', '--output-prefix', help='''Prefix of the output file names''')
    parser.add_argument('-g', '--genome', default='hg38', choices = ['hg38', 'hg19', 'chm13', 'other'],
                       help='''Reference genome name.''')
    parser.add_argument('-C', '--chroms', nargs = '*', default = ['#', 'X'],
                        help = 'List of chromosome labels. Only Hi-C data within the specified '
                        'chromosomes will be included. Specially, "#" stands for chromosomes '
                        'with numerical labels. "--chroms" with zero argument will include '
                        'all chromosome data.')
    parser.add_argument('--balance-type', default='ICE', choices=['ICE', 'CNV', 'Raw'],
                        help = '''Normalization method. If you choose ICE, make sure you have run
                        "cooler balance" on your Hi-C matrix before you run this command; If you
                        choose CNV, make sure you have run "correct-cnv" of the NeoLoopFinder toolkit
                        before you run this command.''')
    parser.add_argument('--output-format', default='full', choices=['full', 'NeoLoopFinder'],
                        help='''Format of the reported SVs. full: 8 columns will be reported for each
                        SV, information includes breakpoint coordinates and probability values of each
                        fusion type (++/+-/-+/--); NeoLoopFinder: 6-column format that can be directly
                        used as the NeoLoopFinder input.''')
    parser.add_argument('--prob-cutoff-5k', default=0.8, type=float, help='''Probability threshold
                        for filtering 5kb SVs.''')
    parser.add_argument('--prob-cutoff-10k', default=0.8, type=float, help='''Probability threshold
                        for filtering 10kb SVs.''')
    parser.add_argument('--prob-cutoff-50k', default=0.99999, type=float, help='''Probability threshold
                        for filtering 50kb SVs.''')
    parser.add_argument('--logFile', default = 'eaglec.log', help = '''Logging file name.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.FileHandler(args.logFile)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('INFO')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        ## Unified Formatter
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# Cool URI at 5kb = {0}'.format(args.hic_5k),
                   '# Cool URI at 10kb = {0}'.format(args.hic_10k),
                   '# Cool URI at 50kb = {0}'.format(args.hic_50k),
                   '# Balance Type = {0}'.format(args.balance_type),
                   '# Reference Genome = {0}'.format(args.genome),
                   '# Included Chromosomes = {0}'.format(args.chroms),
                   '# Probability Cutoff for 5kb SVs = {0}'.format(args.prob_cutoff_5k),
                   '# Probability Cutoff for 10kb SVs = {0}'.format(args.prob_cutoff_10k),
                   '# Probability Cutoff for 50kb SVs = {0}'.format(args.prob_cutoff_50k),
                   '# Output File Prefix = {0}'.format(args.output_prefix),
                   '# Output Format = {0}'.format(args.output_format),
                   '# Log file name = {0}'.format(args.logFile)
                   ]
        argtxt = '\n'.join(arglist)
        logger.info('\n' + argtxt)

        chroms = ' '.join(['"{0}"'.format(i) for i in args.chroms])
        for res, uri, thre in zip([5000, 10000, 50000], [args.hic_5k, args.hic_10k, args.hic_50k], [args.prob_cutoff_5k, args.prob_cutoff_10k, args.prob_cutoff_50k]):
            logger.info('Predict SVs at {0}kb resolution ...'.format(res//1000))
            command = ['predictSV-single-resolution', '-H', uri, '--balance-type', args.balance_type,
                       '-O', '{0}.CNN_SVs.{1}K.txt'.format(args.output_prefix, res//1000), '--genome', args.genome, 
                       '--output-format full', '-C {0}'.format(chroms), '--prob-cutoff {0}'.format(thre),
                       '--logFile', args.logFile]
            subprocess.check_call(' '.join(command), shell=True)
        
        # extract cache folder paths from the log file
        cache_10k = extract_cache_folder(args.logFile, res=10000)
        cache_5k = extract_cache_folder(args.logFile, res=5000)

        logger.info('Locate 10kb SV coordinates on the 5kb matrix ...')
        exit_count = 0
        while (not os.path.exists('{0}.CNN_SVs.10K.txt'.format(args.output_prefix))) or (not os.path.exists('{0}.CNN_SVs.5K.txt'.format(args.output_prefix))):
            exit_count += 1
            time.sleep(30)
            if exit_count > 100:
                break
        command = ['predictSV-single-resolution', '-H', args.hic_5k, '--balance-type', args.balance_type,
                   '-O', '{0}.CNN_SVs.10K_highres.txt'.format(args.output_prefix), '--genome', args.genome,
                   '--low-resolution-breaks', '{0}.CNN_SVs.10K.txt'.format(args.output_prefix), '--region-size 25000',
                   '-C {0}'.format(chroms), '--output-format full', '--prob-cutoff 0', '--logFile', args.logFile,
                   '--cache-folder', cache_5k]
        subprocess.check_call(' '.join(command), shell=True)

        logger.info('Locate 50kb SV coordinates on the 5kb matrix ...')
        exit_count = 0
        while (not os.path.exists('{0}.CNN_SVs.50K.txt'.format(args.output_prefix))) or (not os.path.exists('{0}.CNN_SVs.5K.txt'.format(args.output_prefix))):
            exit_count += 1
            time.sleep(30)
            if exit_count > 100:
                break
        command = ['predictSV-single-resolution', '-H', args.hic_5k, '--balance-type', args.balance_type,
                   '-O', '{0}.CNN_SVs.50K_highres.txt'.format(args.output_prefix), '--genome', args.genome,
                   '--low-resolution-breaks', '{0}.CNN_SVs.50K.txt'.format(args.output_prefix), '--region-size 110000',
                   '-C {0}'.format(chroms), '--output-format full', '--prob-cutoff 0', '--logFile', args.logFile,
                   '--cache-folder', cache_5k]
        subprocess.check_call(' '.join(command), shell=True)

        logger.info('Merge SVs from 5K, 10K, and 50K ...')
        exit_count = 0
        while (not os.path.exists('{0}.CNN_SVs.10K_highres.txt'.format(args.output_prefix))) or (not os.path.exists('{0}.CNN_SVs.50K_highres.txt'.format(args.output_prefix))):
            exit_count += 1
            time.sleep(30)
            if exit_count > 100:
                break
            
        command = ['merge-multiple-resolutions', '--hic-10k', args.hic_10k, '--hic-5k', args.hic_5k, '--balance-type', args.balance_type,
                   '-C {0}'.format(chroms), '--full-sv-files', '{0}.CNN_SVs.5K.txt'.format(args.output_prefix),
                   '{0}.CNN_SVs.10K_highres.txt'.format(args.output_prefix), '{0}.CNN_SVs.50K_highres.txt'.format(args.output_prefix),
                   '-O', '{0}.CNN_SVs.5K_combined.txt'.format(args.output_prefix), '--buff-size 50000', '--output-format', args.output_format,
                   '--cache-10k', cache_10k, '--cache-5k', cache_5k]
        subprocess.check_call(' '.join(command), shell=True)
        
        logger.info('Done')

def extract_cache_folder(logfil, res=10000):

    cache_folder = None
    keyword = 'Interemediate results at the {0}kb resolution will be cached to'.format(res//1000)
    with open(logfil, 'r') as source:
        for line in source:
            if keyword in line:
                cache_folder = line.rstrip().split(keyword)[1].strip()
    
    return cache_folder
            

if __name__ == '__main__':
    run()