#!python

import ConfigParser
import argparse
import os
import re
import sys
import textwrap
import traceback
import warnings
import pandas
from shutil import copyfile
from collections import OrderedDict
import framed.community.smetana as smetana
import framed
from framed.experimental.medium import minimal_medium
from framed.solvers.solver import Status



status_codes = {
    0: "Unknown",
    1: "Optimal",
    -1: "Suboptimal",
    -4: "Infeasible_or_Unbounded",
    -3: "Infeasible",
    -2: "Unbounded"
}


def row_string(row):
    return "\t".join(("" if x is None else str(x)) for x in row.itervalues()) + "\n"


class Logger(object):
    def __init__(self):
        self._log = ""

    def log(self, *args, **kwargs):
        if 'newline' in kwargs:
            newline = kwargs['newline']
            del kwargs['newline']
        else:
            newline = True

        txt = (args[0].format(*args[1:]) if len(args) > 1 else args[0])
        self._log += txt
        if newline:
            self._log += "\n"

    def flush(self, error=False):
        self._log += "=================================================================\n"
        stream = sys.stderr if error else sys.stdout
        stream.write(self._log)
        stream.flush()
        self._log = "\n"


def framed_smetana_pipeline():
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description="Run SMATANA on list of organisms communities",
                                     epilog=textwrap.dedent('''\
        configuration file (* = mandatory):
        [communities]
        
        communities*            Path to tab separated file with community description. Two columns are required, 
                                community ID (column:id) and list of community species (column:species, separator 
                                <communities_separator>)
                                
        communities_separator   Separator character used to separate community members (default: " ")         
                 
        models_dir*             Path to directory where SBML models are located. 
        
        models_locations        Path to tab separated file mapping species ID (column:id) in <communities> to 
                                SBML file (column:path) relative <models_dir> path. If not provided model ID is assumed to be 
                                SBML file name
                                
        compounds_inorganic*    Path to tab separated file with list of metabolites always present in the environment. 
                                Metabolite ID is the only required column (column:id). If not the always-present environment 
                                is assumed to be empty
                                
        compounds_exchanged*    Path to tab separated file with list of metabolites that can be exchanged between species. 
                                Two columns are required, metabolite ID (column:id) and metabolite name (column:name) 
                
        [framed]
        
        flavor                  SBML flavor to use when parsing the models. Allowed values are 'cobra', 'cobra:other', 
                                'seed', 'bigg' and 'fbc2'
                                
        biomass_regex           Regular expression to match biomass reaction (default: "iomass")
        
        extracellular_comp_id   Extracellular compartment id (default: Extracellular)
        
        exch_met_to_rxn_format  Conversion string to convert exchange metabolites to exchange reaction identifiers.
                                Note: this is evaluated as a python format string, so it must be a nested string.
                                (Default value, appropriate for BiGG ids: "'R_EX_{}_pool'")
        
        smetana_n_solutions     How many unique solutions to calculate for Metabolite Uptake and Species Coupling scores (default: 50)
        
        max_uptake              How much of uptake/exchange should be allowed for active exchange reactions (default: 1000)
        
        min_growth              Minimal growth (default: 0.1)
        
        min_mass_weight         Calculate minimal media using mass weight minimization (default: False)
             
        algorithms              Comma-separated list of algorithms to run. Available options: smetana, muscore, mpscore,
                                scscore, fba, mip, mro
    
    '''))

    parser.add_argument('config', help='Path to configuration file (see epilog)')
    parser.add_argument('output', nargs="?",
                        help='Path to output file (have to include {i} and {total} placeholders. example: "results/output_{i}_{total}.tsv")')
    parser.add_argument('--part', dest='part', type=int, help='Which part is calculated (start from 1)', default=1)
    parser.add_argument('--parts-total', dest='parts_total', type=int, help='Total number of parts', default=1)
    parser.add_argument('--no-warnings', dest='no_warnings', action="store_true", help="Suppress warnings")
    args = parser.parse_args()

    if os.path.dirname(args.output) and not os.path.exists(os.path.dirname(args.output)):
        os.makedirs(os.path.dirname(args.output))

    if args.no_warnings:
        warnings.simplefilter("ignore", UserWarning)

    #
    # Load configuration
    #
    logger = Logger()
    config = ConfigParser.ConfigParser()
    config.read(args.config)
    communities_sep = config.get("communities", "communities_separator").strip('"') if config.has_option("communities", "communities_separator") else ","
    flavor = config.get("framed", "flavor") if config.has_option("framed", "flavor") else "fbc2"
    n_solutions = config.getint("framed", "smetana_n_solutions") if config.has_option("framed", "smetana_n_solutions") else 50
    min_mass_weight = config.getint("framed", "min_mass_weight") if config.has_option("framed", "min_mass_weight") else False
    min_growth = config.getfloat("framed", "min_growth") if config.has_option("framed", "min_growth") else 0.1


    algorithms = config.get("framed", "algorithms") if config.has_option("framed", "algorithms") else None
    all_algorithms = {"fba", "mip", "mro", "smetana", "mpscore", "scscore", "muscore"}
    if algorithms is None:
        algorithms = all_algorithms
    else:
        algorithms = set(x.strip() for x in re.split(",", algorithms))
        if "smetana" in algorithms:
            algorithms = (algorithms | {"mpscore", "scscore", "muscore"}) - {"smetana"}
        unknown_algorithmms = algorithms - all_algorithms
        if len(unknown_algorithmms) > 0:
            raise NameError("Unknown algorithm in section [framed/algorithms]: {}".format(",".join(unknown_algorithmms)))

    max_uptake = config.getfloat("framed", "max_uptake") if config.has_option("framed", "max_uptake") else 1000.0
    re_biomass = re.compile(config.get("framed", "biomass_regex")) if config.has_option("framed", "biomass_regex") else re.compile("iomass")
    extracellular_comp_id = config.get("framed", "extracellular_comp_id") if config.has_option("framed", "extracellular_comp_id") else "Extracellular"
    exch_met_to_rxn_format = config.get("framed", "exch_met_to_rxn_format") if config.has_option("framed", "exch_met_to_rxn_format") else "'R_EX_{}_pool'"
    models_dir = config.get("communities", "models_dir")
    if config.has_option("communities", "models_locations"):
        models_locations = {r['id']: r['path'] for _, r in
                            pandas.read_table(config.get("communities", "models_locations"), comment="#").iterrows()}
    else:
        models_locations = {}

    compounds_exchanged = {r['id']: r['name'] for _, r in pandas.read_table(config.get("communities", "compounds_exchanged"), comment="#").iterrows()}
    compounds_exchanged = OrderedDict(sorted(compounds_exchanged.iteritems(), key=lambda x: x[0]))
    compounds_external = {r['id']: r['name'] for _, r in pandas.read_table(config.get("communities", "compounds_external"), comment="#").iterrows()}
    compounds_external = OrderedDict(sorted(compounds_external.iteritems(), key=lambda x: x[0]))
    env_external = framed.Environment.from_compounds(compounds_external.keys(), exchange_format=exch_met_to_rxn_format, max_uptake=max_uptake)
    rxn_external = set(env_external)

    compounds_inorganic = {r['id']: r['name'] for _, r in pandas.read_table(config.get("communities", "compounds_inorganic"), comment="#").iterrows()}
    compounds_inorganic = OrderedDict(sorted(compounds_inorganic.iteritems(), key=lambda x: x[0]))
    env_inorganic = framed.Environment.from_compounds(compounds_inorganic.keys(), exchange_format=exch_met_to_rxn_format, max_uptake=max_uptake)
    rxn_inorganic = set(env_inorganic)

    compounds_exchanged_blacklist = set()
    if config.has_option("communities", "compounds_exchanged_blacklist"):
        compounds_exchanged_blacklist = {r['id'] for _, r in pandas.read_table(config.get("communities", "compounds_exchanged_blacklist"), comment="#").iterrows()}

    for met in compounds_exchanged_blacklist:
        del compounds_exchanged[met]

    #
    # Load communities descriptions
    #
    communities = []
    for _, r in pandas.read_table(config.get("communities", "communities"), comment="#").iterrows():
        community_models = [m_id for m_id in re.split(communities_sep, r['species'].strip())]
        community_model_paths = []
        missing_models = []
        for m_id in community_models:
            if models_locations:
                m_file = models_locations.get(m_id, m_id)
                m_path = os.path.join(models_dir, m_file)
                if m_id not in models_locations or not os.path.exists(m_path):
                    missing_models.append(m_id)

                community_model_paths.append(m_path)

        communities.append({'id': r['id'], 'species': community_models, 'paths': community_model_paths,
                            'missing_models': missing_models})
    max_community_size = max(len(com["species"]) for com in communities)

    #
    # Prepare template for results table and header
    #
    row_template = OrderedDict([("community_id", ""), ("size", "")])
    if "mip" in algorithms:
        row_template["mip"] = ""
    if "mro" in algorithms:
        row_template["mro"] = ""
    if "fba" in algorithms:
        row_template["fba_status"] = ""
        row_template["fba_community"] = ""
    row_template["mmedia_status"] = ""

    for i in range(1, max_community_size + 1):
        row_template["org{}".format(i)] = ""

    if "fba" in algorithms:
        for i in range(1, max_community_size + 1):
            row_template["fba.org{}".format(i)] = ""

    for m_id in compounds_external:
        row_template["mmedia.{}".format(m_id)] = 0

    if "muscore" in algorithms:
        for i in range(1, max_community_size + 1):
            for m_id in compounds_exchanged:
                row_template["muscore.org{}.{}".format(i, m_id)] = 0

    if "mpscore" in algorithms:
        for i in range(1, max_community_size + 1):
            for m_id in compounds_exchanged:
                row_template["mpscore.org{}.{}".format(i, m_id)] = 0

    if "scscore" in algorithms:
        for i in range(1, max_community_size + 1):
            for j in range(1, max_community_size + 1):
                row_template["scscore.org{}.org{}".format(i, j)] = 0

    communities_sample = [communities[i::args.parts_total] for i in xrange(args.parts_total)][args.part - 1]

    # 3_program_validate.prof
    validate = False # validate MILP solutions

    #
    # Main loop
    #
    with open(args.output.format(i=args.part, total=args.parts_total), "w") as output:
        # Write header
        missing_communities = [m['id'] for m in communities if len(m['missing_models'])]
        missing_communities_sample = [m['id'] for m in communities_sample if len(m['missing_models'])]
        output.write("\t".join(row_template) + "\n")
        logger.log("Total number of communities: {} (sample: {})", len(communities), len(communities_sample))
        logger.log("Missing models: {} (sample: {}): {}", len(missing_communities), len(missing_communities_sample),
                   ", ".join(missing_communities_sample))
        logger.flush()

        for community_i, community_data in enumerate(communities_sample, start=1):
            try:
                row = row_template.copy()
                community_all_i = next(
                    c_i for c_i, c in enumerate(communities, start=1) if c['id'] == community_data['id'])

                logger.log("{}/{:<5} (all: {}/{:<5}) [{}]: {}", community_i, len(communities_sample), community_all_i,
                           len(communities), community_data['id'], ", ".join(community_data["species"]))

                if len(community_data['missing_models']):
                    logger.log("Skiping because missing models: {}".format(", ".join(community_data['missing_models'])),
                               newline=True)
                    logger.flush(error=True)
                    continue
                #
                # Read SBML files representing community organisms
                #
                models = []
                models_i = {}
                biomass_valid = True
                for i, path in enumerate(community_data['paths'], start=1):
                    logger.log("Reading {}:'{}' SBML file...".format(i, path), newline=True)
                    model = framed.load_cbmodel(path, flavor=flavor)
                    model.id = re.sub("[^A-Za-z0-9]", "_", community_data['species'][i - 1])
                    biomass_candidates = [r for r in model.reactions if re_biomass.match(r)]
                    if len(biomass_candidates) != 1:
                        biomass_valid = False
                        break

                    model.biomass_reaction = biomass_candidates[0]
                    models.append(model)
                    models_i[model.id] = i

                if not biomass_valid:
                    if len(biomass_candidates) == 0:
                        logger.log("Biomass not found!")
                    elif len(biomass_candidates) > 1:
                        logger.log("Multiple biomass candidates found: {}".format(", ".join(biomass_candidates)))

                    logger.flush(error=True)
                    continue

                #
                # Print community description
                #
                community_exchange_compounds = {m for ml in models
                                                for exch_id in ml.get_exchange_reactions()
                                                for m in ml.reactions[exch_id].stoichiometry}

                community_exchange_blacklist = community_exchange_compounds - set(compounds_exchanged)
                # TODO: 4_program_initCopy.prof
                community_interacting = framed.Community(community_data['id'], models, extracellular_compartment_id=extracellular_comp_id,
                                      create_biomass=True, interacting=True, copy_models=False,
                                      exchanged_metabolites_blacklist=community_exchange_blacklist)

                rxn_com_inorganic = set(community_interacting.merged.get_exchange_reactions()) & rxn_inorganic
                rxn_com_external = set(community_interacting.merged.get_exchange_reactions()) & rxn_external
                logger.log("Inorganic media size: {}", len(rxn_com_inorganic))
                for model_i, model_id in enumerate(community_data['species'], start=1):
                    row["org{}".format(model_i)] = model_id
                row['community_id'] = community_data['id']
                row['size'] = len(models)

                #
                # Calculate minimal media for community. Predefined inorganic compounds are always present !
                #
                env_inorganic.apply(community_interacting.merged, inplace=True)

                rxn_com_minimal_candidates = rxn_com_external - rxn_com_inorganic
                rxn_com_minimal, sol = minimal_medium(community_interacting.merged, exchange_reactions=rxn_com_minimal_candidates, validate=validate, min_mass_weight=min_mass_weight)
                rxn_com_minimal = rxn_com_minimal | rxn_com_inorganic
                env_com_minimal = framed.Environment.from_reactions(rxn_com_minimal, max_uptake=max_uptake)
                met_com_minimal = {m.original_metabolite for ee in
                                   community_interacting.organisms_exchange_reactions.itervalues() for m in
                                   ee.itervalues() if m.community_exchange_reaction in rxn_com_minimal}

                row["mmedia_status"] = status_codes[sol.status]
                if sol.status != Status.OPTIMAL:
                    logger.log("Failed to find media for community")
                    logger.flush(error=True)
                    output.write(row_string(row))
                    output.flush()
                    continue

                env_com_minimal.apply(community_interacting.merged, inplace=True)

                logger.log("\nInteracting community media ({}'inorganic + {}'minimal + {}'external):\n{}",
                           len(rxn_com_inorganic),
                           len(rxn_com_minimal - rxn_com_inorganic),
                           len(rxn_com_external - rxn_com_minimal - rxn_com_inorganic), "-"*60)

                for r_i, r in enumerate(rxn_com_minimal, start=1):
                    m = next(m for m in community_interacting.merged.reactions[r].stoichiometry)

                    orig_m = list({org_m.original_metabolite
                                   for org_exch in community_interacting.organisms_exchange_reactions.itervalues()
                                   for org_m in org_exch.itervalues()
                                   if org_m.extracellular_metabolite == m})
                    if len(orig_m):
                        orig_m = orig_m[0]
                    else:
                        pass

                    row_id = "mmedia.{}".format(orig_m)
                    if row_id not in row:
                        raise KeyError("Column '{}' was not found in row".format(row_id))
                    row[row_id] = 1 if orig_m in compounds_inorganic else 2

                    logger.log("    {}{}", compounds_external[orig_m], "" if r in rxn_com_inorganic else " (added to minimal media)")

                #
                # Calculate MIP and MRO
                #
                if "mip" in algorithms:
                    mip, mip_extras = smetana.mip_score(community_interacting, env_inorganic, validate=validate)
                    row["mip"] = mip
                    logger.log("MIP: {}", mip)


                if "mro" in algorithms:
                    mro, mro_extras = smetana.mro_score(community_interacting, env_inorganic, validate=validate)
                    row["mro"] = mro
                    logger.log("MRO: {}", mro)

                if "fba" in algorithms:
                    community_fba = community_interacting.copy(create_biomass=False, interacting=False)
                    env_com_minimal.apply(community_fba.merged, inplace=True)

                    sol = framed.FBA(community_fba.merged, get_values=True, get_shadow_prices=True)
                    row["fba_status"] = status_codes[sol.status]
                    logger.log("Growth on minimal media (uptakes 1.0/org): {:.1f} ({}) = ",
                               sol.fobj, status_codes[sol.status], newline=False)

                    if sol.status == Status.OPTIMAL:
                        row["fba_community"] = sol.fobj
                        for m_i, model_id in enumerate(community_fba.organisms, start=1):
                            b = community_fba.organisms_biomass_reactions[model_id]
                            logger.log("{:.1f}'{}{}", sol.values[b], community_data['species'][m_i - 1], " + " if m_i < len(community_fba.organisms) else "",
                                       newline=(len(community_fba.organisms) == m_i))
                            row_id = "fba.org{}".format(models_i[model_id])
                            if row_id not in row:
                                raise KeyError("Column '{}' was not found in row".format(row_id))
                            row[row_id] = sol.values[b]

                #
                # Apply minimal+inorganic media
                #
                env_com_minimal.apply(community_interacting.merged, inplace=True)

                #
                # Calculate SMETANA
                #
                if "scscore" in algorithms:
                    scscores, scextras = smetana.species_coupling_score(community_interacting, environment=env_com_minimal,
                                                                min_growth=min_growth,
                                                                max_uptake=max_uptake, n_solutions=n_solutions)
                    logger.log("Species Coupling Scores:\n{}", "-"*60)
                    for m_receiver, donors in scscores.iteritems():
                        if len(donors):
                            donors_str = ["{}'{:.0f}%".format(m_donor, 100*freq) for m_donor, freq in donors.iteritems()]
                            logger.log("    {} <= {}".format(m_receiver, ",".join(donors_str)))
                        for m_donor, scscore in donors.iteritems():
                            row_id = "scscore.org{}.org{}".format(models_i[m_receiver], models_i[m_donor])
                            if row_id not in row:
                                raise KeyError("Column '{}' was not found in row".format(row_id))
                            row[row_id] = scscore

                if "mpscore" in algorithms:
                    mpscores, mpextras = smetana.metabolite_production_score(community_interacting, env_com_minimal, min_growth=min_growth)
                    logger.log("\nMetabolite Production Scores:\n{}", "-"*60)
                    for model_id, metabolites in mpscores.iteritems():
                        if len(metabolites):
                            logger.log("    {} => {}".format(model_id, ",".join([compounds_exchanged.get(m, m) for m in metabolites])))
                        for m_id in metabolites:
                            row_id = "mpscore.org{}.{}".format(models_i[model_id], m_id)
                            if row_id not in row:
                                raise KeyError("Column '{}' was not found in row".format(row_id))
                            row[row_id] = 1

                if "muscore" in algorithms:
                    logger.log("\nMetabolite Uptake Scores:\n{}", "-"*60)
                    muscores, muextras = smetana.metabolite_uptake_score(community_interacting, env_com_minimal,
                                                                 min_mass_weight=min_mass_weight, min_growth=min_growth,
                                                                 max_uptake=max_uptake,
                                                                 n_solutions=n_solutions)
                    for model_id, metabolites in muscores.iteritems():
                        if len(metabolites):
                            logger.log("    {} <= {}".format(model_id, ", ".join(["{}'{:.0f}%".format(compounds_exchanged.get(m, m), 100*freq) for m, freq in metabolites.iteritems()])))
                        for m_id, muscore in metabolites.iteritems():
                            row_id = "muscore.org{}.{}".format(models_i[model_id], m_id)
                            if row_id not in row:
                                raise KeyError("Column '{}' was not found in row".format(row_id))
                            row[row_id] = muscore

                if "scscore" in algorithms and "mpscore" in algorithms and "muscore" in algorithms:
                    smetana_score = smetana.calculate_smetana_score(community=community_interacting, scscores=scscores, mpscores=mpscores, muscores=muscores, report_zero_scores=False)
                    logger.log("\nSmetana Scores (sum={}):\n{}", sum(s.score for s in smetana_score), "-"*60)
                    if smetana_score:
                        for s in smetana_score:
                            logger.log("    {}/{}/{} => {:.0f}%", s.donor_organism, compounds_exchanged.get(s.metabolite, s.metabolite),
                                s.receiver_organism, 100*s.score)

                output.write(row_string(row))
                output.flush()
                logger.flush(error=False)
            except:
                logger.log("-------------------- START_ERROR --------------------")
                logger.log(traceback.format_exc())
                logger.log("-------------------- END_ERROR --------------------")
                logger.flush(error=True)

    out_config = os.path.join(os.path.dirname(args.output), "config.cfg")
    if not os.path.exists(out_config):
        copyfile(args.config, out_config)


if __name__ == "__main__":
    framed_smetana_pipeline()
