#!python

import argparse, sys, eaglec, os

currentVersion = eaglec.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Batch run plot-interSVs/plot-intraSVs for
                                     all SVs outputed from predictSV or predictSV-single-resolution.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('--mcool', help='''Path to an mcool file.''')
    parser.add_argument('--full-sv-file', help='''Path to an SV file outputed from predictSV
                        or predictSV-single-resolution in "full" format.''')
    parser.add_argument('-p', '--prob-cutoff', default=0.5, type=float,
                        help='''Only SVs with probability greater than this value will be plotted.''')
    parser.add_argument('--cnv-file', help='''Copy number profile in bigwig format. Optional.''')
    parser.add_argument('-O', '--output-folder', help='''Name of the output folder.''')
    parser.add_argument('--balance-type', default='ICE', choices=['ICE', 'Raw', 'CNV'])
    parser.add_argument('--dpi', default=800, type=int)
    parser.add_argument('--figure-format', default='png')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

def find_best_resolution(resolutions, size_in_bp, nBins=100):

    import numpy as np

    diff = []
    candidates = []
    for r in resolutions:
        if size_in_bp//r > 20:
            diff.append(abs(size_in_bp//r - nBins))
            candidates.append(r)
    
    if not len(diff):
        return
    
    candidates = np.r_[candidates]
    diff = np.r_[diff]
    best_i = np.argmin(diff)
    res = candidates[best_i]

    return res

def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:

        from eaglec.visualize import load_sv_full
        from cooler.fileops import list_coolers
        import cooler, subprocess

        resolutions = [int(r.replace('/resolutions/','')) for r in list_coolers(args.mcool)]
        clr = cooler.Cooler('{0}::resolutions/{1}'.format(args.mcool, resolutions[-1]))
        chromsizes = clr.chromsizes

        SVs = load_sv_full(args.full_sv_file)
        if not os.path.exists(args.output_folder):
            os.mkdir(args.output_folder)

        for c1, p1, c2, p2, prob1, prob2, prob3, prob4 in SVs:
            if max(prob1, prob2, prob3, prob4) > args.prob_cutoff:
                if c1 != c2:
                    outfil = os.path.join(args.output_folder, '{0}_{1}.EagleC.{2}'.format(c1, c2, args.figure_format))
                    if os.path.exists(outfil):
                        continue
                    size_in_bp = chromsizes[c1] + chromsizes[c2]
                    best_res = find_best_resolution(resolutions, size_in_bp)
                    print(c1, c2, 'plotting the contact matrix at resolution: {0}'.format(best_res))
                    command = [
                        'plot-interSVs', '--cool-uri', '{0}::resolutions/{1}'.format(args.mcool, best_res),
                        '--full-sv-file', args.full_sv_file, '-p', str(args.prob_cutoff),
                        '-O', outfil, '-C', c1, c2, '--balance-type', args.balance_type,
                        '--dpi', str(args.dpi)
                    ]
                    try:
                        subprocess.check_call(' '.join(command), shell=True)
                    except:
                        print('Error occurred, skip')
                else:
                    outfil = os.path.join(args.output_folder,
                                          '{0}_{1}_{2}_{3}.EagleC.{4}'.format(c1, p1, c2, p2, args.figure_format))
                    if os.path.exists(outfil):
                        continue
                    sv_size = abs(p2 - p1)
                    start_loci = max(0, p1 - sv_size)
                    end_loci = min(chromsizes[c1], p2 + sv_size)
                    size_in_bp = end_loci - start_loci
                    best_res = find_best_resolution(resolutions, size_in_bp)
                    print(c1, p1, c2, p2, 'plotting the contact matrix at resolution: {0}'.format(best_res))

                    start_loci = start_loci//best_res * best_res
                    end_loci = end_loci//best_res * best_res
                    if args.cnv_file is None:
                        command = [
                            'plot-intraSVs', '--cool-uri', '{0}::resolutions/{1}'.format(args.mcool, best_res),
                            '--full-sv-file', args.full_sv_file, '-p', str(args.prob_cutoff),
                            '--region', '{0}:{1}-{2}'.format(c1, start_loci, end_loci),
                            '-O', outfil, '--balance-type', args.balance_type,
                            '--coordinates-to-display', str(p1), str(p2),
                            '--dpi', str(args.dpi)
                        ]
                    else:
                        command = [
                            'plot-intraSVs', '--cool-uri', '{0}::resolutions/{1}'.format(args.mcool, best_res),
                            '--full-sv-file', args.full_sv_file, '-p', str(args.prob_cutoff),
                            '--cnv-file', args.cnv_file, '--region', '{0}:{1}-{2}'.format(c1, start_loci, end_loci),
                            '-O', outfil, '--balance-type', args.balance_type,
                            '--coordinates-to-display', str(p1), str(p2),
                            '--dpi', str(args.dpi)
                        ]
                    try:
                        subprocess.check_call(' '.join(command), shell=True)
                    except:
                        print('Error occurred, skip')

if __name__ == '__main__':
    run()