#!python

# Created on Thu Aug 8 11:49:10 2021
# Author: XiaoTao Wang

## Required modules

import argparse, sys, eaglec, os

currentVersion = eaglec.__version__


def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''This script takes the output from the predictSV
                                     command and reformats it into a format suitable for NeoLoopFinder.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('--input-file', help='''Path to the a file outputted by the predictSV command.
                        This file should includes 13 columns for each SV, including breakpoint coordinates,
                        probability values for each fusion type (++, +-, -+, --, ++/--, and +-/-+), the
                        resolution of the contact matrix from which the SV is originally predicted, the
                        finest resolution where the SV can be mapped, the number of bad bins near the SV
                        breakpoints.''')
    parser.add_argument('-O', '--output-file', help='''Output file name.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        import numpy as np

        SVs = []
        with open(args.input_file, 'r') as source:
            source.readline()
            for line in source:
                c1, p1, c2, p2, prob1, prob2, prob3, prob4, prob5, prob6, res1, res2, ng = line.rstrip().split()
                p1, p2 = int(p1), int(p2)
                prob1, prob2, prob3, prob4, prob5, prob6 = float(prob1), float(prob2), float(prob3), float(prob4), float(prob5), float(prob6)
                res1, res2 = int(res1), int(res2)
                SVs.append((c1, p1, c2, p2, prob1, prob2, prob3, prob4, prob5, prob6, res1, res2, ng))

        with open(args.output_file, 'w') as out:
            for c1, p1, c2, p2, prob1, prob2, prob3, prob4, prob5, prob6, res1, res2, ng in SVs:
                strands = ['++', '+-', '-+', '--', '++/--', '+-/-+']
                probs = np.r_[[prob1, prob2, prob3, prob4, prob5, prob6]]
                idx = np.argmax(probs)
                strand_list = strands[idx].split('/')
                for strand in strand_list:
                    annot = 'translocation'
                    if c1 == c2:
                        if strand == '+-':
                            annot = 'deletion'
                        elif strand == '-+':
                            annot = 'duplication'
                        elif strand in ['++', '--']:
                            annot = 'inversion'

                    out.write('\t'.join([c1, c2, strand, str(p1), str(p2), annot])+'\n')

if __name__ == '__main__':
    run()