#!python

import argparse
import datetime
import os
import re
import shutil
import string
import sys
import tempfile

from subprocess import call

from BCBio import GFF

from Bio.SeqFeature import FeatureLocation, SeqFeature

# TODO handle (rare) cases where the id prefix vary depending on the feature type (e.g.: GSSPFG for genes, GSSPFT for transcripts, GSSPFP for proteins)


class OgsMerger():

    def run_cmd(self, cmd):

        print("Executing command:")
        print("{}".format(cmd))

        try:
            retcode = call(cmd, shell=True)
            if retcode < 0:
                print("Child was terminated by signal " + str(-retcode), file=sys.stderr)

        except OSError as e:
            print("Execution failed:" + e, file=sys.stderr)
            sys.exit(1)

    # Parse a gene feature and create a hash representing its content (structure and attributes)
    # we can not always trust date_last_modified as some fix scripts don't change its value
    # instead, build a hash representing the gene content (exons, attributes, ...)
    def make_content_hash(self, f, check_name=False):
        blacklist_attr = ['ID', 'source', 'Parent', 'owner', 'Alias', 'date_creation', 'date_last_modified']
        if not check_name:
            blacklist_attr += ['Name']

        fhash = str(f.location) + "__"
        for qk in sorted(f.qualifiers):
            if qk not in blacklist_attr:
                fhash += qk + "->" + str(sorted(f.qualifiers[qk], key=str.lower)) + "__"

        # we need to sort subfeatures to have comparable hashes
        subs = sorted(f.sub_features, key=lambda e: (e.type, e.location.start, e.location.end))
        for sub in subs:
            # We don't pass the check_name as we don't want to check names on sub features (they are most of the time automatic)
            fhash += "__SUB[" + self.make_content_hash(sub) + "]"

        return fhash

    # Parse a gene feature and create a hash representing its content (structure only)
    def make_structure_hash(self, f):
        fhash = str(f.location) + "__"

        # we need to sort subfeatures to have comparable hashes
        subs = sorted(f.sub_features, key=lambda e: (e.type, e.location.start, e.location.end))
        for sub in subs:
            if sub.type in ['mRNA', 'CDS']:
                fhash += "__SUB[" + self.make_structure_hash(sub) + "]"

        return fhash

    # Clean a gene (and its children) attributes to output in the final GFF file
    def clean_feature(self, f):
        gene_id = f.qualifiers['ID'][0]
        f.qualifiers['ID'][0] = gene_id

        f.qualifiers['source'][0] = self.source
        if 'filtertag' in f.qualifiers:
            del f.qualifiers['filtertag']
        if 'owner' in f.qualifiers:
            del f.qualifiers['owner']

        # First letter in capital is not valid for GFF3 (no longer a problem in apollo 2.0.7, kept for compatibility with older versions)
        if 'Allele' in f.qualifiers:
            f.qualifiers['allele'] = f.qualifiers['Allele']
            del f.qualifiers['Allele']
        if 'Part' in f.qualifiers:
            f.qualifiers['part'] = f.qualifiers['Part']
            del f.qualifiers['Part']
        if 'Synonym' in f.qualifiers:
            f.qualifiers['synonym'] = f.qualifiers['Synonym']
            del f.qualifiers['Synonym']

        # Some qualifiers are not needed outside apollo
        if 'status' in f.qualifiers:
            del f.qualifiers['status']
        if 'annotGroup' in f.qualifiers:
            del f.qualifiers['annotGroup']

        # Cleanup some attributes
        if 'symbol' in f.qualifiers:
            f.qualifiers['symbol'] = [x.strip() for x in f.qualifiers['symbol']]
        if 'description' in f.qualifiers:
            f.qualifiers['description'] = [x.strip() for x in f.qualifiers['description']]
        if 'Note' in f.qualifiers:
            f.qualifiers['Note'] = [x.strip() for x in f.qualifiers['Note']]
        if 'Dbxref' in f.qualifiers:
            f.qualifiers['Dbxref'] = [x.strip() for x in f.qualifiers['Dbxref']]
        if 'allele' in f.qualifiers:
            f.qualifiers['allele'] = [x.strip() for x in f.qualifiers['allele']]
        if 'part' in f.qualifiers:
            f.qualifiers['part'] = [x.strip() for x in f.qualifiers['part']]
        if 'Name' in f.qualifiers:
            f.qualifiers['full_name'] = [x.strip() for x in f.qualifiers['Name']]
        if 'synonym' in f.qualifiers:
            f.qualifiers['synonym'] = [x.strip() for x in f.qualifiers['synonym']]

        mrna_count = 0
        mrna_count_cycle = 1

        gene_start = f.location.start
        gene_end = f.location.end

        # Add previous versions as aliases
        gene_id_splitted = gene_id.split(".")
        if len(gene_id_splitted) > 1:
            gene_id_no_version = gene_id_splitted[0]
            gene_version = int(gene_id_splitted[1])
            if 'Alias' not in f.qualifiers:
                f.qualifiers['Alias'] = []
            f.qualifiers['Alias'].append(gene_id_no_version)
            for old_version in range(1, gene_version):
                f.qualifiers['Alias'].append(gene_id_no_version + '.' + str(old_version))

        has_multiple_isoforms = len(f.sub_features) > 1

        for child in f.sub_features:  # mRNA

            rna_start = f.location.start
            rna_end = f.location.end

            child.qualifiers['source'][0] = self.source
            if 'filtertag' in child.qualifiers:
                del child.qualifiers['filtertag']
            if 'owner' in child.qualifiers:
                del child.qualifiers['owner']
            mrna_id = gene_id + "-" + "R" + (string.ascii_uppercase[mrna_count] * mrna_count_cycle)
            if 'Name' not in child.qualifiers:
                child.qualifiers['Name'] = [mrna_id]
            else:
                child.qualifiers['Name'][0] = mrna_id

            # put apollo id to aliases and write a good ID
            current_id = child.qualifiers['ID'][0]
            if re.match("^[A-F0-9]{32}$", current_id) or re.match("^[a-f0-9-]{36}$", current_id):
                if 'Alias' not in child.qualifiers:
                    child.qualifiers['Alias'] = []
                child.qualifiers['Alias'].append(current_id)

            if 'ID' not in child.qualifiers:
                child.qualifiers['ID'] = [mrna_id]
            else:
                child.qualifiers['ID'][0] = mrna_id

            child.qualifiers['Parent'][0] = gene_id

            # Some qualifiers are not needed outside apollo
            if 'status' in child.qualifiers:
                del child.qualifiers['status']
            if 'annotGroup' in child.qualifiers:
                del child.qualifiers['annotGroup']

            # Transfer attributes to mRNA level
            if 'symbol' in f.qualifiers:
                child.qualifiers['symbol'] = f.qualifiers['symbol']
            if 'allele' in f.qualifiers:
                child.qualifiers['allele'] = f.qualifiers['allele']
            if 'part' in f.qualifiers:
                child.qualifiers['part'] = f.qualifiers['part']
            if 'synonym' in f.qualifiers:
                if 'synonym' not in child.qualifiers:
                    child.qualifiers['synonym'] = []
                child.qualifiers['synonym'] += f.qualifiers['synonym']
            if 'Note' in f.qualifiers:
                if 'Note' not in child.qualifiers:
                    child.qualifiers['Note'] = []
                child.qualifiers['Note'] += f.qualifiers['Note']
            if 'Dbxref' in f.qualifiers:
                if 'Dbxref' not in child.qualifiers:
                    child.qualifiers['Dbxref'] = []
                child.qualifiers['Dbxref'] += f.qualifiers['Dbxref']

            # Transfer only if not multiple isoforms
            if not has_multiple_isoforms:
                if 'Name' in f.qualifiers:
                    child.qualifiers['full_name'] = f.qualifiers['Name']

            # Transfer without replacing isoform information
            if not has_multiple_isoforms or 'description' not in child.qualifiers or not child.qualifiers['description']:
                if 'description' in f.qualifiers:
                    child.qualifiers['description'] = f.qualifiers['description']

            # Add previous versions as aliases
            if len(gene_id_splitted) > 1:
                if 'Alias' not in child.qualifiers:
                    child.qualifiers['Alias'] = []
                child.qualifiers['Alias'].append(gene_id_no_version + "-" + "R" + (string.ascii_uppercase[mrna_count] * mrna_count_cycle))
                for old_version in range(1, gene_version):
                    child.qualifiers['Alias'].append(gene_id_no_version + '.' + str(old_version) + "-" + "R" + (string.ascii_uppercase[mrna_count] * mrna_count_cycle))

            # Remove uppercase variants if any
            if 'Allele' in child.qualifiers:
                del child.qualifiers['Allele']
            if 'Part' in child.qualifiers:
                del child.qualifiers['Part']
            if 'Synonym' in child.qualifiers:
                del child.qualifiers['Synonym']

            id_count = 1
            for gchild in child.sub_features:  # exons, cds, ...
                gchild.qualifiers['source'][0] = self.source
                if 'filtertag' in gchild.qualifiers:
                    del gchild.qualifiers['filtertag']
                if 'ID' in gchild.qualifiers:
                    gchild.qualifiers['ID'][0] = mrna_id + "-" + gchild.type + "-" + str(id_count)
                if 'owner' in gchild.qualifiers:
                    del gchild.qualifiers['owner']
                gchild.qualifiers['Name'] = [mrna_id + "-" + gchild.type]  # Will create the Name array if not yet present
                gchild.qualifiers['Parent'][0] = mrna_id
                id_count_gg = 1
                for ggchild in gchild.sub_features:  # exotic stuff (non_canonical_five_prime_splice_site non_canonical_three_prime_splice_site stop_codon_read_through)
                    ggchild.qualifiers['source'][0] = self.source
                    if 'filtertag' in ggchild.qualifiers:
                        del ggchild.qualifiers['filtertag']
                    if 'owner' in ggchild.qualifiers:
                        del ggchild.qualifiers['owner']
                    ggchild.qualifiers['ID'][0] = mrna_id + "-" + gchild.type + "-" + str(id_count) + "-" + ggchild.type + "-" + str(id_count_gg)
                    ggchild.qualifiers['Name'][0] = mrna_id + "-" + gchild.type + "-" + ggchild.type
                    ggchild.qualifiers['Parent'][0] = mrna_id + "-" + gchild.type + "-" + str(id_count)
                    id_count_gg += 1
                id_count += 1

                # Check exon/utrs are not outside the gene/mrna coordinates
                if gchild.location.start < rna_start or gchild.location.end < rna_start:
                    rna_start = min(gchild.location.start, gchild.location.end)
                if gchild.location.start > rna_end or gchild.location.end > rna_end:
                    rna_end = max(gchild.location.start, gchild.location.end)

            if rna_start != child.location.start:
                print("Fixing start coordinate for mRNA {} from {} to {}".format(child.qualifiers['ID'][0], child.location.start, rna_start))
                child.location = FeatureLocation(rna_start, child.location.end, strand=child.location.strand)
            if rna_end != child.location.end:
                print("Fixing end coordinate for mRNA {} from {} to {}".format(child.qualifiers['ID'][0], child.location.end, rna_end))
                child.location = FeatureLocation(child.location.start, rna_end, strand=child.location.strand)

            # Check mrna are not outside the gene coordinates
            if gchild.location.start < gene_start or gchild.location.end < gene_start:
                gene_start = min(gchild.location.start, gchild.location.end)
            if gchild.location.start > gene_end or gchild.location.end > gene_end:
                gene_end = max(gchild.location.start, gchild.location.end)

            mrna_count += 1
            if mrna_count >= 25:
                mrna_count = 0
                mrna_count_cycle += 1

        if gene_start != f.location.start:
            print("Fixing start coordinate for gene {} from {} to {}".format(f.qualifiers['ID'][0], f.location.start, gene_start))
            f.location = FeatureLocation(gene_start, f.location.end, strand=f.location.strand)
        if gene_end != f.location.end:
            print("Fixing end coordinate for gene {} from {} to {}".format(f.qualifiers['ID'][0], f.location.end, gene_end))
            f.location = FeatureLocation(f.location.start, gene_end, strand=f.location.strand)

        return f

    # Add exons subfeatures to a feature, guessing from UTR and CDS subfeatures
    def guess_exons(self, cleaned_f):

        for child in cleaned_f.sub_features:  # mRNA
            cds_coords = {}
            utr_coords = {}
            exon_coords = {}
            seen_exons = False
            for gchild in child.sub_features:  # exons, cds, ...
                if gchild.type == "CDS":
                    cds_coords[gchild.location.start] = gchild.location.end
                elif gchild.type in ["UTR", "five_prime_UTR", "three_prime_UTR"]:
                    utr_coords[gchild.location.start] = gchild.location.end
                elif gchild.type == "exon":
                    # We found some exons, it means we don't need to guess them
                    seen_exons = True

            if seen_exons:
                continue

            cds_sorted = sorted(cds_coords)
            for cds_start in cds_sorted:
                cds_end = cds_coords[cds_start]
                utr_len_init = len(utr_coords)

                if (cds_end) in utr_coords:
                    # 3' UTR
                    exon_coords[cds_start] = utr_coords[cds_end]
                    del utr_coords[cds_end]

                if cds_start in utr_coords.values():
                    # 5' UTR
                    found_start = None
                    for start, end in utr_coords.items():
                        if found_start is None and end == cds_start:
                            found_start = start
                    if cds_start in exon_coords:
                        # we found utr on both sides
                        cds_end = exon_coords[cds_start]
                        del exon_coords[cds_start]

                    exon_coords[found_start] = cds_end

                    del utr_coords[found_start]

                if len(utr_coords) == utr_len_init:
                    exon_coords[cds_start] = cds_end

            # Add all remaining utr = utr that are not adjacent to cds
            for utr_start in utr_coords:
                exon_coords[utr_start] = utr_coords[utr_start]

            count_id = 1
            for exon_start in exon_coords:
                new_subf = SeqFeature(FeatureLocation(exon_start, exon_coords[exon_start]), type="exon", strand=child.location.strand, qualifiers={"source": [self.source], 'ID': [child.qualifiers['Name'][0] + "-exon-" + str(count_id)], 'Name': [child.qualifiers['Name'][0] + "-exon"]})
                child.sub_features.append(new_subf)
                count_id += 1

        return cleaned_f

    # Add UTR subfeatures to a feature, guessing from CDS and exon subfeatures
    def guess_utrs(self, cleaned_f):

        for child in cleaned_f.sub_features:  # mRNA
            cds_coords = {}
            exon_coords = {}
            utr_coords = {}
            for gchild in child.sub_features:  # exons, cds, ...
                if gchild.type == "CDS":
                    cds_coords[gchild.location.start] = gchild.location.end
                elif gchild.type == "exon":
                    exon_coords[gchild.location.start] = gchild.location.end

            exon_sorted = sorted(exon_coords)
            for exon_start in exon_sorted:
                exon_end = exon_coords[exon_start]

                found_cds = False
                for cds_start in cds_coords:
                    cds_end = cds_coords[cds_start]

                    if cds_start == exon_start and cds_end == exon_end:
                        # whole exon is a CDS
                        found_cds = True

                    else:
                        if cds_start > exon_start and cds_start < exon_end:
                            # 5'UTR
                            found_cds = True
                            utr_coords[exon_start] = cds_start

                        if cds_end > exon_start and cds_end < exon_end:
                            # 3' UTR
                            found_cds = True
                            utr_coords[cds_end] = exon_end

                if not found_cds:
                    # No CDS on the exon, create a UTR on the whole exon
                    utr_coords[exon_start] = exon_end

            count_id = 1
            for utr_start in utr_coords:
                if 'Name' in child.qualifiers:
                    name_prefix = child.qualifiers['Name'][0]
                else:
                    name_prefix = child.qualifiers['ID'][0]
                new_subf = SeqFeature(FeatureLocation(utr_start, utr_coords[utr_start]), type="UTR", strand=child.location.strand, qualifiers={"source": [self.source], 'ID': [name_prefix + "-UTR-" + str(count_id)], 'Name': [name_prefix + "-UTR"]})
                new_subf.sub_features = []  # See https://github.com/biopython/biopython/issues/928
                child.sub_features.append(new_subf)
                count_id += 1

        return cleaned_f

    def renumber_exons(self, cleaned_f):

        for child in cleaned_f.sub_features:  # mRNA
            exon_coords = {}
            cds_coords = {}
            utr_coords = {}
            strand = child.location.strand

            # Find positions
            for gchild in child.sub_features:  # exons, cds, ...
                if gchild.type == "exon":
                    exon_id = gchild.qualifiers['ID'][0]
                    start = gchild.location.start
                    exon_coords[start] = exon_id
                elif gchild.type == "CDS":
                    if 'ID' in gchild.qualifiers:
                        cds_id = gchild.qualifiers['ID'][0]
                    else:
                        cds_id = gchild.qualifiers['Name'][0]
                    start = gchild.location.start
                    cds_coords[start] = cds_id
                elif gchild.type in ["UTR", "five_prime_UTR", "three_prime_UTR"]:
                    if 'ID' in gchild.qualifiers:
                        utr_id = gchild.qualifiers['ID'][0]
                    else:
                        utr_id = gchild.qualifiers['Name'][0]
                    start = gchild.location.start
                    utr_coords[start] = utr_id

            # Sort by position
            if strand > 0:
                exons_sorted = sorted(exon_coords)
                cds_sorted = sorted(cds_coords)
                utr_sorted = sorted(utr_coords)
            else:
                exons_sorted = sorted(exon_coords, reverse=True)
                cds_sorted = sorted(cds_coords, reverse=True)
                utr_sorted = sorted(utr_coords, reverse=True)

            # Reassign ids
            exon_num = 1
            exon_subs = {}
            for exon_s in exons_sorted:
                exon_subs[exon_coords[exon_s]] = re.sub('-exon(-[0-9]+)?$', '-exon-' + str(exon_num), exon_coords[exon_s])
                exon_num += 1

            cds_num = 1
            cds_subs = {}
            for cds_s in cds_sorted:
                cds_subs[cds_coords[cds_s] + str(cds_s)] = re.sub('-CDS(-[0-9]+)?$', '-CDS-' + str(cds_num), cds_coords[cds_s])
                cds_num += 1

            utr_num = 1
            utr_subs = {}
            for utr_s in utr_sorted:
                utr_subs[utr_coords[utr_s] + str(utr_s)] = re.sub('-UTR(-[0-9]+)?$', '-UTR-' + str(utr_num), utr_coords[utr_s])
                utr_num += 1

            for gchild in child.sub_features:  # exons, cds, ...
                if gchild.type == "exon":
                    if gchild.qualifiers['ID'][0] in exon_subs:
                        # The gchild id can be absent from exon_subs if we are treating an exon shared between multiple isoforms,
                        # and it was already renamed while treating previous isoform
                        gchild.qualifiers['ID'][0] = exon_subs[gchild.qualifiers['ID'][0]]
                elif gchild.type == "CDS":
                    if 'ID' in gchild.qualifiers:
                        cds_id = gchild.qualifiers['ID'][0]
                    else:
                        cds_id = gchild.qualifiers['Name'][0]

                    if cds_id + str(gchild.location.start) in cds_subs:
                        # The gchild id can be absent from cds_subs if we are treating an exon shared between multiple isoforms,
                        # and it was already renamed while treating previous isoform
                        gchild.qualifiers['ID'] = [cds_subs[cds_id + str(gchild.location.start)]]

                        # exotic stuff (stop_codon_read_through) needs to be readopted by the correct parent
                        id_count_gg = 1
                        for ggchild in gchild.sub_features:
                            ggchild.qualifiers['ID'][0] = gchild.qualifiers['ID'][0] + "-" + ggchild.type + "-" + str(id_count_gg)
                            ggchild.qualifiers['Parent'][0] = gchild.qualifiers['ID'][0]
                            id_count_gg += 1
                elif gchild.type in ["UTR", "five_prime_UTR", "three_prime_UTR"]:
                    if 'ID' in gchild.qualifiers:
                        utr_id = gchild.qualifiers['ID'][0]
                    else:
                        utr_id = gchild.qualifiers['Name'][0]

                    if utr_id + str(gchild.location.start) in utr_subs:
                        # The gchild id can be absent from utr_subs if we are treating an exon shared between multiple isoforms,
                        # and it was already renamed while treating previous isoform
                        gchild.qualifiers['ID'] = [utr_subs[utr_id + str(gchild.location.start)]]

        return cleaned_f

    def match_apollo_against_base(self):

        print("Converting GFF to BED...")

        base_gff_in = open(self.filtered_base_gff, 'r')
        base_gff_out = open(self.tmpdir + '/base_cds.gff', 'w+')
        for li in base_gff_in:
            cols = li.strip().split()
            # FIXME CDS could be more appropriate (or maybe not...)
            if not li.startswith("#") and cols[2] == 'exon':
                cols[8] = re.sub(r'ID=([a-zA-Z0-9]+)', r'exID=\1', cols[8])  # remove already set id
                cols[8] = re.sub(r'Parent=([a-zA-Z0-9]+)([\.0-9]+)?([-_]R[A-Z]+)?(,[a-zA-Z0-9\.\-_]*)?', r'ID=\1', cols[8])  # generate a fake id based on Parent + remove multiple parents (ie when an exon is part of multiple isoforms)
                cols[8] = cols[8].rstrip(";")  # gff2bed doesn't like trailing ;
                print('\t'.join(cols), file=base_gff_out)
        base_gff_out.close()

        self.run_cmd("awk '{if ($3 == \"gene\") {print}}' FS='\t' OFS='\t' " + self.apollo_gff + " | gff2bed --do-not-sort | awk '{ if ($2 <= 0) {$2++}; print }' FS='\t' OFS='\t' | sort-bed - > " + self.tmpdir + "/apollo.bed")

        self.run_cmd("cat " + self.tmpdir + "/base_cds.gff | gff2bed > " + self.tmpdir + "/base_cds.bed")

        print("Computing intersect between new Apollo annotation and base annotation...")
        in_map = self.tmpdir + "/wa_to_all_base.tsv"

        self.run_cmd("intersectBed -s -wao -a " + self.tmpdir + "/apollo.bed -b " + self.tmpdir + "/base_cds.bed | cut -f4,14,21 | awk '{ if ($2 != \".\") {print} }'  | awk '{a[$1\"\t\"$2]+=$3;} END {for(i in a) print i\"\t\"a[i];}' | sort > " + in_map)

        # Load the bedtools intersect result
        in_map_h = open(in_map)

        for gene in in_map_h:
            cols = gene.strip().split("\t")
            if len(cols) < 2:
                print("Failed loading bedtools result: " + gene)
                sys.exit(1)

            cur_waid = cols[0]
            cur_gid = cols[1]
            cur_len = int(cols[2])

            if cur_waid not in self.primary_matches:
                self.primary_matches[cur_waid] = {'gid': cur_gid, 'len': cur_len}  # key is wa id, gid is base, len is exon length coverage
            else:
                # found another match for current apollo gene
                if cur_len > self.primary_matches[cur_waid]['len']:
                    # found a better match
                    if cur_waid not in self.secondary_matches:
                        self.secondary_matches[cur_waid] = []
                    self.secondary_matches[cur_waid].append(self.primary_matches[cur_waid])
                    self.primary_matches[cur_waid] = {'gid': cur_gid, 'len': cur_len}
                else:
                    # found a worse match
                    if cur_waid not in self.secondary_matches:
                        self.secondary_matches[cur_waid] = []
                    self.secondary_matches[cur_waid].append({'gid': cur_gid, 'len': cur_len})

        in_map_h.close()

    # Parse the original base annotation file
    def parse_base_annotation(self):

        agff_handle = open(self.filtered_base_gff)

        for rec in GFF.parse(agff_handle):
            rec.annotations = {}
            for f in rec.features:  # gene
                if f.type == "gene":
                    gid = f.qualifiers['ID'][0]

                    search = re.search(self.id_regex, gid)
                    if search:
                        pid_number = int(search.group(1))
                        if pid_number > self.highest_id:
                            self.highest_id = pid_number
                            self.padding_length = len(search.group(1))
                    else:
                        print("ERROR: could not parse ID from base annotation: " + gid)

                    gid = self.id_syntax.replace('{id}', search.group(1))  # Remove version suffix
                    if search.group(2):
                        self.base_genes_version[gid] = int(search.group(2).strip("."))
                    self.base_genes_structure[gid] = self.make_structure_hash(f)
                    self.base_genes_positions[gid] = {'start': f.location.start, 'end': f.location.end, 'strand': f.location.strand}

        agff_handle.close()

    # Load the previous annotation
    def parse_previous_annotation(self):

        ogff_handle = open(self.previous_gff)

        for rec in GFF.parse(ogff_handle):
            rec.annotations = {}
            for f in rec.features:  # gene
                if f.type == "gene":
                    self.num_total_genes_in_previous += 1
                    if 'Alias' in f.qualifiers:  # suspect it was already from apollo in previous annotation

                        pid = f.qualifiers['ID'][0]
                        search = re.search(self.id_regex, pid)
                        if search:
                            pid_number = int(search.group(1))
                            if pid_number > self.highest_id:
                                self.highest_id = pid_number
                        else:
                            print("ERROR: could not parse previous ID: " + pid)

                        previous_assigned_id = self.id_syntax.replace('{id}', str(pid_number).zfill(self.padding_length))

                        for alias in f.qualifiers['Alias']:
                            # Alias format is different in apollo 1.x and apollo 2.x, try to support both
                            if re.match("^[A-F0-9]{32}$", alias) or re.match("^[a-f0-9-]{36}$", alias):

                                # Getting there means that the gene was in the previous annotation and that the gene is no longer present in the new version,
                                # OR that it is there but with no bedtools result
                                # OR that it is there with the same or a different bedtools result

                                # Store the information just in case the same alias is in the new annotation
                                self.name_map[alias] = previous_assigned_id
                                self.previous_genes_content[alias] = self.make_content_hash(f, check_name=True)
                                self.previous_genes_name[alias] = pid
                                break  # we use the first alias matching a apollo id, ignoring any subsequent

        ogff_handle.close()

    # When a gene has been splitted on apollo, assign the base id to the most covered
    def handle_splitted(self):
        assigned_gids = {}
        removed_wa = []
        for wa in self.primary_matches:
            gid = self.primary_matches[wa]['gid']
            if gid in assigned_gids:
                self.num_splitted += 1
                if assigned_gids[gid]['len'] > self.primary_matches[wa]['len']:
                    # a previous apollo gene is more covered by the base gene than the current one
                    # move the current match from the primary to the secondary list
                    if wa not in self.secondary_matches:
                        self.secondary_matches[wa] = []
                    self.secondary_matches[wa].append(self.primary_matches[wa])
                    removed_wa.append(wa)
                else:
                    # a new best match
                    # move the previous best from the primary to the secondary list
                    if assigned_gids[gid]['waid'] not in self.secondary_matches:
                        self.secondary_matches[assigned_gids[gid]['waid']] = []
                    self.secondary_matches[assigned_gids[gid]['waid']].append(self.primary_matches[assigned_gids[gid]['waid']])
                    removed_wa.append(assigned_gids[gid]['waid'])

                    # keep current in the primary list
                    assigned_gids[gid] = {'waid': wa, 'len': self.primary_matches[wa]['len']}
            else:
                assigned_gids[gid] = {'waid': wa, 'len': self.primary_matches[wa]['len']}

        for wa in removed_wa:
            del self.primary_matches[wa]

    # When an apollo gene covers several base genes completely, make sure we remove all 100% covered base genes
    # => 3 genes completely covered by 1 apollo gene will give just 1 gene in output
    def handle_merged(self):

        for wa in self.primary_matches:
            if wa in self.secondary_matches:
                apollo_pos = self.apollo_genes_positions[wa]
                for secs in self.secondary_matches[wa]:
                    base_pos = self.base_genes_positions[secs['gid']]

                    search = re.search(self.id_regex, secs['gid'])
                    raw_merged_gene_id = self.id_syntax.replace('{id}', search.group(1))

                    # Check that the gene is not already blacklisted (ie, the primary match of another gene?)
                    if raw_merged_gene_id not in self.blacklist_base:
                        if base_pos['strand'] == apollo_pos['strand'] and (base_pos['start'] >= apollo_pos['start'] and base_pos['start'] <= apollo_pos['end']) or (base_pos['end'] >= apollo_pos['start'] and base_pos['end'] <= apollo_pos['end']):
                            # base gene is somewhat included in the apollo gene, delete it
                            # we could trust secondary_matches and don't check the position, but you never know...
                            if wa not in self.blacklist_merged:
                                self.blacklist_merged[wa] = []
                            self.blacklist_merged[wa].append(raw_merged_gene_id)
                        else:
                            print('ERROR: Strange merge case for ' + str(wa) + ' -> ' + str(raw_merged_gene_id))

    def check_id_compatibility(self):
        for wa in self.primary_matches:

            if wa in self.name_map:
                if self.primary_matches[wa]['gid'] != self.name_map[wa]:
                    self.kept_compatibility += 1
                    print("WARNING: In the previous annotation, gene '" + wa + "' was assigned id '" + self.name_map[wa] + "' but in the new one it would be assigned id '" + self.primary_matches[wa]['gid'] + "'. Keeping '" + self.name_map[wa] + "' for compatibility.")
                # else: nothing to change, the previous id and the one found with bedtools is the same
            else:
                if self.primary_matches[wa]['gid'] not in self.name_map.values():
                    self.name_map[wa] = self.primary_matches[wa]['gid']
                else:
                    already_assigned = ""
                    for w, g in self.name_map.items():
                        if g == self.primary_matches[wa]['gid']:
                            already_assigned = w

                    if already_assigned not in self.apollo_ids_in_latest:
                        print("WARNING: Gene '" + wa + "' will be assigned id '" + self.primary_matches[wa]['gid'] + "' but it was already used by another no-more-existing gene '" + already_assigned + "' in previous annotation.")
                        self.name_map[wa] = self.primary_matches[wa]['gid']
                    else:
                        # The id was already used for another gene, don't store any id mapping for this gene, we will generate a new id later
                        # This happens when a gene was splitted by annotators (but not only)
                        print("WARNING: Gene '" + wa + "' should be assigned id '" + self.primary_matches[wa]['gid'] + "' but it is already used by gene '" + already_assigned + "'. A new id will be created.")

    def parse_apollo_annotation(self):
        # Load the new WA annotation
        wgff_handle = open(self.apollo_gff)
        for rec in GFF.parse(wgff_handle):
            rec.annotations = {}
            for f in rec.features:  # gene
                if 'ID' in f.qualifiers and f.type == "gene":
                    wa_id = f.qualifiers['ID'][0]
                    self.apollo_ids_in_latest.append(wa_id)

                    base_id = ""

                    if wa_id not in self.name_map:
                        # A new gene that have no bedtools result and that wasn't found in the previous annotation version
                        # Give it a completely new id
                        base_id = self.id_syntax.replace('{id}', str(self.highest_id).zfill(self.padding_length))
                        self.name_map[wa_id] = base_id + ".1"
                        self.highest_id += 1
                        self.num_new_id += 1
                    elif wa_id in self.previous_genes_content:
                        # Found in the previous and the new annotation => we have to compare the previous and new version to choose the good suffix
                        f = self.guess_utrs(f)  # apollo doesn't generate UTRs lines in the gff output, guess them before comparing
                        new_hash = self.make_content_hash(f, check_name=True)
                        search = re.search(self.id_regex, self.previous_genes_name[wa_id])
                        if self.previous_genes_content[wa_id] == new_hash:
                            # Gene was the same in the previous annotation, keeping its previous ID
                            self.num_unchanged_from_previous += 1
                            self.name_map[wa_id] = self.previous_genes_name[wa_id]
                            if search:
                                base_id = self.id_syntax.replace("{id}", search.group(1))
                            else:
                                print("ERROR: could not parse previous ID: " + self.previous_genes_name[wa_id])
                        else:
                            # Gene changed, updating its suffix
                            if search:
                                # The previous gene is sure to have a version suffix as it was already incremented when creating the previous OGS
                                suffix = int(search.group(2).strip(".")) + 1
                            else:
                                print("ERROR: could not parse previous ID: " + self.previous_genes_name[wa_id])
                            base_id = self.name_map[wa_id]
                            self.name_map[wa_id] = self.name_map[wa_id] + "." + str(suffix)
                            self.num_version_bumped += 1
                    else:
                        # The gene was not seen in previous annotation and we have a bedtools result (=found in base annotation)
                        base_id = self.name_map[wa_id]
                        search = re.search(self.id_regex, self.name_map[wa_id])
                        if base_id not in self.base_genes_version:
                            # No version suffix in base annotation, add one
                            self.name_map[wa_id] = self.name_map[wa_id] + ".1"
                        else:
                            # There was already a version suffix, increment it
                            suffix = self.base_genes_version[base_id] + 1
                            self.name_map[wa_id] = self.id_syntax.replace('{id}', search.group(1)) + "." + str(suffix)
                        self.num_new_id += 1

                    # See if the gene had exactly the same structure in the base annotation
                    if base_id in self.base_genes_structure and self.base_genes_structure[base_id] == self.make_structure_hash(f):
                        self.num_unchanged_from_base += 1

                    self.apollo_genes_positions[wa_id] = {'start': f.location.start, 'end': f.location.end, 'strand': f.location.strand}

        wgff_handle.close()

    # Generate a list of genes to delete from the base annotation
    def prepare_base_blacklist(self):
        # It also contains ids that were generated for new genes (will not be found in base annotation file)
        self.blacklist_base = []
        for waid in self.name_map:
            # Only keep in the blacklist the genes that are still in the latest apollo annotation
            # We do this to avoid removing a gene that was replaced in the previous annotation, but is no longer replaced in the new
            # The version suffix is not present in blastlist_base
            if waid in self.apollo_ids_in_latest:
                search = re.search(self.id_regex, self.name_map[waid])
                if search:
                    self.blacklist_base.append(self.id_syntax.replace('{id}', search.group(1)))
                else:
                    print("ERROR: could not parse name mapping ID: " + self.name_map[waid])

    # Remove all isoforms marked as deleted by annotators
    # Write a filtered gff into the tmpdir
    def filter_deleted_isoforms(self):

        print("Filtering deleted isoforms...")

        mrnas_to_delete = []

        # Check the mRNAs that were specifically deleted by annotators
        if self.deleted:
            with open(self.deleted, 'r') as f:
                for line in f:
                    mrnas_to_delete.append(line.strip())

        in_handle = open(self.base_gff)

        recs = []
        for rec in GFF.parse(in_handle):
            rec.annotations = {}
            rec.seq = ""
            new_feats = []

            for gene in rec.features:  # gene

                new_subs = []

                for mRNA in gene.sub_features:
                    if mRNA.qualifiers['ID'][0] not in mrnas_to_delete:
                        new_subs += [mRNA]
                    else:
                        self.num_deleted_isoforms += 1

                if len(new_subs) == 0:
                    # No more isoforms in the gene, delete the whole gene
                    self.whole_genes_deleted.append(gene.qualifiers['ID'][0])
                    continue

                gene.sub_features = new_subs
                new_feats.append(gene)

            rec.features = new_feats

            if len(rec.features):
                recs.append(rec)

        in_handle.close()

        out_handle = open(self.filtered_base_gff, "w")
        GFF.write(recs, out_handle)
        out_handle.close()

    # Write final GFF
    def write_gff(self):
        # First write genes kept from the base annotation...
        in_handle = open(self.filtered_base_gff)
        recs = []

        all_merged = [j for i in self.blacklist_merged.values() for j in i]
        self.num_merged = len(all_merged)

        for rec in GFF.parse(in_handle):
            rec.annotations = {}
            rec.seq = ""
            new_feats = []

            for f in rec.features:  # gene

                gene_id = f.qualifiers['ID'][0]

                search = re.search(self.id_regex, gene_id)
                raw_gene_id = self.id_syntax.replace('{id}', search.group(1))

                if raw_gene_id not in self.blacklist_base and raw_gene_id not in all_merged:
                    cleaned_f = self.clean_feature(f)
                    cleaned_f = self.guess_exons(cleaned_f)
                    cleaned_f = self.renumber_exons(cleaned_f)
                    new_feats.append(cleaned_f)
                    self.num_kept_base += 1
                    self.num_total_genes += 1
                else:
                    self.num_replaced_base += 1

            rec.features = new_feats

            if len(rec.features):
                recs.append(rec)

        in_handle.close()

        # ... then write genes models from apollo
        in_handle = open(self.apollo_gff)
        for rec in GFF.parse(in_handle):
            rec.annotations = {}
            rec.seq = ""
            new_feats = []

            for f in rec.features:  # gene

                gene_id = f.qualifiers['ID'][0]

                if gene_id in self.name_map:
                    f.qualifiers['Alias'] = [gene_id]
                    f.qualifiers['ID'][0] = self.name_map[gene_id]
                    nfeat = self.clean_feature(f)
                    nfeat = self.guess_utrs(nfeat)
                    nfeat = self.renumber_exons(nfeat)
                    new_feats.append(nfeat)
                    self.num_total_genes += 1

            rec.features = new_feats

            if len(rec.features):
                recs.append(rec)

        in_handle.close()

        out_handle = open(self.out_gff, "w")
        GFF.write(recs, out_handle)
        out_handle.close()

    def print_statistics(self, file):
        num_apollo_genes = self.num_new_id + self.num_unchanged_from_previous + self.num_version_bumped

        print("Starting numbering of new genes from: {}".format(self.highest_id), file=file)

        print("    {} genes in the base annotation".format(len(self.base_genes_structure) + len(self.whole_genes_deleted)), file=file)
        if self.previous_gff:
            print("    {} genes in the previous annotation, including {} Apollo genes".format(self.num_total_genes_in_previous, len(self.previous_genes_name)), file=file)
        print("    {} genes in the new Apollo annotation".format(num_apollo_genes), file=file)
        print("", file=file)
        print("    {} genes in the new final gff file".format(self.num_total_genes), file=file)
        print("        {} genes from the base annotation were kept untouched".format(self.num_kept_base), file=file)
        print("        {} genes from apollo have no equivalent in base annotation".format(self.num_total_genes - len(self.base_genes_structure) + self.num_merged), file=file)
        print("        {} genes from the base annotation were replaced by an apollo gene".format(self.num_replaced_base), file=file)
        print("            {} base genes were merged into {} Apollo genes".format(self.num_merged, len(self.blacklist_merged)), file=file)
        print("        {} isoforms ({} whole genes) were deleted from the base annotation (by annotator request)".format(self.num_deleted_isoforms, len(self.whole_genes_deleted)), file=file)
        print("", file=file)
        if self.previous_gff:
            print("    {} ids compared to previous annotation: kept {} unchanged, updated {} ids, and {} are new.".format(num_apollo_genes, self.num_unchanged_from_previous, self.num_version_bumped, self.num_new_id), file=file)
            print("    {} genes have an id that was kept from previous annotation for compatibility reason".format(self.kept_compatibility), file=file)
            print("", file=file)
        print("    {} apollo genes have the exact same structure as in the base annotation".format(self.num_unchanged_from_base), file=file)
        print("    {} genes were splitted by annotators".format(self.num_splitted), file=file)
        print("", file=file)

    def print_convertion(self, file):

        print("# Id conversion table from {} to {}".format(os.path.basename(self.base_gff), self.source), file=file)
        if self.previous_gff:
            print("# {}\t{}\t{}\tapollo_gene_id".format(self.source, os.path.basename(self.base_gff), os.path.basename(self.base_gff)), file=file)
        else:
            print("# {}\t{}\tapollo_gene_id".format(self.source, os.path.basename(self.base_gff)), file=file)

        for waid in self.name_map:
            old_id = '-'
            search = re.search(self.id_regex, self.name_map[waid])
            if search:
                old_id = self.id_syntax.replace('{id}', search.group(1))
                if old_id not in self.base_genes_structure:
                    # Ahah, the id in name_map is a new one
                    old_id = '-'

            if self.previous_gff:
                previous_id = '-'
                if waid in self.previous_genes_name:
                    previous_id = self.previous_genes_name[waid]

                print("{}\t{}\t{}\t{}".format(self.name_map[waid], old_id, previous_id, waid), file=file)
            else:
                print("{}\t{}\t{}".format(self.name_map[waid], old_id, waid), file=file)

        for d in self.whole_genes_deleted:
            if self.previous_gff:
                print("deleted\t{}\t-\t-".format(d), file=file)
            else:
                print("deleted\t{}\t-".format(d), file=file)

    # Create fasta output using gffread
    def write_fasta(self):
        print("Generating fasta files using gffread...")

        self.run_cmd("gffread " + self.out_gff + " -g " + self.genome + " -w " + self.out_transcript + " -x " + self.out_cds + " -y " + self.tmpdir + '/proteins.fa')

        # Protein fasta file need to have modified id
        prot_in = open(self.tmpdir + '/proteins.fa', 'r')
        prot_out = open(self.out_protein, 'w+')
        for prot in prot_in:
            prot = prot.strip()
            if prot.startswith(">"):
                prot = re.sub(r'-R([A-Z]+)', r'-P\1', prot)
            print(prot, file=prot_out)
        prot_out.close()

    def parse_args(self):

        parser = argparse.ArgumentParser()
        parser.add_argument("genome", help="Genome file (fasta)")
        parser.add_argument("ogs_name", help="Name of the new OGS")
        parser.add_argument("id_regex", help="Regex with a capturing group around the incremental part of gene ids, and a second one around the version suffix (e.g. 'GSSPF[GPT]([0-9]{8})[0-9]{3}(\\.[0-9]+)?')")
        parser.add_argument("id_syntax", help="String representing a gene id, with {id} where the incremental part of the id should be placed (e.g. 'GSSPFG{id}001')")
        parser.add_argument("base_gff", help="The gff from the base annotation (usually automatic annotation)")
        parser.add_argument("apollo_gff", help="The gff from the new Apollo valid annotation")
        # In --previous_gff, the genes should have an Alias attribute containing the Apollo ID
        parser.add_argument("-p", "--previous_gff", help="The gff from the previous annotation version (if different than <base_gff>)")
        parser.add_argument("-d", "--deleted", help="File containing a list of mRNAs to remove")
        parser.add_argument("-o", "--out_prefix", help="Prefix for output files (default=<ogs_name>_<today's date>)")
        args = parser.parse_args()

        self.base_gff = args.base_gff
        self.filtered_base_gff = self.tmpdir + '/filtered_base.gff'  # base_gff without the deleted isoforms/genes
        self.previous_gff = args.previous_gff
        self.apollo_gff = args.apollo_gff
        self.genome = args.genome
        self.source = args.ogs_name
        self.deleted = args.deleted

        self.id_regex = args.id_regex
        self.id_syntax = args.id_syntax

        self.out_prefix = args.out_prefix
        if not self.out_prefix:
            self.out_prefix = self.source + '_' + datetime.date.today().strftime('%Y%m%d')

        self.out_gff = self.out_prefix + '.gff3'
        self.out_transcript = self.out_prefix + '_transcript.fasta'
        self.out_cds = self.out_prefix + '_cds.fasta'
        self.out_protein = self.out_prefix + '_protein.fasta'
        self.out_report = self.out_prefix + '_merge_report.txt'

        self.highest_id = 0
        self.padding_length = 0

        # Intersect results
        self.primary_matches = {}
        self.secondary_matches = {}

        # Information from base annotation
        self.base_genes_structure = {}
        self.base_genes_version = {}  # Store the original version of the genes
        self.base_genes_positions = {}  # Store positions of all base genes

        # Information from previous annotation
        self.previous_genes_content = {}
        self.previous_genes_name = {}
        self.num_total_genes_in_previous = 0
        self.kept_compatibility = 0  # Number of genes which id was kept from previous annotation for compatibility

        # Information from apollo annotation
        self.apollo_ids_in_latest = []
        self.apollo_genes_positions = {}  # Store positions of all apollo genes

        # Name mapping
        self.name_map = {}  # keys = Apollo id, values = best corresponding base annotation id
        self.blacklist_merged = {}

        # Statistics
        self.num_unchanged_from_base = 0
        self.num_unchanged_from_previous = 0
        self.num_version_bumped = 0
        self.num_new_id = 0
        self.num_kept_base = 0
        self.num_total_genes = 0
        self.num_replaced_base = 0
        self.num_splitted = 0
        self.num_merged = 0
        self.num_deleted_isoforms = 0
        self.whole_genes_deleted = []

    def main(self):

        # Create a temp directory to launch bedtools
        self.tmpdir = tempfile.mkdtemp()

        self.parse_args()

        # Delete isoforms (when asked by annotators)
        self.filter_deleted_isoforms()

        # Launch external tools to compute intersection between apollo genes and genes from base annotation
        self.match_apollo_against_base()

        # Parse the base annotation file
        self.parse_base_annotation()

        # Load the previous annotation (if any)
        if self.previous_gff:
            self.parse_previous_annotation()

        # Assign IDs to genes that got splitted in apollo
        self.handle_splitted()

        # Ensure we assign the same ids as in previous annotation
        self.check_id_compatibility()

        # Increase a little to be sure we don't reuse the id of a gene that was removed
        self.highest_id += 100

        # Parse latest apollo annotation
        self.parse_apollo_annotation()

        # Blacklist genes that were replaced
        self.prepare_base_blacklist()

        # Find genes that got merged in apollo
        self.handle_merged()

        # Print output gff
        self.write_gff()

        # Print output gff
        self.write_fasta()

        # Print some general statistics
        self.print_statistics(sys.stdout)
        self.print_statistics(open(self.out_prefix + '_merge_stats.txt', 'w'))

        # Print ID convertion table from base annotation to new one
        self.print_convertion(open(self.out_prefix + '_merge_ids_conversion.tsv', 'w'))

        # Remove temp dir
        shutil.rmtree(self.tmpdir)


if __name__ == '__main__':
    ogsmerger = OgsMerger()

    ogsmerger.main()
