#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, print_function, unicode_literals

"""compare hgvs and mutalyzer outputs

eg$ ./bin/hgvs-mutalyzer-compare -vv -eid -H acmg-g-current-hgvs.tsv -M acmg-g-current-mzr.tsv -B bermuda.tsv

The input files must be formatted as TSV, following Mutalyzer's positionConverter format:

Input Variant	Errors	Chromosomal Variant	Coding Variant(s)
NC_000001.10:g.100586496G>C		NC_000001.10:g.100586496G>C	NM_194292.1:c.483+438C>G	XM_005270551.1:c.-19+438C>G

For each matching input (g.) variant, compare the set of c. variants
computed by each tool.  The cases are:

empty -- both tools return empty sets
complete match -- both tools return identical sets
intersection match -- all intersection keys match
subset match -- some (not all) intersection keys metch
failed -- no intersection keys match

"""


import argparse
import codecs
import collections
import csv
import gzip
import logging
import os
import re
import sys


def parse_args(argv):
    # parse command line for configuration files
    ap = argparse.ArgumentParser(
        description = __doc__,
        formatter_class = argparse.ArgumentDefaultsHelpFormatter,
        )
    ap.add_argument('--bermuda-filename', '-B', required=True)
    ap.add_argument('--hgvs-filename', '-H', required=True)
    ap.add_argument('--match-eq-delins-on-location','-e', default=False, action='store_true')
    ap.add_argument('--mutalyzer-filename', '-M', required=True)
    ap.add_argument('--rewrite-identity-substitutions','-i', default=False, action='store_true')
    ap.add_argument('--strip-delins','-d', default=False, action='store_true')
    ap.add_argument('--verbose', '-v', default=0, action='count')
    args = ap.parse_args(argv)
    return args

def gzopen(fn,mode='rb'):
    return gzip.open(fn,mode) if fn.endswith('.gz') else open(fn,mode)

def read_pc_format(fn):
    """open file and return generator of dictionaries of positionConverter data.
    c. variants are grouped into a set of variants.
    """

    fh = gzopen(fn,'r')
    hdr = fh.readline()
    assert hdr == 'Input Variant\tErrors\tChromosomal Variant\tCoding Variant(s)\n'
    for line in fh:
        if line.startswith("#"):
            continue
        vals = line.strip('\n\r').split('\t')
        d = {'input': vals[0],
             'errors': vals[1],
             'g_var': vals[2],
             'c_vars': [c_var for c_var in vals[3:] if c_var.startswith('NM_')]}
        if d['errors']:
            continue
        yield d
    
def read_pc(fn):
    return { d['input']:d for d in read_pc_format(fn) }

def strip_delins(recs):
    for v in recs.itervalues():
        v['c_vars'] = [ re.sub('del[\dACGT]+ins','delins',cv) for cv in v['c_vars'] ]
def rewrite_identity_substitutions(recs):
    for v in recs.itervalues():
        v['c_vars'] = [ re.sub(r'([ACGT])>\1','=',cv) for cv in v['c_vars'] ]

def vars_match(opts,hv,mv):
    if hv == mv:
        return True
    if opts.match_eq_delins_on_location:
        hv = hv.rstrip('=')
        mv = re.sub('delins.+','',mv)
        return hv == mv
    return False


if __name__ == '__main__':
    logging.basicConfig(level=logging.WARN)
    logger = logging.getLogger(__name__)
    
    opts = parse_args(sys.argv[1:])
    if opts.verbose:
        logger.setLevel(logging.INFO if opts.verbose == 1 else logging.DEBUG)
    
    bermuda = {r['tx_ac']:r for r in csv.DictReader(gzopen(opts.bermuda_filename,'r'),delimiter=str('\t'))}
    dirty_acs = set(k for k,r in bermuda.iteritems() if ('D' in r['s_status'] or 'I' in r['s_status']))

    h_recs = read_pc(opts.hgvs_filename)
    logger.info("read {n} genomic variants from {opts.hgvs_filename}".format(n=len(h_recs),opts=opts))

    m_recs = read_pc(opts.mutalyzer_filename)
    logger.info("read {n} genomic variants from {opts.mutalyzer_filename}".format(n=len(m_recs),opts=opts))


    if opts.strip_delins:
        strip_delins(h_recs)
        strip_delins(m_recs)
    if opts.rewrite_identity_substitutions:
        rewrite_identity_substitutions(h_recs)
        rewrite_identity_substitutions(m_recs)

    g_hgvs_keys = set(h_recs.keys()) & set(m_recs.keys())
    logger.info("{n} genomic variants in common".format(n=len(g_hgvs_keys)))

    match_bins = collections.defaultdict(lambda: set())
    match_counts = collections.Counter()
    hmk_tot = 0
    h_missing = set()
    m_missing = set()

    for g_hgvs in sorted(g_hgvs_keys):
        h_rec = h_recs[g_hgvs] if g_hgvs in h_recs else None
        m_rec = m_recs[g_hgvs] if g_hgvs in m_recs else None

        if not h_rec:
            logger.warn("no hgvs record for "+g_hgvs)
            continue
        elif not m_rec:
            logger.warn("no Mutalyzer record for "+g_hgvs)
            continue
        assert h_rec is not None and m_rec is not None

        h_cs_d = { v.partition(':')[0]:str(v) for v in h_rec['c_vars'] }
        m_cs_d = { v.partition(':')[0]:str(v) for v in m_rec['c_vars'] }
            
        # "keys" are transcript accessions for most of the following code
        hk = set(h_cs_d.keys())      # hgvs keys
        mk = set(m_cs_d.keys())      # mzr keys
        hok = hk - mk                # hgvs only keys
        mok = mk - hk                # mzr only keys
        hmk = hk & mk                # intersection keys

        eqk = set(k for k in hmk if     vars_match(opts,h_cs_d[k],m_cs_d[k]))  # match keys
        nek = set(k for k in hmk if not vars_match(opts,h_cs_d[k],m_cs_d[k]))  # mismatch keys
        nek_dirty = nek & dirty_acs
        nek_clean = nek - dirty_acs 

        hmk_tot += len(hmk)

        match = None
        if len(hmk) == 0:
            # no transcript accessions in common
            if len(hk) == len(mk) == 0:   match = 'two-sided empty'
            elif len(hk) == 0:
                match = 'one-sided empty (hgvs)'
                h_missing |= mk
            elif len(mk) == 0:
                match = 'one-sided empty (mutalyzer)'
                m_missing |= hk
        elif len(eqk) == len(hmk):
            # all intersection keys eq
            if False and len(eqk) == len(mk) == len(hk):  match = 'perfect match'
            else:                         match = 'intersection match'
        elif len(eqk) < len(hmk):
            # some mismatches
            assert len(nek_clean) + len(nek_dirty) == len(nek)
            if len(nek_dirty) == 0:       match = 'all-clean mismatch'
            elif len(nek_clean) == 0:     match = 'all-dirty mismatch'
            elif len(nek_dirty) > 0 and len(nek_clean) > 0:
                                          match = 'semi-dirty mismatch'
        else:
            assert False, "Shouldn't be here"
        assert match is not None

        match_bins[match].add(g_hgvs)
        match_counts[match] += len(hmk)

        msg = "{g_hgvs}: {match}; keys: {ho}/{eqk}/{hmk}/{mo}".format(
            g_hgvs=g_hgvs, match=match,
            ho=len(hok), eqk=len(eqk), hmk=len(hmk), mo=len(mok))
        if eqk:
                matches = [ (h_cs_d[k],m_cs_d[k]) for k in eqk ]
                msg += "\n  matches:" + '; '.join(str(mm) for mm in matches)
        if nek:
            if nek_clean:
                mismatches = [ (h_cs_d[k],m_cs_d[k]) for k in nek_clean ]
                msg += "\n  clean mismatches:" + '; '.join(str(mm) for mm in mismatches)
            if nek_dirty:
                mismatches = [ (h_cs_d[k],m_cs_d[k]) for k in nek_dirty ]
                msg += "\n  dirty mismatches:" + '; '.join(str(mm) for mm in mismatches)
        print(msg)

    match_keys = [
        'two-sided empty',
        'one-sided empty (hgvs)',
        'one-sided empty (mutalyzer)',
        'perfect match',
        'intersection match',
        'all-clean mismatch',
        'all-dirty mismatch',
        'semi-dirty mismatch',
        ]

    for match in match_keys:
        print("{n:6d} {mc:6d} {match}".format(n=len(match_bins[match]),mc=match_counts[match],match=match))
    print("{n:6d} genomic variants compared".format(n=len(g_hgvs_keys)))
    print("{n:6d} transcript variants compared".format(n=hmk_tot))

    print("{n:6d} transcripts not in hgvs ({acs})".format(n=len(h_missing),acs=','.join(h_missing)))
    print("{n:6d} transcripts not in mutalyzer ({acs})".format(n=len(m_missing),acs=','.join(m_missing)))
