#!python

#################
# Frameshift_deletions_checks
#
# Description:
# Detect frameshifts indels and stop codons in the provided viral consensus sequences and generate a report of the event
#
# Authors:
# Lara Fuhrmann, Ivan Topolsky (CBG ETHZ)- frameshift indels detection and reporting
# Matteo Carrara (NEXUS Personalized Health Technologies) - stop codons detection and reporting
#################


import pysam
import pysamstats
import numpy as np
import pandas as pd
from io import StringIO
from Bio import SeqIO
from Bio import AlignIO
from Bio.Align.Applications import MafftCommandline
from BCBio import GFF
import operator
import argparse
import re
import os
import tempfile
import sys

from smallgenomeutilities._version import __version__



def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def parse_args():
    """ Set up the parsing of command-line arguments """

    # keep lines in epilog ; but keep the defaults in the list
    class Formatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter): pass

    parser = argparse.ArgumentParser(
        description="Produces a report about frameshifting indels and stops in a consensus sequences",
        epilog=u"""columns signification:
	[ref_id/cons_id]: name of the sequence in the reference and consensus
	[start_position/length]: location of the variant
	[VARIANT]: one of: "insertion", "deletion", "stopgain" or "stoploss"
	[gene_region]: Gene in which the deletion is found according to -g argument;
	[reads_all]: Total number of reads covering the indel;
	[reads_fwd]: Total nubmer of forward reads covering the indel;
	[reads_rev]: Total nubmer of reverse reads covering the indel;
	[deletions/insertions/stops]: Number of reads supporting the deletion/insertion/stop;
	[freq_del/freq_insert/freq_stop]: Fraction of reads supporting the deletion/insertion/stop;
	[matches_ref]: number of reads that matche with the reference base;
	[pos_critical_inserts]: Start positions of insertions in the same gene_region that occur in > 40% of reads;
	[pos_critical_dels]: Start positions of deletions in the same gene_region that occur in > 40% of reads;
	[homopolymeric]: True if either around the start or end position of the deletion three bases are the same, which may have caused the polymerase to skip during reverse transcription of viral RNA to cDNA, e.g. AATAG;
	[ref_base]: base in the reference genome;
	[variant_position_english]: english sentence describing the indel or stop;
	[variant_diagnosis]: english sentence with the indel diagnosis""",
        formatter_class=Formatter)
    requiredNamed = parser.add_argument_group('required named arguments')
    requiredNamed.add_argument(
        "-i", "--input", required=True, metavar='BAM', dest='bamfile',
        help="Input BAM file, aligned against the reference"
    )
    requiredNamed.add_argument(
        "-c", "--consensus", required=True, metavar='FASTA', dest='ref_majority_dels', type=str,
        help="Fasta file containing the ref_majority_dels consensus sequence"
    )
    requiredNamed.add_argument(
        "-f", "--reference", required=True, metavar='FASTA', dest='reference', type=str,
        help="Fasta file containing the reference sequence to compare against"
    )
    requiredNamed.add_argument(
        "-g", "--genes", required=True, metavar='GFF', dest='genes_gff', type=str,
        help="GFF file listing genes positions on the reference sequence"
    )
    requiredNamed.add_argument(
        "-O", "--orf1ab", required=False, dest='orf1ab', type=str, default = 'cds-YP_009724389.1',
        help="CDS ID for the full Orf1ab CDS, comprising the ribosomal shift. In the GFF this CDS should consist of 2 entries with the same CDS ID due to the partial overlap caused by the ribosomal shift at translation time"
    )
    parser.add_argument(
        "-s", "--chain", required=False, dest='chain', type=str,
        help="Chain file describing how the consensus is aligned to the reference (e.g. `bcftools consensus --chain ...`); If not provided, mafft will be used to align the consensus to the reference."
    )
    parser.add_argument(
        "-0", "--zero-based", required=False, dest='based', action='store_const', const=0, default=1,
        help="Use 0-based (python) instead of 1-based (standard) seq positions"
    )
    parser.add_argument(
        "-o", "--output", required=False, default=os.path.join(os.getcwd(), 'frameshift_deletions_check.tsv'),
        metavar='TSV', dest='outfile', help="Output file"
    )
    parser.add_argument(
        "-E", "--no-english", dest='english', action='store_false')
    parser.add_argument(
        "-e", "--english", required=False, type=str2bool, nargs='?', const=True, dest='english',
        help="If True writes english summary diagnosis."
    )
    parser.add_argument(
        "-v", "--version", action='version', version='%(prog)s {version}'.format(version=__version__))
    parser.set_defaults(english=True)
    return parser.parse_args()

def check_homopolymeric(variation_info, position, gap_length, indel_type):
    '''
    return homopolmyeric == True if either around the start_position or the end_position
    between the two neighbors 3 are of the same base, eg. AATAG
    '''

    # NOTE insertions are in reference space, so the end of the insetion is at [position + 1]
    # whereas deletions' end is at [ position + lenght -1 ]
    for p in [ position-1, position+(1 if indel_type == 'insertion' else gap_length)-1 ]:
        nuc_list =[]
        # list index for [position +/- 2]
        # also only pick covered psotions (e.g.: "nnnnATTCG-Cnnn" - 'C' is alone, there no reads on position+2)
        idx_pos = np.where(np.logical_and(variation_info.pos>=p-2, variation_info.pos<=p+2))[0]
        if idx_pos.size==0:
            # really broken alignment
            print(f"Warning no reads around {ref_id}:{position-2}:{position+2} for homopolymeric ?!", file=sys.stderr)
            continue

        # count all the listed
        nuc_list=[variation_info[p].ref[0] for p in idx_pos]
        uni_list = np.unique(nuc_list, return_counts=True)
        if np.max(uni_list[1])>2:
            return 1

    return 0

def parse_gff(genes_gff, featuretype):
    """
    Return a gene_list suitable for get_gene_at_position
    """
    if genes_gff:
        with open(genes_gff) as gf:
            if featuretype in ['gene', 'Gene', 'GENE']:
                return [ (record.id, int(feature.location.start), int(feature.location.end), feature.qualifiers.get('Name', [feature.id])[0]) for record in GFF.parse(gf) for feature in record.features if feature.type == 'gene' ]
            elif featuretype in ['cds', "CDS"]:
                genes = [feature for record in GFF.parse(gf) for feature in record.features if feature.type == 'gene' ]
                #for gene in genes:
                #cds_list.extend(gene.sub_features)
                return genes
            else:
                sys.exit("Unrecognized featuretype for parsing GFF. Accepted values are 'gene' and 'CDS'")
    return None

def get_gene_at_position(gene_list, ref_id, position):
    """
    Return gene and gene-region-interval the position belongs to.
    This is only valid for SARS-CoV2.
    """
    for gene in gene_list:
        if (gene[0] == ref_id) and (position in range(gene[1],gene[2])):
            return gene

    return (ref_id, position-10, position+10, '-')

def check_indels_gene_region(gene_reg_load_variation, curr_pos, gap_length, indel_type, region_start, region_end, based):
    """
    check if there are some indels in the gene-region-interval(*) that occur in
    more than 40% of the reads. If so, returns the position of those insertions.
    ---
    (*): or at least the part that got coverage
    """

    # clip the request to position for which we have information. Worse-case scenario: begins or ends with the indel
    i_start = np.where(gene_reg_load_variation.pos >=region_start)[0][0]
    i_end = np.where(gene_reg_load_variation.pos <=region_end)[0][-1]

    # NOTE the index of these scrutures will be same across all,
    # BUT will not correspond to the actual position
    indel_pos = gene_reg_load_variation.pos[i_start:i_end]
    reads_all = gene_reg_load_variation.reads_all[i_start:i_end]
    inserts= gene_reg_load_variation.insertions[i_start:i_end]
    dels= gene_reg_load_variation.deletions[i_start:i_end]

    # NOTE When looking for an insert:
    # - we will not list the insert *itself*
    # - we *will* list overlapping *deletions* on other reads.
    # And vice-versa for deletions.
    critical_inserts = [indel_pos[i]+based for i,x in enumerate(inserts) if (x > 0.4*reads_all[i]) and (indel_type=='deletion' or indel_pos[i]!=curr_pos)]
    critical_dels = [indel_pos[i]+based for i,x in enumerate(dels) if (x > 0.4*reads_all[i]) and ((indel_type=='insertion') or (indel_pos[i] not in range(curr_pos, curr_pos+gap_length)))]
    return critical_inserts, critical_dels

def ranges(nums):
    """
    auxiliary function for list_frameshift_dels().
    Input is a list of numbers
    return ranges that are covered by those numbers,
    e.g. [1,2,3,10]--> [(1,3),(10,10)]
    """
    nums = sorted(set(nums))
    gaps = [[s, e] for s, e in zip(nums, nums[1:]) if s+1 < e]
    edges = iter(nums[:1] + sum(gaps, []) + nums[-1:])
    return list(zip(edges, edges))

def len_del(item_range):
    """
    auxiliary function for list_frameshift_dels().
    computing the lenght of item_range,
    """
    return item_range[1]- item_range[0]+1

def list_all_inserts(aligned_seqs):
    '''
    returns list with starting position of _all_ inserts and their resepective length as tuple: (start_pos, length)
    '''
    pos_length_list = []
    for reference_seq, consensus_seq in zip(*[iter(aligned_seqs)]*2):
        insert_pos = [i for i,x in enumerate(reference_seq.seq) if x =="-"]

        for item_range in ranges(insert_pos):
            pos_length_list.append([item_range[0],len_del(item_range), reference_seq.id, consensus_seq.id, 'insertion'])

    return pos_length_list

def list_all_dels(aligned_seqs):
    '''
    returns list with starting position of _all_  dels and their resepective length as tuple: (start_pos, length)
    '''
    pos_length_list = []
    for reference_seq, consensus_seq in zip(*[iter(aligned_seqs)]*2):
        del_pos = [i for i,x in enumerate(consensus_seq.seq) if x =="-"]

        for item_range in ranges(del_pos):
            pos_length_list.append([item_range[0],len_del(item_range),reference_seq.id, consensus_seq.id, 'deletion'])

    return pos_length_list

def extract_cds_range(cds_list):
    cds_positions = []
    for gene in cds_list:
        for cds in gene.sub_features:
            cds_positions.append([gene.id, cds.id, cds.location.start.position, cds.location.end.position])
    return cds_positions

def correct_cds_positions(cds_positions, list_inserts):
    # Move from reference space (CDS position in the GFF) to sequence space. Upstream inserts shift the start and end forward. Inserts inside the CDS shift only the end. Deletions do not shift as they are just positions that are substituted with "-"
    for count, insert in enumerate(list_inserts):
        for cds in cds_positions:
            if insert[0] <= cds[2]:
                cds[2] += insert[1]
                cds[3] += insert[1]
            if insert[0] > cds[2] and insert[0] <= cds[3]:
                cds[3] += insert[1]
    return cds_positions

def list_all_stops(aligned_seqs, cds_positions, orf1ab_name):
    '''
    returns list with starting position of _all_ stop codons for each CDS. The script also takes into account if there are no new stops and the expected stop is lost (stop-loss)
    The function has an exception for Orf1ab full CDS: Because of the gene being translated with a ribosome shift, there is a 1-nucleotide overlap between the region before and after the shift. This results in 2 separate CDS entries with identical gene- and CDS-ID. The exception identifies the two entries nd merges their nucleotide sequences together before translation
    The list of stops has the following structure: [ start, length, ref_id, id, variant_type, gene_id, cds_id, aminoacid_position, mismatches, first_nucleotide]
    More specifically: variant_type can be either "stopgain", "stoploss"; first_nucleotide is used only for stoplosses, it stores first nucleotide of the codon and it is used to retrieve the coverage of the actual nucleotide in the consensus, even in case of mixed coverage
    '''
    stop_list = []
    orf1ab_processed = False
    for cds in cds_positions:
       stoptype = "stopgain"
       if cds[1] == orf1ab_name:
            if orf1ab_processed:
                continue
            orf1ab_positions = [ cds for cds in cds_positions if cds[1] == orf1ab_name ]
            if len(orf1ab_positions) == 2:
                # Store the number of deletions within the CDS, as they will shift the relative position after the ungap
                orf1ab_sequence = aligned_seqs[1][orf1ab_positions[0][2]:orf1ab_positions[0][3]].seq
                cds_sequence = orf1ab_sequence + aligned_seqs[1][orf1ab_positions[1][2]:orf1ab_positions[1][3]].seq
                cds_del_number = cds_sequence.count("-")
                cds_sequence = cds_sequence.ungap("-")
                cds_sequence_for_translation = cds_sequence + ('n' * cds_del_number)
                orf1ab_processed = True
            else:
                sys.exit("Orf1ab full CDS is expected to have exactly 2 separate entries. Found " + str(len(orf1ab_positions)) + " entries matching CDS ID " + orf1ab_name + ". Please check you GFF and the provided CDS ID")
       else:
           #Remove the deletions from the CDS region before translating
           cds_sequence = aligned_seqs[1][cds[2]:cds[3]].seq
           cds_del_number = cds_sequence.count("-")
           cds_sequence = cds_sequence.ungap("-")
           cds_sequence_for_translation = cds_sequence + ('n' * cds_del_number)
       #TODO: add the option to provide a different stop codon table: https://biopython.org/docs/1.75/api/Bio.Seq.html#Bio.Seq.Seq.translate. See below for a second step that needs adaptation
       translated_cds = cds_sequence_for_translation.translate()
       # We are interested in the all stops we find. If it's before the last codon, it's a stopgain. If there is no stop codon at all, it's a stoploss
       stop_found = re.finditer(r'\*', str(translated_cds))
       stop_positions = [ m.start() for m in stop_found ]
       if len(stop_positions) == 0:
           potential_stoploss_seq = cds_sequence[len(cds_sequence)-3:len(cds_sequence)]
           if potential_stoploss_seq.find("n") == -1:
               stoptype = "stoploss"
               stop_pos_aa = [ len(translated_cds) ]
           else:
               # Ignore stoplosses that are detected because of n in the sequence
               continue
       else:
           # The positions in stop_positions are 0 based because they are positions in a python string. Converting back immediately to 1 based. No need to do that for the stoplosses because they are bound to the length of the string, not to the position
           stop_pos_aa = [ sp+1 for sp in stop_positions ]
           if stop_pos_aa[-1] == len(translated_cds) and stoptype != "stoploss":
               stop_pos_aa.pop()
       # If there is no actual stops, ignore this CDS and move to the next
       if len(stop_pos_aa) == 0:
           continue
       stop_rel_pos_nt = [ (sp*3)-2 for sp in stop_pos_aa ]
       # The ribosome shifts makes it so the position after the shift is +1 the actual position on the reference. Shift back if the cds is orf1ab and the position is after the start of the second part of the CDS 
       stop_abs_pos_nt = [ cds[2]+sp for sp in stop_rel_pos_nt ]
       stop_abs_pos_nt = correct_stop_abs_pos(stop_abs_pos_nt, aligned_seqs, cds[2])
       if cds[1] == orf1ab_name:
           stop_abs_pos_nt_corrected = []
           for stop in stop_abs_pos_nt:
               if stop > orf1ab_positions[1][2]:
                   stop_abs_pos_nt_corrected.append(stop-1)
               else:
                   stop_abs_pos_nt_corrected.append(stop)
           stop_abs_pos_nt = stop_abs_pos_nt_corrected
           #Adding padding at the end of the sequence at translation time can shift the position of the stop codon, which can be not at the end of the translated string. Here, after all corrections, we check again based on the absolute position of the stop, compared to the annotation. Please remember that the stop nt position is the first of the 3 nucleotides, so we need to remove 2 from the end of the gene
       if cds[1] == orf1ab_name:
           if stop_abs_pos_nt[-1] == (cds_positions[1][3]-2):
               stop_abs_pos_nt.pop()
               stop_pos_aa.pop()
       else:
           if stop_abs_pos_nt[-1] == (cds[3]-2) and stoptype != "stoploss":
                   stop_abs_pos_nt.pop()
                   stop_pos_aa.pop()
       # Report one stop per element for compatibility with the downstream functions. Differentiate between stopgain and stoploss. Remember to remove the expected stop at the end of the CDS if it exists
       # at this point check if the position has any mismatches with the reference
       if len(stop_pos_aa) != 0:
           for i in range(0, len(stop_pos_aa)):
               mismatches = find_mismatches_in_stops(aligned_seqs, stop_abs_pos_nt[i])
               if stoptype == "stoploss":
                   stoploss_nt = aligned_seqs[1][stop_abs_pos_nt[i]-1]
               else:
                   stoploss_nt = "None"
               tmp_list = [ stop_abs_pos_nt[i], 3, aligned_seqs[0].id, aligned_seqs[1].id, stoptype, cds[0], cds[1], stop_pos_aa[i], mismatches, stoploss_nt ]
               # Do not append duplicates. This removes duplicate antries for the alternate full read of orf1ab after the cds with the ribosome shift has been processed
               tmp_stop_list = [ [a[0], a[1], a[4], a[5] ] for a in stop_list ]
               if [ tmp_list[0], tmp_list[1], tmp_list[4], tmp_list[5] ] not in tmp_stop_list:
                   stop_list.append([ stop_abs_pos_nt[i], 3, aligned_seqs[0].id, aligned_seqs[1].id, stoptype, cds[0], cds[1], stop_pos_aa[i], mismatches, stoploss_nt ])
    return stop_list
           
def correct_stop_abs_pos(stop_abs_pos_nt, aligned_seqs, cds_start):
    '''
    If a stop is in a CDS that has deletions in the sequence, the relative positions are shifted by that many nucleotides by the ungapping if the deletion appears before the stop codon. When we move to absolute positions, we still need to correct for the deletions, because it's a missing information not yet evaluated. 
    '''
    for position,stop in enumerate(stop_abs_pos_nt):
        del_num = 0
        converged = False
        failsafe = 0
        while not converged:
            preceding_cds = aligned_seqs[1][cds_start:stop+del_num].seq
            previous_del_num = del_num
            del_num = preceding_cds.count("-")
            if previous_del_num == del_num:
                converged = True
            failsafe = failsafe + 1
            if failsafe > 1000000:
                sys.exit("Possible infinite loop in function: Execution interrupted as failsafe")
        stop_abs_pos_nt[position] = stop + del_num
        #print("DEBUG: found " + str(del_num) + " deletions. Shifted stop from " +str(stop) + " to "+ str(stop_abs_pos_nt[position]))
    return stop_abs_pos_nt

def find_mismatches_in_stops(aligned_seqs, consensus_space_pos):
    position_0based = consensus_space_pos - 1
    evaluated_pos = 0
    mismatches = 0
    while evaluated_pos < 3:
        ref_nt = aligned_seqs[0][position_0based]
        cons_nt = aligned_seqs[1][position_0based]
        if cons_nt == "-":
            position_0based = position_0based+1
            continue
        if ref_nt != cons_nt:
            if ref_nt != "-":
                mismatches = mismatches + 1
        position_0based = position_0based+1
        evaluated_pos = evaluated_pos + 1
    return mismatches

def find_indels_in_stops(list_stops, list_dels, list_inserts):
    # Stops with indel within the codon are reported alongside the indel itself and therefore must be flagged
    # We need to check if the range of the indel stretch falls in the range of 3 nucleotide from the start position of the stop
    # There can be up to 2 non-contiguous insertions or deletions in the codon, therefore we need to check all
    # E.g. _ _ _T-A-A_ _ _
    # The only exception is if the indel has length multiple of 3: in this case it's going to not be reported, therefore the stopgain still needs to stay standalone
    # for ranges, the current starting position counts as 1, so for a position X of length Y is (X+Y-1). However ranges() requires to define the 1-based position
    # of the end, meaning that it's always +1. (X+Y-1+1). Ranges need to be therefore declared as just X+Y to work for the ranges
    overlap_found = []
    for i,stop in enumerate(list_stops):
        stop_range = range(stop[0], stop[0]+stop[1])
        for deletion in list_dels:
            del_range = range(deletion[0], deletion[0]+deletion[1])
            if len(set(stop_range).intersection(del_range)) != 0 and deletion[1]%3!=0:
                deletion.append(stop[7])
                deletion.append(stop[8])
                deletion.append(stop[9])
                overlap_found.append(i)
        for ins in list_inserts:
            ins_range = range(ins[0], ins[0]+ins[1])
            if len(set(stop_range).intersection(ins_range)) != 0 and ins[1]%3!=0:
                ins.append(stop[7])
                ins.append(stop[8])
                ins.append(stop[9])
                overlap_found.append(i)
    for deletion in list_dels:
        if len(deletion) == 5:
            deletion.append("None")
            deletion.append(0)
            deletion.append("None")
        if len(deletion) != 5 and len(deletion) != 8:
            print("ERROR: invalid deletion length")
    for ins in list_inserts:
        if len(ins) == 5:
            ins.append("None")
            ins.append(0)
            ins.append("None")
        if len(ins) != 5 and len(ins) != 8:
            print("ERROR: invalid insertion length")
    if len(overlap_found) == 1:
        overlap_found = operator.itemgetter(*set(overlap_found))(list_stops)
        list_stops.remove(overlap_found)
    elif len(overlap_found) != 0:
        for j,this_overlap_pos in enumerate(overlap_found):
            this_overlap = list_stops[this_overlap_pos]
            list_stops.remove(this_overlap)
    return list_stops, list_dels, list_inserts
            

def align_with_mafft(reference, consensus):
    """
    we rely on a MAFFT alignment, in order to:
    - detect deletions in consensuses from callers that do not mark them
      (e.g.: bcftools without the `--mark-dels '-'` option)
    - detect insertions in the consensus
        (will appear as deletions *in the reference*)
    """


    all_align=[]

    # Loop through pairs to avoid MAFFT aligning different unrelated segments
    for seq_record, ref_record in zip(SeqIO.parse(consensus, "fasta"), SeqIO.parse(reference, "fasta")):
        # HACK assume seq_record have same order as ref_record
        # TODO test to multi-segemented viruses before the next flu pandemic
        print(f"{seq_record.id} mapped to {ref_record.id}")

        if (align_with_mafft.rx_only_n.match(str(seq_record.seq))):
            print(f"Warning: {seq_record.id} contains only <N>s", file=sys.stderr)
            # BUG MAFFT has a bug where if such a only-<N>s sequence is given as an argument, it will generate an alignement offset by 1 bogus gap at each end:
            # > Alignment with 2 rows and 29904 columns
            # > attaaaggtttataccttcccaggtaacaaaccaaccaactttc...aa- NC_045512.2
            # > -nnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn...nnn 100100_31_C02-20200508_J3GT6
            # these bogus gaps are not supported by any reads and break analyse_position()
            # we *must* not call MAFFT
            continue

        data=[ref_record, seq_record]

        # NOTE this also handles properly TMPDIR on clusters
        with tempfile.NamedTemporaryFile() as tmp:
            # Open the file for writing.
            with open(tmp.name, 'w') as f:
                SeqIO.write(data, tmp.name, "fasta")
                f.seek(0)

            mafft_cline = MafftCommandline(input=tmp.name)
            stdout, stderr = mafft_cline()
            if (len(stderr)>0):
                print(mafft_cline, ":", file=sys.stderr)
                print(stderr, file=sys.stderr)
            all_align += AlignIO.read(StringIO(stdout), "fasta")

    return all_align
# static re: only compile it once
align_with_mafft.rx_only_n=re.compile(r'^[Nn]+$')

def align_with_chain(reference, consensus, chain, keepcase = False):
    """
    use a .chain file produced by the consensus caller:
    - this doesn't rely on an external aligner
      - avoids corner cases (e.g. whole N consensus breaks MAFFT)
      - less dependencies
    - consensus caller knows in advance
      - where it put insertions
      - where it put deletions (if not marking them, e.g., `bcftools --mark-dels ')

    Format documentation: https://genome.ucsc.edu/goldenPath/help/chain.html
    Code inspiration : https://github.com/liguowang/CrossMap/blob/master/lib/cmmodule/utils.py#L265
    """
    from Bio.Seq import Seq
    from Bio.SeqRecord import SeqRecord
    from Bio.Align import MultipleSeqAlignment


    all_align = []

    # dictionary with each fragment from both reference and samples' consensus
    refs = SeqIO.to_dict(SeqIO.parse(reference, "fasta"))
    cons = SeqIO.to_dict(SeqIO.parse(consensus, "fasta"))

    ref = con = r_pos = c_pos = None
    # and now process the .chain
    with open(chain, "rt") as f:
        for l in f:
            l = l.strip()
            if not l or l.startswith('#'):
                continue

            t = l.split()
            if t[0] == 'chain' and len(t) in [12, 13]:
                # Example: chain 4900 chrY 58368225 + 25985403 25985638 chr5 151006098 - 43257292 43257528 1

                # score = int(t[1]) # Alignment score
                # chain_id = None if len(t) == 12 else t[12]

                # reference:
                r_name = t[2]      # E.g. chrY
                r_size = int(t[3]) # Full length of the chromosome
                r_strand = t[4]	   # Must be +
                if r_strand != '+':
                    raise Exception(f"Source strand in a chain file must be +. ({l})")
                r_start = int(t[5]) # Start of source region
                r_end = int(t[6])  # End of source region

                # consensus:
                c_name = t[7]       # E.g. chr5
                c_size = int(t[8])  # Full length of the chromosome
                c_strand = t[9]    # + or -
                c_start = int(t[10])
                c_end = int(t[11])

                # if c_strand not in ['+', '-']:
                #     raise Exception("Target strand must be - or +. (%s)" % line)
                if c_strand != '+':
                    raise Exception(f"BUG: Target strand currently only supports +. ({l})")
                if r_name not in refs:
                    raise Exception(f"Source <{r_name}> not found in file {reference}")
                if c_name not in cons:
                    raise Exception(f"Target <{r_name}> not found in file {consensus}")

                ref = refs[r_name]
                con = cons[c_name]

                # handle what's before the start
                al_ref = ref.seq[0:r_start]
                r_pos = r_start

                al_con = Seq('') if r_start <= c_start else Seq("n" * (r_start - c_start))
                al_con +=  con.seq[ (0 if c_start <= r_start else c_start - r_start) : c_start]
                c_pos = c_start

                assert len(al_ref) == len(al_con), "internal failure: start side doesn't have the same size"

            elif t[0] != 'chain' and (r_pos is None or c_pos is None):
                raise Exception(f"No header line (chain ...) before line with chain details ({l})")

            elif t[0] != 'chain' and len(t) == 3:
                # Now read the alignment chain from the file and store it as a list (source_from, source_to) -> (target_from, target_to)

                size, r_gap, c_gap = int(t[0]), int(t[1]), int(t[2])

                # aligned part
                al_ref += ref.seq[r_pos:r_pos+size]
                al_con += con.seq[c_pos:c_pos+size]
                r_pos += size
                c_pos += size

                if r_gap:
                    # gap: reference with no consensus
                    al_ref += ref.seq[r_pos:r_pos+r_gap]
                    al_con += Seq('-' * r_gap)
                    r_pos += r_gap

                if c_gap:
                    # gap: consensus with no reference
                    al_ref += Seq('-' * c_gap)
                    al_con += con.seq[c_pos:c_pos+c_gap]
                    c_pos += c_gap

                # TODO handle '-' sstrand (reverse direction of consensus)

            elif t[0] != 'chain' and len(t) == 1:
                size = int(t[0])

                # aligned part
                al_ref += ref.seq[r_pos:r_pos+size]
                al_con += con.seq[c_pos:c_pos+size]

                assert len(al_ref) == len(al_con), "internal failure: lenght god out of syn"

                r_pos += size
                c_pos += size

                assert r_pos == r_end, f"internal failure: didn't consume reference to end ({r_pos} != {r_end})"
                assert c_pos == c_end, f"internal failure: didn't consume consensus to end ({c_pos} != {c_end})"

                all_align += MultipleSeqAlignment(
                                [
                                    SeqRecord(al_ref if keepcase else al_ref.lower(), id=r_name),
                                    SeqRecord(al_con if keepcase else al_con.lower(), id=c_name),
                                ],
                                annotations={"tool": "chain"},
                                column_annotations=None
                            )

                ref = con = r_pos = c_pos = None

            else:
                raise Exception(f"Invalid chain format. ({l})")

    return all_align



def check_dels_gene_region(frameshift_deletions, region_start, region_end, position):
    """
    Check for deletions that are also in the consensus in this gene region.
    """
    critical_pos= []
    for pos1 in frameshift_deletions:
        if pos1[0] in range(region_start, region_end):
            if pos1[0]!=position:
                critical_pos.append(pos1[0])

    return critical_pos

def correct_positions(list_dels,list_inserts, list_stops):
    '''
    Insertions shift the positions --> shift back to reference seq space
    Correct to zero-based (we need this for pysamstats).
    '''
    # correct deletions shifted due to insertions occuring before
    for count, insert in enumerate(list_inserts):
        for deletion in list_dels:
            if deletion[0] > insert[0]:
                deletion[0]-=insert[1]
    # shift stops due to insertions and deletions occuring before
    for count, insert in enumerate(list_inserts):
        for j in range(len(list_stops)):
            cds_stops_corrected = []
            if list_stops[j][0] > insert[0]:
                 list_stops[j][0] = list_stops[j][0] - insert[1]
    # shift stops to zero-based space
    for j in range(len(list_stops)):
        list_stops[j][0] = list_stops[j][0] -1
    # shift insertions due to insertions occuring before
    for count, insert in enumerate(list_inserts):
        for j in range(count+1, len(list_inserts)):
            if list_inserts[j][0] > insert[0]:
                list_inserts[j][0]-=insert[1]
    # shift insertions to zero-based space
    for j in range(len(list_inserts)):
        list_inserts[j][0]-=1
    corrected_insert = list_inserts
    corrected_dels = list_dels
    corrected_stops = list_stops
    return corrected_insert, corrected_dels, corrected_stops

def write_english_summary(df_temp):
    """
    Add summary variant_diagnosis column to deletion data table
    """
    variant_position_english =[]
    variant_diagnosis =[]
    for iter, row in df_temp.iterrows():
        # description of the indels themselves
        variant_type = None
        if row['VARIANT']=='deletion':
            variant_type = 'Gap'
            variant_pos_temp = f"{variant_type} of {row['length']} nucleotide(s) found at refpos {row['start_position']}"
            if row['aa_position'] != "None":
                variant_pos_temp = variant_pos_temp + f", in overlap with an early stop codon at aminoacid position {row['aa_position']}"
            if row['stoploss_nt'] != "None":
                variant_pos_temp = variant_pos_temp + ", in overlap with a stoploss" 
        elif row['VARIANT']=='insertion':
            variant_type = 'Insertion'
            variant_pos_temp = f"{variant_type} of {row['length']} nucleotide(s) found at refpos {row['start_position']}"
            if row['aa_position'] != "None":
                variant_pos_temp = variant_pos_temp + f", in overlap with an early stop codon at aminoacid position {row['aa_position']}"
            if row['stoploss_nt'] != "None":
                variant_pos_temp = variant_pos_temp + ", in overlap with a stoploss" 
        elif row['VARIANT']=='stopgain':
            variant_type = 'Stopgain'
            variant_pos_temp = f"Early Stopgain found at refpos {row['start_position']}, aminoacid position {row['aa_position']}"
        elif row['VARIANT']=='stoploss':
            variant_type = 'Stoploss'
            variant_pos_temp = f"{variant_type} found at refpos {row['start_position']}"
        else:
            raise ValueError("Bad variant type", row['VARIANT'])
        if row['stop_mismatches'] > 0:
            variant_pos_temp = variant_pos_temp + f", with {row['stop_mismatches']} mismatches"
        variant_position_english.append(str(variant_pos_temp))


         # diagnosis of the variant
        if variant_type == 'Gap':
            name = 'deletion'
            fwdc = row['freq_del_fwd']
            revc = row['freq_del_rev']
        elif variant_type == 'Insertion':
            name = 'insertion'
            fwdc = row['freq_insert_fwd']
            revc = row['freq_insert_rev']
        elif variant_type == 'Stopgain':
            name = 'stopgain'
            fwdc = row['freq_stop_fwd']
            revc = row['freq_stop_rev']
        elif variant_type == 'Stoploss':
            name = 'stoploss'
            fwdc = row['freq_stop_fwd']
            revc = row['freq_stop_rev']
        else:
            # Throw some 'internal error' exception
            raise ValueError("Bad variant type", row['VARIANT'])

        if (fwdc > 0.5 and revc > 0.5):
            support_status = f"{name} supported by majority of fwd and rev reads"
        elif (fwdc > 0.5 and row['reads_rev']== 0) or (revc > 0.5 and row['reads_fwd'] == 0):
            support_status = f"only fwd or rev reads available, {name} supported by the majority of them"
        elif (fwdc > 0.5) or (revc > 0.5):
            support_status = f"{name} supported by majority of fwd or rev reads"
        elif (fwdc >= 0.05 and row['reads_rev']== 0) or (revc >= 0.05 and row['reads_fwd'] == 0):
            support_status = f"only fwd or rev reads available, {name} supported by the minority of them [5%-50%]"
        elif (fwdc >= 0.05) or (revc >= 0.05):
            support_status = f"{name} supported by minority of fwd or rev reads [5%-50%]"
        else:
            support_status = f"{name} not supported (covered by less than 5% of fwd and rev reads)"

        homopolymeric_status="homopolymeric" if row['homopolymeric']==1 else None
        pos_critical_status="neighboring indels may restore reading frame" if (row['pos_critical_dels'] != []) or (row['pos_critical_inserts'] != []) else None

        text_row='; '.join([s for s in [support_status, homopolymeric_status, pos_critical_status] if s])
        variant_diagnosis.append(text_row)

    df_temp.pop('stop_mismatches')
    df_temp.pop('aa_position')
    df_temp.pop('stoploss_nt')
    df_temp['variant_position_english']=variant_position_english
    df_temp['variant_diagnosis']=variant_diagnosis
    return df_temp

def analyse_position(bamfile, reference, ref_id, position, stop_specific, gap_length, indel_type, gene_list, cons_id='', based=1):
    """
    gather information for current frameshift position.
    """
    gene_region= get_gene_at_position(gene_list, ref_id, position)
    region_start = max(gene_region[1], 0)
    region_end = gene_region[2]

    if (analyse_position.ref_id == ref_id and analyse_position.region_start==region_start and analyse_position.region_end==region_end):
        # cache hit! no need to reparse the BAM file!
        variation_info = analyse_position.indels_gene_reg
    else:
        analyse_position.indels_gene_reg = variation_info = pysamstats.load_variation_strand(bamfile, fafile=reference,
                                     chrom=ref_id,
                                     start=region_start, end=region_end)
        # NOTE the region of interest covers the position anyway, so we can re-use the cache of the whole region stats
                                     #start=position, end=position+gap_length)
        analyse_position.ref_id=ref_id
        analyse_position.region_start=region_start
        analyse_position.region_end=region_end


    idx_pos = np.where(variation_info.pos==position)
    if(idx_pos[0].size==0):
        # see MAFFT's BUG mentionned in align_with_mafft()
        print(f"Warning no read mapping to {ref_id}:{position}:+{gap_length}", file=sys.stderr)
        return { } # empty dict

    reads_all = variation_info[idx_pos].reads_all[0]
    reads_fwd = variation_info[idx_pos].reads_fwd[0]
    reads_rev = variation_info[idx_pos].reads_rev[0]

    deletions = variation_info[idx_pos].deletions[0]
    freq_del = deletions/reads_all
    deletions_fwd = variation_info[idx_pos].deletions_fwd[0]
    freq_del_fwd = deletions_fwd/reads_fwd if reads_fwd else 0
    deletions_rev = variation_info[idx_pos].deletions_rev[0]
    freq_del_rev = deletions_rev/reads_rev if reads_rev else 0

    insertions = variation_info[idx_pos].insertions[0]
    freq_insert = insertions/reads_all
    insertions_fwd = variation_info[idx_pos].insertions_fwd[0]
    freq_insert_fwd = insertions_fwd/reads_fwd if reads_fwd else 0
    insertions_rev = variation_info[idx_pos].insertions_rev[0]
    freq_insert_rev = insertions_rev/reads_rev if reads_rev else 0

    stops = 0
    freq_stop = 0
    stops_fwd = 0
    freq_stop_fwd = 0
    stops_rev = 0
    freq_stop_rev = 0

    aa_position = stop_specific[0]
    stop_mismatches = stop_specific[1]
    stoploss_nt = stop_specific[2]
    critical_inserts = []
    critical_dels = []
    homopolymeric = 0

    # TODO: The coverage for stopgains assumes a standard stop codon table and therefore codons always starting with T. When implementing alternative stop codon tables we need to follow an approach similar to what is done below for stoplosses: store the first aminoacid of the codon explicitly and recover the correct coverage here
    if indel_type == "stopgain":
        stops_fwd = variation_info[idx_pos].T_fwd[0]
        freq_stop_fwd = stops_fwd/reads_fwd if reads_fwd else 0
        stops_rev = variation_info[idx_pos].T_rev[0]
        freq_stop_rev = stops_rev/reads_rev if reads_rev else 0
        stops = stops_fwd + stops_rev
        freq_stop = stops/reads_all
    elif indel_type == "stoploss":
        if stoploss_nt == 'a' or stoploss_nt == 'A':
            stops_fwd = variation_info[idx_pos].A_fwd[0]
            stops_rev = variation_info[idx_pos].A_rev[0]
        if stoploss_nt == 'c' or stoploss_nt == 'C':
            stops_fwd = variation_info[idx_pos].C_fwd[0]
            stops_rev = variation_info[idx_pos].C_rev[0]
        if stoploss_nt == 'g' or stoploss_nt == 'G':
            stops_fwd = variation_info[idx_pos].G_fwd[0]
            stops_rev = variation_info[idx_pos].G_rev[0]
        if stoploss_nt == 't' or stoploss_nt == 'T':
            stops_fwd = variation_info[idx_pos].T_fwd[0]
            stops_rev = variation_info[idx_pos].T_rev[0]
        freq_stop_fwd = stops_fwd/reads_fwd if reads_fwd else 0
        freq_stop_rev = stops_rev/reads_rev if reads_rev else 0
        stops = stops_fwd + stops_rev
        freq_stop = stops/reads_all

    else:
        # NOTE the region of interest covers the position anyway, so we can re-use the whole region stats
        indels_gene_reg = analyse_position.indels_gene_reg
        critical_inserts, critical_dels= check_indels_gene_region(indels_gene_reg,position, gap_length, indel_type,
                                                    region_start, region_end, based)
        homopolymeric = check_homopolymeric(variation_info, position, gap_length, indel_type)

    dict = {'ref_id': ref_id,
            'start_position': position+based,
            'length': gap_length,
            'VARIANT': indel_type,
            'gene_region':gene_region[3],
            'aa_position': aa_position,
            'stop_mismatches': stop_mismatches,
            'stoploss_nt': stoploss_nt,
            'reads_all': reads_all,
            'reads_fwd': reads_fwd,
            'reads_rev': reads_rev,
            'deletions': deletions,
            'freq_del': freq_del,
            'freq_del_fwd': freq_del_fwd ,
            'freq_del_rev':freq_del_rev,
            'deletions_fwd': deletions_fwd,
            'deletions_rev': deletions_rev,
            'insertions': insertions,
            'freq_insert': freq_insert,
            'freq_insert_fwd': freq_insert_fwd ,
            'freq_insert_rev':freq_insert_rev,
            'insertions_fwd': insertions_fwd,
            'insertions_rev': insertions_rev,
            'stops': stops,
            'freq_stop': freq_stop,
            'freq_stop_fwd': freq_stop_fwd,
            'freq_stop_rev': freq_stop_rev,
            'stops_fwd': stops_fwd,
            'stops_rev': stops_rev,
            'matches_ref': variation_info[idx_pos].matches[0],
            'pos_critical_inserts': critical_inserts,
            'pos_critical_dels': critical_dels,
            'homopolymeric': homopolymeric,
            'ref_base': variation_info[idx_pos].ref[0],
            'cons_id': cons_id,
           }

    return dict
# keep a static cache between calls
analyse_position.ref_id=None
analyse_position.region_start=None
analyse_position.region_end=None
analyse_position.indels_gene_reg=None

def remove_df_duplicates(df_temp):
    # If we have a deletion in orf1b that restores the frame for orf1ab, we end up with duplicated stop entries, as the report always defines the region as orf1ab. We remove duplicates here
    dup_pos = df_temp.duplicated(subset = "start_position")
    dup_pos = [i for i, x in enumerate(dup_pos) if x]
    df_temp = df_temp.drop(dup_pos)
    return df_temp
    
def main():

    args = parse_args()
    bamfile = args.bamfile	    # e.g.: 'REF_aln_410130_171220eg29_H5.bam'
    reference = args.reference	# e.g.: '../references/NC_045512.2.fasta'
    consensus = args.ref_majority_dels	# e.g.: 'ref_majority_dels.fasta'
    chain = args.chain
    orf1ab_name = args.orf1ab # e.g.: 'cds-YP_009724389.1'
    based = args.based # e.g.: 1

    gene_list = parse_gff(args.genes_gff, featuretype='gene') # e.g.: 'Genes_NC_045512.2.GFF3'
    cds_list = parse_gff(args.genes_gff, featuretype='CDS')
    
    df = pd.DataFrame(columns=('ref_id','start_position','length','VARIANT','gene_region', 'aa_position', 'stop_mismatches', 'stoploss_nt',
                                'reads_all','reads_fwd','reads_rev',
                                'deletions','freq_del','freq_del_fwd','freq_del_rev',
                                'deletions_fwd','deletions_rev',
                                'insertions','freq_insert','freq_insert_fwd','freq_insert_rev',
                                'insertions_fwd','insertions_rev',
                                'stops','freq_stop','freq_stop_fwd','freq_stop_rev',
                                'stops_fwd','stops_rev',
                                'matches_ref','pos_critical_inserts','pos_critical_dels',
                                'homopolymeric','ref_base','cons_id'))


    # use mafft if no chain is provided
    align_seqs = align_with_chain(reference, consensus, chain, keepcase = False) if chain else align_with_mafft(reference, consensus)
    if (len(align_seqs)==0):
        print("Warning: no usable alignment", file=sys.stderr)
    elif (len(align_seqs)%2!=0):
        print(f"Fatal error: there are {len(align_seqs)} alignements (odd), they should be in pairs (even)!", file=sys.stderr)
        sys.exit(1)
    else:
        align_dels= list_all_dels(align_seqs)
        align_inserts= list_all_inserts(align_seqs)
        cds_positions = extract_cds_range(cds_list)
        cds_positions = correct_cds_positions(cds_positions, align_inserts)
        align_stops = list_all_stops(align_seqs, cds_positions, orf1ab_name)
        
        # convert coordinate from pair-alignment space to reference-space
        corrected_insert, corrected_dels, corrected_stops = correct_positions(align_dels,align_inserts, align_stops)
        
        corrected_stops, corrected_dels, corrected_insert = find_indels_in_stops(corrected_stops, corrected_dels, corrected_insert)

        # sort by ref_id, then by position to optimize cache hits in analyse_position
        for pos in sorted(corrected_dels+corrected_insert+corrected_stops, key=operator.itemgetter(2,0)):
            ref_id = pos[2]
            position = int(pos[0])
            gap_length = int(pos[1])
            cons_id = pos[3]
            indel_type = pos[4]
            stop_specific= None
            if indel_type == "stopgain" or indel_type == "stoploss":
                stop_specific = pos[7:10]
            elif indel_type == "insertion" or indel_type == "deletion":
                stop_specific = pos[5:8]
            if gap_length%3==0 and indel_type != "stopgain" and indel_type != "stoploss":
                # only frameshift insertions , i.e. insert lenght not dividible by 3; or stops
                continue
            pos_dict = analyse_position(bamfile, reference, ref_id, position, stop_specific, gap_length,indel_type,
                                        gene_list,cons_id, based=based)
            if (len(pos_dict)==0):
                # skip when no information extracted
                continue
            # explicitly transform the dictionary in a pandas dataframe to avoid errors
            pos_dict=pd.DataFrame.from_dict(pos_dict, orient='index').T
            df = pd.concat([df, pos_dict], ignore_index=True)

    df = remove_df_duplicates(df)
    if args.english==True:
        print("adding english language")
        df= write_english_summary(df)


    df.to_csv(args.outfile, sep='\t') # write to tsv-file

if __name__ == '__main__':
    main()
