#!/usr/bin/env python

import argparse
import pandas as pd
import numpy as np
import json
import sys
import os
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import logging

from miner2 import coexpression, preprocess, mechanistic_inference as mechinf, miner
from miner2 import GIT_SHA
from miner2 import __version__ as MINER_VERSION
from miner2 import util

MIN_REGULON_GENES = 5
DESCRIPTION = """miner-causalinference - MINER causal inference.
MINER Version %s (Git SHA %s)""" % (str(MINER_VERSION).replace('miner2 ', ''),
                                    GIT_SHA.replace('$Id: ', '').replace(' $', ''))

if __name__ == '__main__':
    LOG_FORMAT = '%(asctime)s %(message)s'
    logging.basicConfig(format=LOG_FORMAT, level=logging.DEBUG,
                        datefmt='%Y-%m-%d %H:%M:%S \t')

    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description=DESCRIPTION)
    parser.add_argument('expfile', help="input matrix")
    parser.add_argument('mapfile', help="identifier mapping file")
    parser.add_argument('coreg', help="coregulationModules.json file from miner-mechinf")
    parser.add_argument('coher', help="coherentMembers.csv file from miner-bcmembers")

    parser.add_argument('cmfile', help="common mutations file")
    parser.add_argument('tlfile', help="translocations file")
    parser.add_argument('cgfile', help="cytogenetics file")
    parser.add_argument('outdir', help="output directory")

    args = parser.parse_args()

    if not os.path.exists(args.expfile):
        sys.exit("expression file not found")
    if not os.path.exists(args.mapfile):
        sys.exit("identifier mapping file not found")
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    exp_data, conv_table = preprocess.main(args.expfile, args.mapfile)
    with open(args.coreg) as infile:
        coregulation_modules = json.load(infile)
    #regulon_df = pd.read_csv(args.regulondf, index_col=0, header=0)
    regulons = mechinf.get_regulons(coregulation_modules,
                                    min_number_genes=MIN_REGULON_GENES,
                                    freq_threshold=0.333)
    regulon_modules, regulon_df = mechinf.get_regulon_dictionary(regulons)
    coherent_samples_matrix = pd.read_csv(args.coher, index_col=0, header=0)

    common_mutations = pd.read_csv(args.cmfile, index_col=0, header=0)
    translocations = pd.read_csv(args.tlfile, index_col=0, header=0)
    cytogenetics = pd.read_csv(args.cgfile, index_col=0, header=0)


    eigengenes = miner.getEigengenes(regulon_modules, exp_data,
                                     regulon_dict=None, saveFolder=None)
    eigen_scale = np.percentile(exp_data,95) / np.percentile(eigengenes, 95)
    eigengenes = eigen_scale * eigengenes
    eigengenes.index = np.array(eigengenes.index).astype(str)

    # Perform causal analysis for each mutation matrix
    result_dir = os.path.join(args.outdir, "causal_analysis")
    miner.causalNetworkAnalysis(regulon_matrix=regulon_df,
                                expression_matrix=exp_data,
                                reference_matrix=eigengenes,
                                mutation_matrix=common_mutations,
                                resultsDirectory=result_dir,
                                minRegulons=1,
                                significance_threshold=0.05,
                                causalFolder="causal_results_common_mutations")

    miner.causalNetworkAnalysis(regulon_matrix=regulon_df,
                                expression_matrix=exp_data,
                                reference_matrix=eigengenes,
                                mutation_matrix=translocations,
                                resultsDirectory=result_dir,
                                minRegulons=1,
                                significance_threshold=0.05,
                                causalFolder="causal_results_translocations")

    miner.causalNetworkAnalysis(regulon_matrix=regulon_df,
                                expression_matrix=exp_data,
                                reference_matrix=eigengenes,
                                mutation_matrix=cytogenetics,
                                resultsDirectory=result_dir,
                                minRegulons=1,
                                significance_threshold=0.05,
                                causalFolder="causal_results_cytogenetics")

    # compile all causal results
    causal_results = miner.readCausalFiles(result_dir)
    causal_results.to_csv(os.path.join(args.outdir, "completeCausalResults.csv"))

    wire_diagram_out = os.path.join(args.outdir, 'wiring_diagram.csv')
    wire_diagram = miner.wiringDiagram(causal_results, regulon_modules,
                                       coherent_samples_matrix,
                                       include_genes=False,
                                       savefile=wire_diagram_out)


    # Generate Filtered Causal Flows
    causal_results_regulon_filtered = causal_results[causal_results["-log10(p)_Regulon_stratification"]>=-np.log10(0.05)]
    causal_results_aligned = causal_results_regulon_filtered[causal_results_regulon_filtered.Fraction_of_edges_correctly_aligned>=0.5]
    causal_results_aligned_correlated = causal_results_aligned[causal_results_aligned["RegulatorRegulon_Spearman_p-value"]<=0.05]
    causal_results_stratified_aligned_correlated = causal_results_aligned_correlated[causal_results_aligned_correlated["-log10(p)_MutationRegulatorEdge"]>=-np.log10(0.05)]

    # for all causal flows, 
    # the regulon is differentially active w.r.t the mutation,
    # the regulator is differentially active w.r.t the mutation,
    # the regulator is significantly correlated to the regulon,
    # and the directionality of at least half of the differentially active targets 
    # downstream of the regulator are consistent with the perturbation from the mutation
    causal_results_stratified_aligned_correlated.to_csv(os.path.join(args.outdir,
                                                                     "filteredCausalResults.csv"))
