#!python

# Created on Wed Mar 16 23:57:18 2022
# Author: XiaoTao Wang

## Required modules
import argparse, sys, eaglec, os

currentVersion = eaglec.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Annotate gene fusion events for a list of
                                     SV breakpoints''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('--sv-file', help='''Path to a TXT file containing breakpoint coordinate
                        information (chrom1, pos1, chrom2, pos2) in the first four columns. Different
                        columns in the file must be separated by tab or space.''')
    parser.add_argument('--output-file', help='''Output file name. The program will append a column
                        containing detected gene fusions to the input SV file. If multiple fusion events
                        are detected, they will be separated by ",". If no fusion events are detected for
                        an SV, that row will not be reported in the output file.''')
    parser.add_argument('--buff-size', type=int, default=10000, help='''Genomic span (in base pair)
                        of the breakpoints for each SV. A gene will be considered at the breakpoint
                        if its interval is overlapped with the extended breakpoint region.''')
    parser.add_argument('--skip-rows', default=0, type=int,
                        help='''Number of leading lines in the loop file to skip.''')
    parser.add_argument('--ensembl-release', default=93, type=int,
                        help='''Ensembl release number. Refer to https://github.com/openvax/pyensembl
                        for details.''')
    parser.add_argument('--species', default='human',
                        help='''Species name of your sample. Refer to https://github.com/openvax/pyensembl
                        for details.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        
        from pyensembl import EnsemblRelease

        # Create/Load Ensembl Release cache ...
        db = EnsemblRelease(args.ensembl_release, species=args.species)
        db.download()
        db.index()
        
        buff = args.buff_size
        with open(args.output_file, 'w') as out:
            with open(args.sv_file, 'r') as source:
                for i, line in enumerate(source):
                    if i < args.skip_rows:
                        continue
                    parse = line.rstrip().split()
                    fusions = []
                    c1, p1, c2, p2 = parse[:4]
                    p1, p2 = int(p1), int(p2)
                    # breakpoint 1
                    genes = db.genes_at_locus(c1.lstrip('chr'), p1-buff, p1+buff)
                    gene_set_1 = set()
                    for g in genes:
                        if g.biotype == 'protein_coding':
                            gene_set_1.add(g.gene_name)
                    # breakpoint 2
                    genes = db.genes_at_locus(c2.lstrip('chr'), p2-buff, p2+buff)
                    gene_set_2 = set()
                    for g in genes:
                        if g.biotype == 'protein_coding':
                            gene_set_2.add(g.gene_name)
                    
                    if (len(gene_set_1) > 0) and (len(gene_set_2) > 0):
                        for g1 in gene_set_1:
                            for g2 in gene_set_2:
                                if g1 != g2:
                                    fusions.append('-'.join([g1, g2]))
                    
                    if len(fusions):
                        parse.append(','.join(fusions))
                        out.write('\t'.join(parse)+'\n')

if __name__ == '__main__':
    run()