#!python

# Created on Thu Aug 8 11:49:10 2021
# Author: XiaoTao Wang

## Required modules

import argparse, sys, eaglec, os

currentVersion = eaglec.__version__


def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Merge multiple SV calls from the same sample.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('--full-sv-files', nargs='+', help='''Path to the input SV files in full format
                       (see "--output-format" below). If redundant SVs (the coordinates are close each
                       other, see "--buff-size" below) are detected in two input files, files that appear
                       first have priority over files that appear later.''')
    parser.add_argument('--buff-size', default=50000, type=int, help='''Two SVs are determined as redundant
                        if the genomic distances between both side of the breakpoints are less than this span.
                        (bp)''')
    parser.add_argument('--output-format', default='full', choices=['full', 'NeoLoopFinder'],
                        help='''Format of the reported SVs. full: 8 columns will be reported for each
                        SV, information includes breakpoint coordinates and probability values of each
                        fusion type (++/+-/-+/--); NeoLoopFinder: 6-column format that can be directly
                        used as the NeoLoopFinder input.''')
    parser.add_argument('-O', '--output-file', help='''Output file name''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        import numpy as np

        buff = args.buff_size
        SV_pool = [load_sv_full(fil) for fil in args.full_sv_files]
        SVs = SV_pool[0]
        if len(SV_pool) > 1:
            for query_SVs in SV_pool[1:]:
                for sv in query_SVs:
                    if not check_in(SVs, sv, buff_size=buff):
                        SVs.append(sv)

        with open(args.output_file, 'w') as out:
            if args.output_format == 'full':
                out.write('\t'.join(['chrom1', 'pos1', 'chrom2', 'pos2', '++', '+-', '-+', '--'])+'\n')
            for c1, p1, c2, p2, prob1, prob2, prob3, prob4 in SVs:
                if args.output_format == 'full':
                    out.write('{0}\t{1}\t{2}\t{3}\t{4:.4g}\t{5:.4g}\t{6:.4g}\t{7:.4g}\n'.format(c1, p1, c2, p2, prob1, prob2, prob3, prob4))
                else:
                    strands = ['++', '+-', '-+', '--']
                    probs = np.r_[[prob1, prob2, prob3, prob4]]
                    idx = np.where(probs > 0.5)[0]
                    for i in idx:
                        strand = strands[i]
                        annot = 'translocation'
                        if c1 == c2:
                            if strand == '+-':
                                annot = 'deletion'
                            elif strand == '-+':
                                annot = 'duplication'
                            elif strand in ['++', '--']:
                                annot = 'inversion'
                        out.write('\t'.join([c1, c2, strand, str(p1), str(p2), annot])+'\n')


def check_in(pool, sv, buff_size=100000):

    label = False
    c1, p1, c2, p2 = sv[:4]
    if c1 > c2:
        c1, c2 = c2, c1
        p1, p2 = p2, p1
    for ref_sv in pool:
        ref_c1, ref_p1, ref_c2, ref_p2 = ref_sv[:4]
        if ref_c1 > ref_c2:
            ref_c1, ref_c2 = ref_c2, ref_c1
            ref_p1, ref_p2 = ref_p2, ref_p1
        
        if (c1 == ref_c1) and (c2 == ref_c2) and (abs(p1 - ref_p1) < buff_size) and (abs(p2 - ref_p2) < buff_size):
            label = True
            break
    
    return label

def load_sv_full(fil):

    SVs = []
    with open(fil, 'r') as source:
        source.readline()
        for line in source:
            c1, p1, c2, p2, prob1, prob2, prob3, prob4 = line.rstrip().split()
            p1, p2 = int(p1), int(p2)
            prob1, prob2, prob3, prob4 = float(prob1), float(prob2), float(prob3), float(prob4)
            SVs.append((c1, p1, c2, p2, prob1, prob2, prob3, prob4))
    
    return SVs


if __name__ == '__main__':
    run()