#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, print_function, unicode_literals

import csv
import logging
import re
import sys

from hgvs.dataproviders.uta import connect
from hgvs.parser import Parser
from hgvs.variantmapper import EasyVariantMapper


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

cigarop_re = re.compile('(?P<l>\d+)(?P<op>[=IDX])')

def generate_variants(r):
    """given a query row, return a set of variants corresponding to mismatches"""
    ts = 0
    gs = 0
    cigarops = list(cigarop_re.finditer(row['cigar']))
    #if len(cigarops) > 4:
    #    raise RuntimeError("{n} cigar ops... skipping".format(n=len(cigarops)))
    for m in (m.groupdict() for m in cigarops):
        ac = row['tx_ac']
        l = int(m['l'])
        tseq = row['tx_aseq'].replace('-','')
        gseq = row['alt_aseq'].replace('-','')
        if m['op'] != '=':
            vs = row['tx_start_i'] + ts
            if m['op'] == 'X':
                tq = tseq[ts:ts+l]
                gq = gseq[gs:gs+l]
                if l==1:
                    v = "{ac}:n.{vs}{tq}>{gq}".format(ac=ac, vs=vs+1, tq=tq, gq=gq)
                else:
                    v = "{ac}:n.{vs}_{ve}delins{gq}".format(ac=ac, vs=vs+1, ve=vs+l, gq=gq)
            elif m['op'] == 'D':
                if l==1:
                    v = "{ac}:n.{vs}del".format(ac=ac, vs=vs+1)
                else:
                    v = "{ac}:n.{vs}_{ve}del".format(ac=ac, vs=vs+1, ve=vs+l-1)
            elif m['op'] == 'I':
                gq = row['alt_aseq'][gs:gs+l]
                v = "{ac}:n.{vs}_{ve}ins{gq}".format(ac=ac, vs=vs, ve=vs+1, gq=gq)

            #import IPython; IPython.embed()	  ### TODO: Remove IPython.embed()
            yield v

        if m['op'] in '=XD':
            ts += l
        if m['op'] in '=XI':
            gs += l


if __name__ == "__main__":
    hdp = connect()
    hp = Parser()
    evm = EasyVariantMapper(hdp)

    fieldnames = 'hgvs_n hgvs_c hgvs_g tx_ac strand exon'.split()
    out = csv.DictWriter(sys.stdout, fieldnames=fieldnames, delimiter=b'\t', lineterminator='\n')
    out.writeheader()

    tx_list = []
    #tx_list = 'NM_178040.2 NM_015512.4'.split()
    tx_list = sys.argv[1:]
    query = [
        "select D.*",
        "from reece.discrep_mv D",
        "join reece.ej_gene_transcript_v GL on D.tx_ac=GL.tx_ac",
        "where D.tx_ac in ({txs})".format(txs=",".join(["'{ac}'".format(ac=ac) for ac in tx_list])) if tx_list else "",
        "order by D.tx_ac,D.ord"
        ]
    cur = hdp._execute(" ".join(query))
    for row in cur:
        try:
            for var_str in generate_variants(row):
                var_n = hp.parse_hgvs_variant(var_str)
                var_c = evm.n_to_c(var_n)
                var_g = evm.n_to_g(var_n)
                out.writerow({
                    'hgvs_n': str(var_n),
                    'hgvs_c': str(var_c),
                    'hgvs_g': str(var_g),
                    'tx_ac': row['tx_ac'],
                    'strand': row['alt_strand'],
                    'exon': row['ord']+1,
                    })
        except Exception as e:
            print("# {row[tx_ac]}, exon {row[ord]}: {e}".format(row=row, e=e))
            #logger.exception(e)
