#!/usr/bin/env python

import argparse
import pandas as pd
import numpy as np
import json
import sys
import os
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from miner2 import coexpression
from miner2 import preprocess

DESCRIPTION = """miner-coexpr - MINER cluster expression data."""

if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description=DESCRIPTION)
    parser.add_argument('expfile', help="input matrix")
    parser.add_argument('mapfile', help="identifier mapping file")
    parser.add_argument('outdir', help="output directory")
    parser.add_argument('-mg', '--mingenes', type=int, default=6, help="min number genes")
    parser.add_argument('-moxs', '--minoverexpsamp', type=int, default=4,
                        help="minimum overexpression samples")
    parser.add_argument('-mx', '--maxexclusion', type=float, default=0.5,
                        help="maximum samples excluded")
    parser.add_argument('-rs', '--randstate', type=float, default=12,
                        help="random state")
    parser.add_argument('-oxt', '--overexpthresh', type=int, default=80,
                        help="overexpression threshold")

    args = parser.parse_args()
    if not os.path.exists(args.expfile):
        sys.exit("expression file not found")
    if not os.path.exists(args.mapfile):
        sys.exit("identifier mapping file not found")

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    exp_data, conv_table = preprocess.main(args.expfile, args.mapfile)
    t1 = time.time()
    init_clusters = coexpression.cluster(exp_data,
                                         minNumberGenes=args.mingenes,
                                         minNumberOverExpSamples=args.minoverexpsamp,
                                         maxSamplesExcluded=args.maxexclusion,
                                         random_state=args.randstate,
                                         overExpressionThreshold=args.overexpthresh)

    revised_clusters = coexpression.reviseInitialClusters(init_clusters, exp_data)
    with open(os.path.join(args.outdir, "coexpressionDictionary.json"), 'w') as out:
        json.dump(revised_clusters, out)


    # retrieve first three clusters for visual inspection
    first_clusters = np.hstack([revised_clusters[i] for i in np.arange(3).astype(str)])

    # visualize background expression
    plt.figure(figsize=(8,4))
    plt.imshow(exp_data.loc[np.random.choice(exp_data.index, len(first_clusters), replace=False),:],
               aspect="auto", cmap="viridis", vmin=-1,vmax=1)
    plt.grid(False)
    plt.ylabel("Genes",FontSize=20)
    plt.xlabel("Samples",FontSize=20)
    plt.title("Random selection of genes",FontSize=20)

    plt.savefig(os.path.join(args.outdir, "background_expression.pdf"),
                bbox_inches="tight")

    # visualize first 10 clusters
    plt.figure(figsize=(8,4))
    plt.imshow(exp_data.loc[first_clusters,:], aspect="auto", cmap="viridis", vmin=-1, vmax=1)
    plt.grid(False)
    plt.ylabel("Genes", FontSize=20)
    plt.xlabel("Samples", FontSize=20)
    plt.title("First 3 clusters", FontSize=20)
    plt.savefig(os.path.join(args.outdir, "first_clusters.pdf"),
                bbox_inches="tight")

    # report coverage
    print("Number of genes clustered: {:d}".format(len(set(np.hstack(init_clusters)))))
    print("Number of unique clusters: {:d}".format(len(revised_clusters)))

    t2 = time.time()
    print("Completed clustering module in {:.2f} minutes".format((t2-t1)/60.))

    """

    # visualize first 10 clusters
    plt.figure(figsize=(8,8))
    plt.imshow(exp_data.loc[np.hstack([revised_clusters[i] for i in range(10)]),:],
               aspect="auto", cmap="viridis", vmin=-1, vmax=1)
    plt.grid(False)
    plt.ylabel("Genes", FontSize=20)
    plt.xlabel("Samples", FontSize=20)
    plt.title("First 10 clusters", FontSize=20)

    # report coverage
    #print("Number of genes clustered: {:d}".format(len(set(np.hstack(initialClusters)))))
    #print("Number of unique clusters: {:d}".format(len(revisedClusters)))

    # plot histogram of the cluster size distribution
    counts_ = plt.hist([len(revised_clusters[key]) for key in revised_clusters.keys()],
                       bins=100)
    plt.xlabel("Number of genes in cluster", FontSize=14)
    plt.ylabel("Number of clusters", FontSize=14)
    plt.savefig(os.path.join(args.outdir, "cluster_size_distribution.pdf"),
                bbox_inches="tight")
    """
