#!python

import argparse
import json
import os
import random
import sys
import numpy as np
from console_progressbar import ProgressBar
from Bio import AlignIO

from modelmatcher.models import RateMatrix
from modelmatcher.model_io import read_model
from modelmatcher.version import __version__

def guess_format(filename):
    '''
    Returns a string that guesses the MSA format used in filename.
    '''
    with open(filename) as h:
        first_line = h.readline()
        if len(first_line) == 0:
            raise IOError(f"Error: file '{filename}' is empty. Aborting.")
        elif first_line[:15] == '# STOCKHOLM 1.0':
            return 'stockholm'
        elif first_line == 'CLUSTAL':

            return 'clustal'
        elif first_line[0] == '>':
            return 'fasta'
        elif first_line.lower().startswith("#nexus"):
            return 'nexus'
        else:
            tokens = first_line.split()
            if len(tokens) != 2:
                raise Exception(f'Does not recognize the format in {filename}')
            else:
                try:            # If the first line contains two integers, then it ought to be Phylip
                    a = int(tokens[0])
                    b = int(tokens[1])
                except:
                    raise SyntaxError(f'Does not recognize the format in {filename}')
                # Came this far? Success!
                return 'phylip'


def read_alignment(file, input_format):
    '''
    Factory function. Read the alignment with BioPython's support, and
    return an appropriate alignment.
    '''
    if file == '-':
        file = sys.stdin        # Start reading from stdin if "magic filename"
    elif input_format == 'guess':
        input_format = guess_format(file)
    alignment = AlignIO.read(file, input_format)
    return alignment



def add_to_count_matrix(m, s1, s2):
    '''
    In: Initialised 20x20 matrix m and a 20-element frequency matric.

    The function adds amino acid pair counts to m.
    '''
    for i, j in zip(s1, s2):
        if i and j:
            m[i][j] = m[i][j] + 1


def convert_sequences(msa):
    '''
    Convert sequences in the msa from strings to lists of intergers: A=0, R=1, etc
    Return the a list of "integer sequences".

    Note that indels create None elements in the converted sequences.
    '''
    # Prepare to translate aa sequences to integer lists
    _ALPHABET = 'ARNDCQEGHILKMFPSTWYV' # Needed to map amino acids
    alphabet_dictionary = {}
    for i, letter in enumerate(_ALPHABET):
        alphabet_dictionary[letter] = i

    # Now scan sequences
    sequences = []
    for s in map(lambda r: r.seq, msa):
        new_s = list(map(lambda c: alphabet_dictionary.get(c, None), s))
        sequences.append(new_s)

    return sequences


def compute_aa_frequencies(slist):
    '''
    Take the output of convert_sequences and count the number of each
    amino acid, and return a vector of amino acid frequencies.
    '''
    counts = np.zeros(20)
    sum_of_amino_acids = 0
    for s in slist:
        for aa in s:
            if aa != None:              # Indels create None values, so must avoid those
                counts[aa] += 1
                sum_of_amino_acids += 1
    return counts / sum_of_amino_acids


import itertools as it

def msa_to_count_matrix(msa, sample_size=0, verbose=False):
    '''
    Return a normalised count matrix, $F$,  and a frequency vector,
    $\pi$, based on the input MSA.

    If sample_size is 0, then every sequence pair will be considered,
    otherwise that many pairs will be drawn randomly (uniformly, without 
    replacement) for the count matrix.

    A symmetriced matrix and the observed amino acid frequences (from
    all sequences) are returned. Also: the actual sample size is returned.
    '''

    if verbose:
        print('Scanning sequences', file=sys.stderr)
    sequences = convert_sequences(msa)
    frequencies = compute_aa_frequencies(sequences)

    seq_pairs = list(it.combinations(sequences, 2))
    if sample_size > 0:
        n_pairs = len(seq_pairs)
        if sample_size > n_pairs:
            if verbose:
                print(f'{sample_size} samples requested, but can only provide {n_pairs} because of the number of sequences.', file=sys.stderr)
            sample_size = n_pairs
        seq_pairs = random.sample(seq_pairs, sample_size)
    else:
        sample_size = len(seq_pairs)

    if verbose:
        n_seqs = len(sequences)
        pb = ProgressBar(total=sample_size, prefix='Counting AA pairs:', suffix='', decimals=0, length=50, fill='+', zfill=' ')

    m_sum = np.zeros((20, 20))
    i=0
    for a, b in seq_pairs:
        add_to_count_matrix(m_sum, a, b)
        if verbose:
            i += 1
            progress_bar_string = pb.generate_pbar(i)
            print('\r'+progress_bar_string, end='', file=sys.stderr)
    if verbose:
        print()

    symmetric = np.add(m_sum, m_sum.T)
    return symmetric, frequencies, sample_size


def diagonalize_with_model(m, Q, freq_vec=None):
    '''
    Try to diagonalize m using the eigenvectors of Q and the equilibrium distribution.
    Return the sum of the off-diagonal elements of what was supposed to be the
    diagonal matrix.

    m:        NumPy matrix with normalised counts of amino acid replacements
    Q:        RateMatrix object for a sequence evolution model (like JTT, WAG, etc).
    freq_vec: NumPy vector with the observed amino acid frequencies in the input data.
              If left undefined (None), then only the model frequencies are attempted,
              otherwise both attempts are made.
    '''
    off_diagonal_elems = np.ones((20,20))
    np.fill_diagonal(off_diagonal_elems, 0) # Set the diagonal to zero

    L_ev, R_ev = Q.get_eigenvectors()

    right_product = np.matmul(m, R_ev)
    if type(freq_vec) == np.ndarray:
        f = freq_vec
    else:
        f = Q.get_freq()
    left = np.matmul(L_ev, np.diag(1.0/f))
    prod = np.matmul(left, right_product)
    off_diagonal = prod * off_diagonal_elems
    return np.sum(np.absolute(off_diagonal)) # Measure the deviation as sum of off-diagonal elements


def apply_models(m, freq_vec=None, q=None):
    '''
    Try eigenvectors of standard rate matrices aginst the count matrix m

    m:        NumPy matrix with normalised counts of amino acid replacements
    freq_vec: NumPy vector with the observed amino acid frequencies in the input data.
              If left undefined (None), then only the model frequencies are attempted,
              otherwise both attempts are made.
    q:        Additional model, as a NumPy matrix, to apply over the built-in models.
              Ignored if None.

    Returns a list of triples: (model name, possible adaption (like +F), score)
    '''
    matrices= list(RateMatrix.get_all_models())
    diffs = []

    if q:
        matrices.append(q)

    for Q in matrices:
        name = Q.get_name()
        try:
            diff = diagonalize_with_model(m, Q)
            diffs.append((name, '', diff)) # Triple: model name, model adaption (no adaption here), and the modelmatcher score
        except:
            print(f"Warning: could not apply the '{name}' model because of numerical problems.")

        try:
            if type(freq_vec) == np.ndarray :
                diff = diagonalize_with_model(m, Q, freq_vec)
                diffs.append((name, '+F', diff)) # Triple: model name, model adaption (+F), and the score
        except:
            print(f"Warning: could not apply the '{name}+F' model because of numerical problems.")
    return diffs


def verify_feasible_aa_freqs(aa_freqs):
    '''
    '''
    if not all(aa_freqs):
        print('Warning: There are some amino acids that have not been observed in the input MSA,', file=sys.stderr)
        print('         so +F models cannot be assessed with modelmatcher.', file=sys.stderr)
        return None
    else:
        return aa_freqs


def main(args):
    if args.list_models:
        models = RateMatrix.get_all_models()
        for m in models:
            print(m.get_name())
        sys.exit()
    if args.verbose:
        print('Reading MSA', file=sys.stderr)
    try:
        msa = read_alignment(args.infile, args.format)
    except IOError as e:
        print(e, file=sys.stderr)
        sys.exit(4)
    except ValueError as e:
        print('Error:', e, file=sys.stderr)
        sys.exit(3)
    except SyntaxError as e:
        print('Error:',e, file=sys.stderr)
        sys.exit(2)
    except FileNotFoundError:
        print(f"Error: No alignment file '{args.infile}' was found. Aborting.", file=sys.stderr)
        sys.exit(1)
        
    try:
        N, amino_acid_freqs, sample_size = msa_to_count_matrix(msa, args.sample_size, args.verbose)
    except ValueError as e:
        print(e, file=sys.stderr)
        print(f'Error occurred in {args.infile}.', file=sys.stderr)

    n_observations = int(np.sum(N))
    F = N / n_observations
    if args.verbose:
        print('Testing models', file=sys.stderr)

    if args.no_F_testing:
        amino_acid_freqs = None
    else:
        amino_acid_freqs = verify_feasible_aa_freqs(amino_acid_freqs)

    if args.model:
        with open(args.model) as h:
            extra_model = read_model(h, args.model)
        diffs = apply_models(F, amino_acid_freqs, extra_model)
    else:
        diffs = apply_models(F, amino_acid_freqs)

    sorted_analysis = sorted(diffs, key = lambda d: d[2])
    if args.output_format == 'json':
        output_json(sorted_analysis, n_observations, args.infile, len(msa), args.sample_size, sample_size)
    elif args.output_format == 'tabular':
        output_tabular(sorted_analysis)
    elif args.output_format == 'raxml':
        output_raxml_choice(sorted_analysis)
    elif args.output_format == 'iqtree':
        output_iqtree_choice(sorted_analysis)
    elif args.output_format == 'phyml':
        output_phyml_choice(sorted_analysis)
    elif args.output_format == 'mrbayes':
        output_mrbayes_choice(sorted_analysis)
    else:
        raise Exception('Bug: given output format is not implemented')


def output_json(sorted_analysis, n_observations, filename, n_seqs, requested_sample_size, actual_sample_size):
    '''
    Write the analysis in JSON format to stdout.
    '''
    model_ranking = list(map(lambda triple: (triple[0]+triple[1], triple[2]), sorted_analysis))
    simple_model_ranking = list(map(lambda triple: (triple[0], triple[2]), filter(lambda triple: triple[1]=='', sorted_analysis)))
    subsampling = False
    if requested_sample_size > 0:
        subsampling = {
            "requested_sample_size": requested_sample_size,
            "actual_sample_size": actual_sample_size
            }
    json_data = {"n_observations": n_observations,
                 "infile": filename,
                 "n_seqs": n_seqs,
                 "subsampling": subsampling,
                 "model_ranking": model_ranking,
                 "simple_model_ranking": simple_model_ranking # Without the +F models
    }
    json.dump(json_data, sys.stdout)
    print()



def output_tabular(sorted_analysis):
    '''
    Main output function: simple two-column table with model name on the left (e.g., "LG+F")
    and modelmatcher score on the right.
    '''
    for name, adaption, d in sorted_analysis:
        new_name = name+adaption
        print(f'{new_name:12} {d:>8.3f}')


def output_raxml_choice(sorted_analysis):
    '''
    Print the best model (i.e., with lowest score) to stdout.
    '''
    best = sorted_analysis[0]
    name = best[0] + best[1]    # Regular name and possible adaption (+F)
    print(name) # First element of the first-ranked tuple (which is on the format '(modelname, score)'


def output_iqtree_choice(sorted_analysis):
    '''
    Print the best model (i.e., with lowest score) to stdout.
    '''
    iqtree_models = {         # Map modelmatcher model names to IQTREE model names
    'BLOSUM62' : 'Blosum62',
    'VT': 'VT',
    'WAG':'WAG',
    'JTT-DCMUT':'JTTDCMut',
    'JTT':'JTT',
    'LG':'LG',
    'FLU':'FLU',
    'DCMUT':'DCMut',
    'Dayhoff':'Dayhoff',
    'HIVb':'HIVb',
    'cpREV':'cpREV',
    'rtREV':'rtREV',
    'HIVw':'HIVw',
    'MtZoa':'mtZOA',
    'mtArt':'mtART',
    'MtMAM':'mtMAM',
    'MtREV':'mtREV',
    'mtInv':'mtInv',
    'mtMet':'mtMet',
    'mtVer':'mtVer',
    'PMB':'PMB'

# Models not implemented in IQTREE:        
#    'mtDeu':'mtDeu',
#    'mtPro':'mtPro',

    }

    for best in sorted_analysis:
        if best[0] in iqtree_models:
            name = iqtree_models[best[0]] + best[1]    # Regular name and possible adaption (+F)
            print(name) # First element of the first-ranked tuple (which is on the format '(modelname, score)'
            return      # Return early if a model implemented in IQTREE is found
    else:
        raise Exception('Bug: we should not end up here (output_iqtree_choice)')    

def output_phyml_choice(sorted_analysis):
    '''
    Output a model choice that works for PhyML, like output_simple_choice, but make sure it can be
    used directly with PhyML:
       phyml -i infile.phy -d aa -m $(modelmatcher -of phyml infile.phy)

    PhyML does not support the +F notation, a separate option "-f e" is needed for "+F".
    Also, we ensure that the suggested model is among the ones supported by PhyML.
    '''
    # Here is a list of models implemented in PhyML, down-cased. We make sure that the selected model for PhyML
    # is found in this list.
    phyml_models = ['lg', 'wag', 'jtt', 'mtrev', 'dayhoff', 'dcmut', 'rtrev', 'cprev', 'vt', 'ab', 'blosum62', 'mtmam', 'mtart', 'hivw', 'hivb']
    restriction = list(filter(lambda triple: triple[1] == '' and triple[0].lower() in phyml_models, sorted_analysis))
    chosen_model = restriction[0]
    print(chosen_model[0])


def output_mrbayes_choice(sorted_analysis):
    '''
    Print the best model for MrBayes to stdout.
    '''
    # MrBayes models
    # Note that "jones" means "JTT". We need to map the modelmatcher names to the names used by MrBayes
    mrbayes_models={'JTT':'Jones', 'Dayhoff':'Dayhoff', 'MtREV':'Mtrev', 'MtMAM':'Mtmam', 'WAG':'Wag', 'RtREV':'Rtrev', 'cpREV':'Cprev', 'VT':'Vt', 'BLOSUM62':'Blosum', 'LG':'LG'}
    restriction = list(filter(lambda triple: triple[1] == '' and triple[0] in mrbayes_models, sorted_analysis))
    chosen_model = restriction[0]
    print(chosen_model[0])



try:
    if __name__ == '__main__':
        # For convenience, set some print options for numpy
        np.set_printoptions(precision=3)
        np.set_printoptions(suppress=True)
        np.set_printoptions(linewidth=200)

        description_str = '''
Suggest model based on an MSA without using likelihood calculations.

Example usage:
 * Get a simple table of the possible models and their modelmatcher score (lower is better):
      modelmatcher infile.fa

 * Try the builtin models as well as another model defined in a file (PAML format):
       modelmatcher -m my_model.txt infile.fa

 * Get a single choice suitable for PhyML:
       modelmatcher -of phyml infile.phy
   Note that modelmatcher is currently not supporting the "-f e" (estimated amino acid frequencies,
   i.e., +F) option in PhyML.

 * Get a single choice suitable for IQTREE:
       modelmatcher -of iqtree infile.fa

 * Run IQTREE with the best model choice from modelmatcher:
       iqtree -safe -s infile.fa -m $(modelmatcher -of iqtree infile.fa)

'''
        ap = argparse.ArgumentParser(description=description_str, formatter_class=argparse.RawDescriptionHelpFormatter)
        ap.add_argument('infile',
                        help='A multi-sequence alignment.')  # It is easy to change to handle Fasta, Clustal, Stockholm, Nexus
        ap.add_argument('-f', '--format',
                        choices=['guess', 'fasta', 'clustal', 'nexus', 'phylip', 'stockholm'],
                        default='guess',
                        help="Specify what sequence type to assume. Be specific if the file is not recognized automatically. When reading from stdin, the format is always guessed to be FASTA. Default: %(default)s")
        ap.add_argument('-m', '--model', metavar='filename',
                        help='Add the model given in the file to the comparisons.')
        ap.add_argument('-nf', '--no-F-testing', action='store_true',
                        help='Do not try +F models, i.e., do not test with amino acid frequencies estimated from the MSA.')
        ap.add_argument('-s', '--sample-size', type=int, default=0, metavar='int',
                        help='For alignments with many sequences, decide on an upper bound of sequence pairs to use from the MSA. The computational complexity grows quadratically in the number of sequences, so a choice of 5000 bounds the growth for MSAs with more than 100 sequence.')
        ap.add_argument('-of', '--output_format', default='tabular',
                        choices=['tabular', 'json', 'iqtree', 'raxml', 'phyml', 'mrbayes'],
                        help='Choose output format. Tabular format is default. JSON is for convenient later parsing, with some additional meta-data added. For one-line output convenient for immediate use by inference tools, consider raxml and similar choices. Note that the IQTREE, PhyML, and MrBayes options are restricted to their implemented models. Although PhyML supports the +F models (using the "-f e" option), this is not reflected in the output from "modelmatcher -of phyml ..." at this time.')
        ap.add_argument('--list-models', action='store_true',
                        help='Output a list of models implemented in modelmatcher, then exit.')
        ap.add_argument('--verbose', action='store_true',
                        help='Output progress information')
        ap.add_argument('--version', action='version', version="%(prog)s " + __version__)
        args = ap.parse_args()
        main(args)
except KeyboardInterrupt:
    sys.exit()
