#!/usr/bin/env python

import argparse
import json
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn import tree
from sklearn.externals.six import StringIO
import pydot
from graphviz import Source

# xgboost requires different plotting method
from xgboost import plot_tree, to_graphviz

import pickle
import logging

from miner2 import risk_predict, survival, util
from miner2 import GIT_SHA, __version__ as pkg_version

DESCRIPTION = """miner-riskclassifier - MINER compute risk classifier.
MINER Version %s (Git SHA %s)""" % (pkg_version, GIT_SHA.replace('$Id: ', '').replace(' $', ''))


def cross_dataset_prediction(membership_datasets, survival_datasets, dataset_labels,
                             outdir, method, verbose):
    if verbose:
        logging.info("calculating cross data set prediction using method: '%s'" % method)

    if method == "decisionTree":
        classifier, class0, class1, mean_aucs, mean_hrs, pct_labeled, precision_matrix, recall_matrix = risk_predict.generate_predictor(membership_datasets, survival_datasets, dataset_labels, iterations=30, method=method, output_directory=outdir, metric='hazard_ratio', separate_results=True, class1_proportion=0.30, test_proportion=0.30, best_state=11, test_only=True, verbose=verbose)
    elif method == "xgboost":
        classifier, class0, class1, mean_aucs, mean_hrs, pct_labeled, precision_matrix, recall_matrix = risk_predict.generate_predictor(membership_datasets, survival_datasets, dataset_labels, iterations=30, method=method, output_directory=outdir, metric='hazard_ratio', separate_results=True, class1_proportion=0.30, test_proportion=0.30, best_state=None, verbose=verbose)
    else:
        raise Exception("unknown predictor type '%s'" % method)

    if verbose:
        logging.info('plotting precision/recall')
    recall = recall_matrix[2]
    precision = precision_matrix[2]
    plt.step(recall, precision, alpha=0.8,where='post')

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.ylim([0.0, 1.05])
    plt.xlim([0.0, 1.0])
    plt.savefig(os.path.join(outdir, 'precision_recall.pdf'))

    return classifier


def visualize_decision_tree(classifier, outdir):
    with open(os.path.join(outdir, 'decision_tree_sklearn.gv'), 'w') as outfile:
        tree.export_graphviz(classifier, out_file=outfile)
    dot_data = StringIO()
    tree.export_graphviz(classifier, out_file=dot_data)
    graph = pydot.graph_from_dot_data(dot_data.getvalue())

    # visualize decision tree
    graph[0].set_graph_defaults(size = "\"15,15\"")
    decisionTree = Source(graph[0].to_string())
    graph[0].write_pdf(os.path.join(outdir, "decisionTreeAUC.pdf"))
    return decisionTree


def visualize_decision_tree_xgboost(classifier, outdir):
    dot = to_graphviz(classifier)
    dot.save(os.path.join(outdir, 'decision_tree_xgboost.gv'))

    plot_tree(classifier, num_trees=1)
    fig = plt.gcf()
    fig.set_size_inches(50, 10)
    plt.savefig(os.path.join(outdir, "xgboost_tree_0.pdf"))


if __name__ == '__main__':
    LOG_FORMAT = '%(asctime)s %(message)s'
    logging.basicConfig(format=LOG_FORMAT, level=logging.DEBUG,
                        datefmt='%Y-%m-%d %H:%M:%S \t')

    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description=DESCRIPTION)
    parser.add_argument('input', help="input specification file")
    parser.add_argument('outdir', help="output directory")
    parser.add_argument('--method', default="decisionTree", help="predictor method (decisionTree or xgboost)")
    parser.add_argument('-v', '--verbose', action="store_true", help="activate verbose mode")

    args = parser.parse_args()

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    with open(os.path.join(args.outdir, 'run_info.txt'), 'w') as outfile:
        util.write_dependency_infos(outfile)

    with open(args.input) as infile:
        input_spec = json.load(infile)

    datasets = risk_predict.read_datasets(input_spec, args.verbose)

    # extract the important info
    membership_datasets = [ds['omm'] for ds in datasets]
    survival_datasets = [ds['gs'] for ds in datasets]
    dataset_labels = [ds['label'] for ds in datasets]

    classifier = cross_dataset_prediction(membership_datasets, survival_datasets, dataset_labels,
                                          args.outdir, args.method, args.verbose)

    if args.verbose:
        logging.info("Generating decision tree visualization")

    if args.method == 'decisionTree':
        visualize_decision_tree(classifier, args.outdir)
    elif args.method == 'xgboost':
        visualize_decision_tree_xgboost(classifier, args.outdir)

    if args.verbose:
        logging.info("Writing serialized classifier")
    clf_filepath = os.path.join(args.outdir, 'miner_alldata_predictor.pkl')
    with open(clf_filepath, 'wb') as outfile:
        pickle.dump(classifier, outfile)
