#!python

import argparse
import json
import os
import sys
import numpy as np
from console_progressbar import ProgressBar
from Bio import AlignIO

from modelmatcher.models import RateMatrix
from modelmatcher.model_io import read_model

_ALPHABET = 'ARNDCQEGHILKMFPSTWYV'

def guess_format(filename):
    '''
    Returns a string that guesses the MSA format used in filename.
    '''
    with open(filename) as h:
        first_line = h.readline()
        if first_line[:15] == '# STOCKHOLM 1.0':
            return 'stockholm'
        elif first_line == 'CLUSTAL':

            return 'clustal'
        elif first_line[0] == '>':
            return 'fasta'
        elif first_line.lower().startswith("#nexus"):
            return 'nexus'
        else:
            tokens = first_line.split()
            if len(tokens) != 2:
                raise Exception(f'Does not recognize the format in {filename}')
            else:
                try:            # If the first line contains two integers, then it ought to be Phylip
                    a = int(tokens[0])
                    b = int(tokens[1])
                except:
                    raise Exception(f'Does not recognize the format in {filename}')
                # Came this far? Success!
                return 'phylip'


def read_alignment(file, input_format):
    '''
    Factory function. Read the alignment with BioPython's support, and
    return an appropriate alignment.
    '''
    if file == '-':
        file = sys.stdin        # Start reading from stdin if "magic filename"
    elif input_format == 'guess':
        input_format = guess_format(file)
    alignment = AlignIO.read(file, input_format)
    return alignment



def add_to_count_matrix(m, s1, s2):
    '''
    In: Initialised 20x20 matrix m and a 20-element frequency matric.

    The function adds amino acid pair counts to m.
    '''
    for i, j in zip(s1, s2):
        if i and j:
            m[i][j] = m[i][j] + 1


import itertools as it
def msa_to_count_matrix(msa, verbose=False):
    '''
    Return a normalised count matrix, $F$,  and a frequency vector,
    $\pi$, based on the input MSA.
    '''
    # Prepare to translate aa sequences to integer lists
    alphabet_dictionary = {}
    for i, letter in enumerate(_ALPHABET):
        alphabet_dictionary[letter] = i

    sequences = []
    for s in map(lambda r: r.seq, msa):
        sequences.append(list(map(lambda c: alphabet_dictionary.get(c, None), s)))

    if verbose:
        n_seqs = len(sequences)
        n_pairs = n_seqs * (n_seqs - 1) / 2
        pb = ProgressBar(total=n_pairs, prefix='Counting AA pairs. Progress:', suffix='', decimals=0, length=50, fill='+', zfill=' ')

    m_sum = np.zeros((20, 20))
    i=0
    for a, b in it.combinations(sequences, 2):
        add_to_count_matrix(m_sum, a, b)
        if verbose:
            i += 1
            pb.print_progress_bar(i)

    # Now compute the aa frequency vector
    if verbose:
        pb = ProgressBar(total=n_seqs, prefix='Counting AA freqs. Progress:', suffix='', decimals=0, length=50, fill='+', zfill=' ')
        i = 0
    f_vec = np.zeros(20)
    for s in sequences:
        if verbose:
            i+=1
            pb.print_progress_bar(i)
        for aa in s:
            f_vec[aa] += 1

    symmetric = np.add(m_sum, m_sum.T)
    scale = np.sum(f_vec)
    return symmetric, f_vec/scale


def diagonalize_with_model(m, Q):
    '''
    Try to diagonalize m using the eigenvectors of Q and the equilibrium distribution.
    Return the sum of the off-diagonal elements of what was supposed to be the
    diagonal matrix.
    '''
    off_diagonal_elems = np.ones((20,20))
    np.fill_diagonal(off_diagonal_elems, 0) # Set the diagonal to zero

    L_ev, R_ev = Q.get_eigenvectors()

    right_product = np.matmul(m, R_ev)
    f = Q.get_freq()
    left = np.matmul(L_ev, np.diag(1.0/f))
    prod = np.matmul(left, right_product)
    off_diagonal = prod * off_diagonal_elems
    return np.sum(np.absolute(off_diagonal)) # Measure the deviation as sum of off-diagonal elements

def apply_models(m, q=None):
    '''
    Try eigenvectors of standard rate matrices aginst the count matrix m
    '''
    matrices= list(RateMatrix.get_all_models())
    diffs = []

    if q:
        matrices.append(q)

    for Q in matrices:
        diff = diagonalize_with_model(m, Q)
        diffs.append((Q.get_name(), diff))

    return diffs

def main(args):
    if args.verbose:
        print('Reading MSA', file=sys.stderr)
    msa = read_alignment(args.infile, args.format)

    N, amino_acid_freqs = msa_to_count_matrix(msa, args.verbose)
    n_observations = int(np.sum(N))
    F = N / n_observations
    if args.verbose:
        print('Testing models', file=sys.stderr)

    if args.model:
        with open(args.model) as h:
            extra_model = read_model(h, args.model)
        diffs = apply_models(F, extra_model)
    else:
        diffs = apply_models(F)

    if args.json:
        json_data = {"n_observations": n_observations,
                     "infile": args.infile,
                     "n_seqs": len(msa),
                     "model_ranking": sorted(diffs, key = lambda d: d[1])
                     }
        json.dump(json_data, sys.stdout)
        print()
    else:
        for name, d in sorted(diffs, key = lambda d: d[1]):
            print(f'{name:12} {d:>8.3f}')

if __name__ == '__main__':
    try:
        # For convenience, set some print options for numpy
        np.set_printoptions(precision=3)
        np.set_printoptions(suppress=True)
        np.set_printoptions(linewidth=200)

        ap = argparse.ArgumentParser(description='Suggest model based on an MSA without using likelihood calculations.')
        ap.add_argument('infile',
                        help='A multi-sequence alignment.')  # It is easy to change to handle Fasta, Clustal, Stockholm, Nexus
        ap.add_argument('-f', '--format', choices=['guess', 'fasta', 'clustal', 'nexus', 'phylip', 'stockholm'], default='guess',
                        help="Specify what sequence type to assume. Be specific if the file is not recognized automatically. When reading from stdin, the format is always guessed to be FASTA. Default: %(default)s")
        ap.add_argument('-m', '--model', metavar='filename',
                        help='Add the model given in the file to the comparisons.')
        ap.add_argument('-j', '--json', action='store_true',
                        help='Output the information in JSON format, for easier later parsing.')
        ap.add_argument('--verbose', action='store_true',
                        help='Output progress information')
        args = ap.parse_args()
        main(args)
    except KeyboardInterrupt:
        sys.exit()
