#!python

# Created on Wed Mar 16 23:57:18 2022
# Author: XiaoTao Wang

## Required modules
import argparse, sys, eaglec, os

currentVersion = eaglec.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Annotate gene fusion events based on a list
                                     of SV breakpoints''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('--sv-file', help='''Path to a TXT file containing breakpoint coordinates,
                        with chrom1, pos1, chrom2, and pos2 in the first four columns. Columns must
                        be separated by tabs or spaces.''')
    parser.add_argument('--output-file', help='''Output file name. A new column containing detected gene
                        fusions will be appended as the last column of each SV record. Multiple fusion
                        events will be separated by commas. SVs without any fusion events will be excluded
                        from the output.''')
    parser.add_argument('--buff-size', type=int, default=10000, help='''Genomic span (in base pairs)
                        to extend around each SV breakpoint. A gene will be considered at the breakpoint
                        if its interval overlaps with the extended region.''')
    parser.add_argument('--skip-rows', default=1, type=int,
                        help='''Number of initial lines to skip in the input file.''')
    parser.add_argument('--ensembl-release', default=93, type=int,
                        help='''Ensembl release number. Refer to https://github.com/openvax/pyensembl
                        for details.''')
    parser.add_argument('--species', default='human',
                        help='''Species name of your sample. Refer to https://github.com/openvax/pyensembl
                        for details.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        
        from pyensembl import EnsemblRelease

        # Create/Load Ensembl Release cache ...
        db = EnsemblRelease(args.ensembl_release, species=args.species)
        db.download()
        db.index()
        
        buff = args.buff_size
        with open(args.output_file, 'w') as out:
            with open(args.sv_file, 'r') as source:
                for i, line in enumerate(source):
                    if i < args.skip_rows:
                        continue
                    parse = line.rstrip().split()
                    fusions = []
                    c1, p1, c2, p2 = parse[:4]
                    p1, p2 = int(p1), int(p2)
                    # breakpoint 1
                    genes = db.genes_at_locus(c1.lstrip('chr'), p1-buff, p1+buff)
                    gene_set_1 = set()
                    for g in genes:
                        if g.biotype == 'protein_coding':
                            gene_set_1.add(g.gene_name)
                    # breakpoint 2
                    genes = db.genes_at_locus(c2.lstrip('chr'), p2-buff, p2+buff)
                    gene_set_2 = set()
                    for g in genes:
                        if g.biotype == 'protein_coding':
                            gene_set_2.add(g.gene_name)
                    
                    if (len(gene_set_1) > 0) and (len(gene_set_2) > 0):
                        for g1 in gene_set_1:
                            for g2 in gene_set_2:
                                if g1 != g2:
                                    fusions.append('-'.join([g1, g2]))
                    
                    if len(fusions):
                        parse.append(','.join(fusions))
                        out.write('\t'.join(parse)+'\n')

if __name__ == '__main__':
    run()