#!/usr/bin/env python

import argparse
import pandas as pd
import numpy as np
import json
import sys
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from miner2 import miner

DESCRIPTION = """miner-subtypes - MINER compute sample subtypes"""
MIN_CORRELATION = 0.2

if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description=DESCRIPTION)
    parser.add_argument('expfile', help="input matrix")
    parser.add_argument('mapfile', help="identifier mapping file")
    parser.add_argument('regulons', help="regulons.json file from miner-mechinf")
    parser.add_argument('outdir', help="output directory")
    args = parser.parse_args()

    if not os.path.exists(args.regulons):
        sys.exit("regulons file not found")

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    exp_data, conv_table = miner.preprocess(args.expfile, args.mapfile)
    bkgd = miner.backgroundDf(exp_data)

    with open(args.regulons) as infile:
        regulon_modules = json.load(infile)

    overexpressed_members = miner.biclusterMembershipDictionary(regulon_modules, bkgd, label=2,
                                                                p=0.05)
    overexpressed_members_matrix = miner.membershipToIncidence(overexpressed_members,
                                                               exp_data)
    underexpressed_members = miner.biclusterMembershipDictionary(regulon_modules, bkgd, label=0,
                                                                 p=0.05)
    underexpressed_members_matrix = miner.membershipToIncidence(underexpressed_members,
                                                                exp_data)

    sample_dictionary = overexpressed_members
    sample_matrix = overexpressed_members_matrix

    # perform initial subtype clustering
    similarity_clusters = miner.f1Decomposition(sample_dictionary, thresholdSFM=0.1)
    initial_classes = [i for i in similarity_clusters if len(i) > 4]

    # visualize initial results
    sample_freq_matrix = miner.sampleCoincidenceMatrix(sample_dictionary,
                                                       freqThreshold=0.333,
                                                       frequencies=True)
    similarity_matrix = sample_freq_matrix * sample_freq_matrix.T
    miner.plotSimilarity(similarity_matrix,
                         np.hstack(initial_classes), vmin=0, vmax=0.5,
                         title="Similarity matrix", xlabel="Samples",
                         ylabel="Samples", fontsize=14, figsize=(7,7),
                         savefig=os.path.join(args.outdir,
                                              "similarityMatrix_regulons.pdf"))

    # expand initial subtype clusters
    centroid_clusters, centroid_matrix = miner.centroidExpansion(initial_classes,
                                                                 sample_matrix,
                                                                 f1Threshold = 0.1,
                                                                 returnCentroids=True)

    centroid_matrix.to_csv(os.path.join(args.outdir, "centroids.csv"))
    unmapped = list(sample_matrix.columns[np.where(sample_matrix.sum(axis=0)==0)[0]])
    mapped_samples = [i for i in np.hstack(centroid_clusters) if i not in unmapped]
    mapped_clusters = miner.mapExpressionToNetwork(centroid_matrix, sample_matrix,
                                                   threshold = 0.05)

    # visualize expanded subtype clusters
    miner.plotSimilarity(similarity_matrix, mapped_samples,
                         vmin=0, vmax=0.333, title="Similarity matrix",
                         xlabel="Samples", ylabel="Samples", fontsize=14,
                         figsize=(7,7),
                         savefig=os.path.join(args.outdir,"centroidClusters_regulons.pdf"))

    # Generate heatmaps of module activity
    ordered_overexpressed_members = miner.orderMembership(centroid_matrix,
                                                          sample_matrix,
                                                          mapped_clusters,
                                                          ylabel="Modules",
                                                          resultsDirectory=args.outdir)

    ordered_dm = miner.plotDifferentialMatrix(sample_matrix,
                                              underexpressed_members_matrix,
                                              ordered_overexpressed_members,
                                              cmap="bwr", aspect="auto",
                                              saveFile=os.path.join(args.outdir,
                                                                    "centroid_clusters_heatmap.pdf"))

    # Infer transcriptional programs and states
    programs, states = miner.mosaic(dfr=ordered_dm, clusterList=centroid_clusters,
                                    minClusterSize_x=9, minClusterSize_y=5,
                                    allow_singletons=False, max_groups=50,
                                    saveFile=os.path.join(args.outdir,
                                                          "regulon_activity_heatmap.pdf"),
                                    random_state=12)

    # transcriptionalPrograms uses a feature of pandas that does not always
    # exist (pandas.core.indexes), my solution was to patch the problem
    # with an exception handler
    transcriptional_programs, program_regulons = miner.transcriptionalPrograms(programs,
                                                                               regulon_modules)
    program_list = [program_regulons[("").join(["TP",str(i)])]
                    for i in range(len(program_regulons))]

    mosaic_df = ordered_dm.loc[np.hstack(program_list), np.hstack(states)]
    mosaic_df.to_csv(os.path.join(args.outdir, "regulons_activity_heatmap.csv"))

    # Get eigengenes for all modules
    eigengenes = miner.getEigengenes(regulon_modules, exp_data,
                                     regulon_dict=None, saveFolder=None)

    # write eigengenes to .csv
    eigengenes.to_csv(os.path.join(args.outdir, "eigengenes.csv"))

    # plot eigengenes
    plt.figure()
    plt.imshow(eigengenes.loc[np.hstack(program_list), np.hstack(states)],
               cmap="viridis", vmin=-0.05, vmax=0.05, aspect="auto")
    plt.grid(False)

    # calculate percent of samples that fall into a state with >= minimum acceptable
    # number of samples
    groups = [states[i]
              for i in range(len(states))
              if len(states[i])>=int(np.ceil(0.01*exp_data.shape[1]))]

    print('Discovered {:d} transcriptional states and {:d} transcriptional programs'.format((len(states)),len(transcriptional_programs)))
    print('sample coverage within sufficiently large states: {:.1f}%'.format(100*float(len(np.hstack(groups))) / exp_data.shape[1]))

    # write all transcriptional program genesets to text files for external analysis
    if not os.path.isdir(os.path.join(args.outdir, "transcriptional_programs_coexpressionModules")):
        os.mkdir(os.path.join(args.outdir,"transcriptional_programs_coexpressionModules"))

    for tp in transcriptional_programs.keys():
        np.savetxt(os.path.join(args.outdir, "transcriptional_programs_coexpressionModules",
                                (".").join([tp,"txt"])),
                   transcriptional_programs[tp], fmt="%1.50s")

    # Determine activity of transcriptional programs in each sample
    states_df = miner.reduceModules(df=ordered_dm, programs=program_list, states=states,
                                    stateThreshold=0.65,
                                    saveFile=os.path.join(args.outdir,
                                                          "transcriptional_programs_vs_samples.pdf"))

    # Cluster patients into subtypes and give the activity of each program in each subtype
    programs_vs_states = miner.programsVsStates(states_df, states,
                                                filename=os.path.join(args.outdir,
                                                                      "programs_vs_states.pdf"))

    # Visualize with tSNE
    miner.tsne(exp_data, perplexity=15, n_components=2, n_iter=1000,
               plotOnly=True, plotColor="red", alpha=0.4)
    plt.savefig(os.path.join(args.outdir, "tsne_gene_expression.pdf"),
                bbox_inches="tight")

    # tSNE applied to df_for_tsne. Consider changing the perplexity in the
    # range of 5 to 50
    df_for_tsne = mosaic_df.copy()
    plt.figure()
    x_embedded = miner.tsne(df_for_tsne, perplexity=30, n_components=2,
                            n_iter=1000, plotOnly=None, plotColor="blue", alpha=0.2)
    tsne_df = pd.DataFrame(x_embedded)
    tsne_df.index = df_for_tsne.columns
    tsne_df.columns = ["tsne1", "tsne2"]
    plt.savefig(os.path.join(args.outdir, "tsne_regulon_activity.pdf"),
                bbox_inches="tight")

    """
    nsd2, maf, ccnd1, csk1b, ikzf1, ikzf3, tp53, e2f1, ets1, phf19 = ["ENSG00000109685","ENSG00000178573","ENSG00000110092","ENSG00000173207","ENSG00000185811","ENSG00000161405","ENSG00000141510","ENSG00000101412","ENSG00000134954","ENSG00000119403"]

    target = tp53
    target_expression = np.array(exp_data.loc[target, tsne_df.index])

    plt.scatter(tsne_df.iloc[:,0], tsne_df.iloc[:,1], cmap="bwr", c=target_expression, alpha=0.65)
    plt.savefig(os.path.join(args.outdir,("").join(["labeled_tsne_", target,
                                                    "_overlay_o2.pdf"])),bbox_inches="tight")
    """

    # How many clusters do you expect ? Start with number of states
    num_clusters = len(states)

    # Are the clusters separated how you thought? If not, change the random_state
    # to a different number and retry
    random_state = 12

    clusters, labels, centroids = miner.kmeans(tsne_df, numClusters=num_clusters,
                                               random_state=random_state)

    # overlay cluster labels. WW: Vega20 instead of tab20 because it does not exist
    # at my system
    plt.figure()
    try:
        plt.scatter(tsne_df.iloc[:,0], tsne_df.iloc[:,1], cmap="tab20", c=labels, alpha=0.65)
    except:
        plt.scatter(tsne_df.iloc[:,0], tsne_df.iloc[:,1], cmap="Vega20", c=labels, alpha=0.65)

    plt.savefig(os.path.join(args.outdir, "labeled_tsne_kmeans.pdf"),
                bbox_inches="tight")

    # convert states to tsne labels
    state_labels = miner.tsneStateLabels(tsne_df, states)

    # overlay states cluster labels
    plt.figure()
    try:
        plt.scatter(tsne_df.iloc[:,0], tsne_df.iloc[:,1], cmap="tab20",
                    c=state_labels,
                    alpha=0.65)
    except:
        plt.scatter(tsne_df.iloc[:,0], tsne_df.iloc[:,1], cmap="Vega20",
                    c=state_labels,
                    alpha=0.65)

    plt.savefig(os.path.join(args.outdir, "labeled_tsne_states.pdf"),
                bbox_inches="tight")


    # overlay activity of transcriptional programs
    miner.plotStates(states_df, tsne_df,
                     numCols=int(np.sqrt(states_df.shape[0])),
                     saveFile=os.path.join(args.outdir,
                                           ("_").join(["states_regulons",
                                                       ("").join(["0o", str(MIN_CORRELATION).split(".")[1]]),
                                                       "tsne.pdf"])),
                     aspect=1,size=10,scale=3)
