#!/usr/bin/env python3

import argparse
import json
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn import tree
from sklearn.externals.six import StringIO
import pydot
import seaborn as sns
from graphviz import Source

# xgboost requires different plotting method
from xgboost import plot_tree

import pickle

from miner2 import risk_predict, survival, util, preprocess, miner
from miner2 import GIT_SHA, __version__ as pkg_version

DESCRIPTION = """miner-riskpredict - MINER compute risk prediction.
MINER Version %s (Git SHA %s)""" % (pkg_version, GIT_SHA.replace('$Id: ', '').replace(' $', ''))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description=DESCRIPTION)
    parser.add_argument('input', help="input specification file")
    parser.add_argument('outdir', help="output directory")
    parser.add_argument('--method', default="decisionTree", help="predictor method (decisionTree or xgboost)")

    args = parser.parse_args()

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    with open(os.path.join(args.outdir, 'run_info.txt'), 'w') as outfile:
        util.write_dependency_infos(outfile)

    with open(args.input) as infile:
        input_spec = json.load(infile)


    exp_data, conv_table = preprocess.main(input_spec['exp'], input_spec['idmap'])

    common_mutations = pd.read_csv(input_spec['common_mutations'], index_col=0, header=0)
    translocations = pd.read_csv(input_spec['translocations'], index_col=0, header=0)
    cytogenetics = pd.read_csv(input_spec['cytogenetics'], index_col=0, header=0)
    cytogenetics = cytogenetics.loc[:,list(set(cytogenetics.columns)&set(exp_data.columns))]
    common_patients_mutations_translocations = list(set(translocations.columns) & set(common_mutations.columns))
    merged_mutations = pd.concat([common_mutations.loc[:,common_patients_mutations_translocations],
                                  translocations.loc[:,common_patients_mutations_translocations]],
                                 axis=0)



    with open(input_spec['coexpression_dictionary']) as infile:
        revised_clusters = json.load(infile)
    with open(input_spec['coexpression_modules']) as infile:
        coexpression_modules = json.load(infile)
    with open(input_spec['regulon_modules']) as infile:
        regulon_modules = json.load(infile)
    with open(input_spec['mechanistic_output']) as infile:
        mechanistic_output = json.load(infile)

    regulon_df = pd.read_csv(input_spec['regulon_df'], index_col=0, header=0)
    oem_matrix = pd.read_csv(input_spec['overexpressed_members'], index_col=0, header=0)
    oem_matrix.index = np.array(oem_matrix.index).astype(str)

    uem_matrix = pd.read_csv(input_spec['underexpressed_members'], index_col=0, header=0)
    uem_matrix.index = np.array(uem_matrix.index).astype(str)

    eigengenes = pd.read_csv(input_spec['eigengenes'], index_col=0, header=0)
    eigengenes.index = np.array(eigengenes.index).astype(str)

    filtered_causal_results = pd.read_csv(input_spec['filtered_causal_results'],
                                          index_col=0, header=0)

    with open(input_spec['transcriptional_programs']) as infile:
        transcriptional_programs = json.load(infile)
        program_list = [transcriptional_programs[str(key)]
                        for key in range(len(transcriptional_programs.keys()))]

    with open(input_spec['transcriptional_states']) as infile:
        transcriptional_states = json.load(infile)
        states_list = [transcriptional_states[str(key)]
                       for key in range(len(transcriptional_states.keys()))]
        final_state = list(set(exp_data.columns) - set(np.hstack(states_list)))
        if len(final_state) > 0:
            states_list.append(final_state)

        states = list(states_list)
        diff_matrix_mmrf = oem_matrix - uem_matrix

    survival_mmrf = pd.read_csv(input_spec['primary_survival_data'], index_col=0, header=0)
    survival_df_mmrf = survival_mmrf.iloc[:,0:2]
    survival_df_mmrf.columns = ["duration", "observed"]
    oem_matrix_mmrf = oem_matrix
    km_df_mmrf = miner.kmAnalysis(survivalDf=survival_df_mmrf,
                                  durationCol="duration",
                                  statusCol="observed")
    guanSurvivalDfMMRF = miner.guanRank(kmSurvival=km_df_mmrf)



    # Figures

    plt.figure(figsize=(8,8))
    plt.imshow(diff_matrix_mmrf.loc[np.hstack(program_list),
                                    np.hstack(states_list)],
               cmap="bwr", vmin=-1, vmax=1, aspect=0.275)
    plt.grid(False)
    plt.savefig(os.path.join(args.outdir, "regulon_activity_heatmap.pdf"),
                bbox_inches="tight")

    # Survival analysis of regulons
    cox_regulons_output = miner.parallelMedianSurvivalAnalysis(regulon_modules,
                                                               exp_data,
                                                               guanSurvivalDfMMRF,
                                                               numCores=5)
    cox_regulons_output = cox_regulons_output.iloc[np.argsort(np.array(cox_regulons_output.index).astype(int))]

    cox_regulons_output.to_csv(os.path.join(args.outdir, 'CoxProportionalHazardsRegulons.csv'))
    cox_regulons_output.sort_values(by="HR", ascending=False, inplace=True)

    print("\nHigh-risk programs:")
    print(cox_regulons_output.iloc[0:5,:])
    print("\nLow-risk programs")
    print(cox_regulons_output.iloc[-5:,:])

    # Survival analysis of transcriptional programs
    # Create dictionary of program genes
    # make dictionary of genes by program
    pr_genes = {}
    for i in range(len(program_list)):
        rgns = program_list[i]
        genes = []
        for r in rgns:
            genes.append(regulon_modules[r])
        genes = list(set(np.hstack(genes)))
        pr_genes[i] = genes

    cox_programs_output = miner.parallelMedianSurvivalAnalysis(pr_genes, exp_data,
                                                               guanSurvivalDfMMRF,numCores=5)
    cox_programs_output = cox_programs_output.iloc[np.argsort(np.array(cox_programs_output.index).astype(int))]
    cox_programs_output.to_csv(os.path.join(args.outdir,
                                            'CoxProportionalHazardsPrograms.csv'))
    cox_programs_output.sort_values(by="HR", ascending=False, inplace=True)

    print("\nHigh-risk programs:")
    print(cox_programs_output.iloc[0:5,:])
    print("\nLow-risk programs")
    print(cox_programs_output.iloc[-5:,:])

    # Kaplan-Meier plot of all programs (median expression)
    srv = guanSurvivalDfMMRF.copy()
    keys = list(pr_genes.keys())

    plt.figure(figsize=(8,8))
    plt.style.use('seaborn-whitegrid')
    plt.xlim(-100,2000)

    for key in keys:
        cluster = np.array(exp_data.loc[pr_genes[key],:])
        median_ = np.mean(cluster,axis=0)
        threshold = np.percentile(median_,85)
        median_[median_>=threshold] = 1
        median_[median_<threshold] = 0
        membership_df = pd.DataFrame(median_)
        membership_df.index = exp_data.columns
        membership_df.columns = [key]

        cox_hr, cox_p = miner.survivalMembershipAnalysisDirect(membership_df, guanSurvivalDfMMRF)

        groups = [membership_df.index[np.where(membership_df[key]==1)[0]]]
        labels = ["{0}: {1:.2f}".format(str(key), cox_hr)]
        miner.kmplot(srv=srv, groups=groups, labels=labels, xlim_=(-100,1750),
                     filename=None, lw=2, color="gray", alpha=0.3)

    key_min = cox_programs_output.index[0]
    key_max = cox_programs_output.index[-1]

    cluster = np.array(exp_data.loc[pr_genes[key_min],:])
    median_ = np.mean(cluster, axis=0)
    threshold = np.percentile(median_,85)
    median_[median_>=threshold] = 1
    median_[median_<threshold] = 0
    membership_df = pd.DataFrame(median_)
    membership_df.index = exp_data.columns
    membership_df.columns = [key_min]

    cox_hr, cox_p = miner.survivalMembershipAnalysisDirect(membership_df, guanSurvivalDfMMRF)

    groups = [membership_df.index[np.where(membership_df[key_min]==1)[0]]]
    labels = ["{0}: {1:.2f}".format(str(key), cox_hr)]

    miner.kmplot(srv=srv,groups=groups,labels=labels,xlim_=(-100,1750),
                 filename=None,lw=2,color="blue",alpha=1)

    cluster = np.array(exp_data.loc[pr_genes[key_max],:])
    median_ = np.mean(cluster,axis=0)
    threshold = np.percentile(median_,85)
    median_[median_>=threshold] = 1
    median_[median_<threshold] = 0
    membership_df = pd.DataFrame(median_)
    membership_df.index = exp_data.columns
    membership_df.columns = [key_max]

    cox_hr, cox_p = miner.survivalMembershipAnalysisDirect(membership_df, guanSurvivalDfMMRF)

    groups = [membership_df.index[np.where(membership_df[key_max]==1)[0]]]
    labels = ["{0}: {1:.2f}".format(str(key), cox_hr)]

    miner.kmplot(srv=srv, groups=groups, labels=labels, xlim_=(-100,1750),
                 filename=None, lw=2, color="red",alpha=1)

    plt.savefig(os.path.join(args.outdir,"kmplots_programs.pdf"), bbox_inches="tight")

    # Survival analysis of transcriptional states
    # Kaplan-Meier plot of all programs (median expression)
    plt.figure(figsize=(8,8))
    plt.style.use('seaborn-whitegrid')
    plt.xlim(-100,2000)

    srv = guanSurvivalDfMMRF.copy()
    for key in range(len(states_list)):
        median_df = pd.DataFrame(np.zeros(exp_data.shape[1]))
        median_df.index = exp_data.columns
        median_df.columns = [key]
        median_df.loc[states_list[key],key] = 1

        cox_hr, cox_p = miner.survivalMembershipAnalysisDirect(median_df, guanSurvivalDfMMRF)

        groups = [median_df.index[np.where(median_df[key]==1)[0]]]
        labels = ["{0}: {1:.2f}".format(str(key), cox_hr)]

        miner.kmplot(srv=srv,groups=groups,labels=labels,xlim_=(-100,1750),
                     filename=None,lw=2,color="gray",alpha=0.3)

    highlight_list = [np.hstack([states_list[4],states_list[5],states_list[16],states_list[23]])]

    for key in range(len(highlight_list)):
        median_df = pd.DataFrame(np.zeros(exp_data.shape[1]))
        median_df.index = exp_data.columns
        median_df.columns = [key]
        median_df.loc[highlight_list[key],key] = 1

        cox_hr, cox_p = miner.survivalMembershipAnalysisDirect(median_df, guanSurvivalDfMMRF)

        groups = [median_df.index[np.where(median_df[key]==1)[0]]]
        labels = ["{0}: {1:.2f}".format(str(key), cox_hr)]

        miner.kmplot(srv=srv,groups=groups,labels=labels,xlim_=(-100,1750),
                     filename=None,lw=2,color="red",alpha=1)

    # Combine high-risk states
    highlight_list = [np.hstack([states_list[10],states_list[14],states_list[21]])]

    for key in range(len(highlight_list)):
        median_df = pd.DataFrame(np.zeros(exp_data.shape[1]))
        median_df.index = exp_data.columns
        median_df.columns = [key]
        median_df.loc[highlight_list[key],key] = 1

        cox_hr, cox_p = miner.survivalMembershipAnalysisDirect(median_df, guanSurvivalDfMMRF)

        groups = [median_df.index[np.where(median_df[key]==1)[0]]]
        labels = ["{0}: {1:.2f}".format(str(key), cox_hr)]

        miner.kmplot(srv=srv,groups=groups,labels=labels,xlim_=(-100,1750),
                     filename=None,lw=2,color="blue",alpha=1)

    plt.savefig(os.path.join(args.outdir, "kmplots_states.pdf"), bbox_inches="tight")

    # Generate boxplot data for transcriptional states
    survival_patients = list(guanSurvivalDfMMRF.index)
    t414_patients = translocations.columns[
        np.where(translocations.loc["RNASeq_WHSC1_Call",:]==1)[0]
    ]
    t1114_patients = translocations.columns[
        np.where(translocations.loc["RNASeq_CCND1_Call",:]==1)[0]
    ]

    min_patients = 5
    ranks = []
    boxplot_data = []
    boxplot_names = []
    boxplot_samples = []
    boxplot_labels = []
    percent_t414 = []
    percent_t1114 = []
    for key in range(len(states_list)):
        state = states_list[key]
        overlap_patients = list(set(survival_patients) & set(state))
        if len(overlap_patients) < min_patients:
            continue
        guan_data = list(guanSurvivalDfMMRF.loc[overlap_patients,"GuanScore"])
        boxplot_samples.append(overlap_patients)
        boxplot_data.append(guan_data)
        boxplot_names.append([1+key for i in range(len(overlap_patients))])
        risk_score = np.median(guan_data)
        ranks.append(risk_score)

        t414_overlap = len(set(state)&set(t414_patients))
        t1114_overlap = len(set(state)&set(t1114_patients))
        percent_t414_ = float(t414_overlap)/len(state)
        percent_t1114_ = float(t1114_overlap)/len(state)
        percent_t414.append(percent_t414_)
        percent_t1114.append(percent_t1114_)

    labels = np.hstack(np.array(boxplot_names)[np.argsort(ranks)])
    labels_df = pd.DataFrame(labels)
    labels_df.index = np.hstack(np.array(boxplot_samples)[np.argsort(ranks)])
    labels_df.columns = ["label"]
    plot_data = pd.concat([guanSurvivalDfMMRF.loc[labels_df.index,"GuanScore"],labels_df],axis=1)

    rank_order = np.array(list(set(np.hstack(boxplot_names))))[np.argsort(ranks)]
    ranked_t414 = np.array(percent_t414)[np.argsort(ranks)]
    ranked_t1114 = np.array(percent_t1114)[np.argsort(ranks)]

    # Violin plots by states
    f, ax = plt.subplots(figsize=(12, 2))

    # Plot the orbital period with horizontal boxes
    sns.violinplot(x="label", y="GuanScore", data=plot_data, fliersize=0,
                palette="coolwarm",order=rank_order)

    # Add in points to show each observation
    sns.swarmplot(x="label", y="GuanScore", data=plot_data,
                  size=2, color=".3", linewidth=0,order=rank_order)

    # Tweak the visual presentation
    ax.set(ylabel="")
    ax.set(xlabel="")

    # Save figure
    plt.savefig(os.path.join(args.outdir, "violin_states_risk.pdf"), bbox_inches="tight")

    # Boxplots by states
    f, ax = plt.subplots(figsize=(12, 2))

    # Plot the orbital period with horizontal boxes
    sns.boxplot(x="label", y="GuanScore", data=plot_data,fliersize=0,
                palette="coolwarm",order=rank_order)

    # Add in points to show each observation
    sns.swarmplot(x="label", y="GuanScore", data=plot_data,
                  size=2, color=".3", linewidth=0,order=rank_order)

    # Tweak the visual presentation
    ax.set(ylabel="")
    ax.set(xlabel="")
    ax.set(ylim=(-0.1,1.1))

    # Save figure
    plt.savefig(os.path.join(args.outdir, "boxplot_states_risk.pdf"), bbox_inches="tight")

    # t(4;14) and t(11;14) subtypes by states
    plt.figure(figsize=(12, 2))

    N = len(ranks)
    ind = np.arange(N)    # the x locations for the groups
    w = 0.6
    p1 = plt.bar(ind, 100*ranked_t1114,width=w,color='#0A6ECC',edgecolor="white",alpha=1)
    p2 = plt.bar(ind, 100*ranked_t414,
                 bottom=100*ranked_t1114,width=w,color='#E53939',edgecolor="white",alpha=1)

    plt.xlim(-0.5,N-0.5)
    plt.ylim(-5,110)
    try:
        plt.xticks(ticks=range(len(rank_order)),labels=list(rank_order))
    except:
        # solves compatibility issue with pyplot, as it does not work in Python2
        pass
    plt.legend((p1[0], p2[0]), ('t(11;14)', 't(4;14)'),loc="upper left")

    plt.savefig(os.path.join(args.outdir,"barplot_states_translocations.pdf"),
                bbox_inches="tight")




