#!/usr/bin/env python

import logging
import traceback
import sys
import os
from ribonorma import statistics
from scipy import stats

arguments = sys.argv + [""]
classed_arguments = {"infile_separator": "\t", "outfile_separator": "\t", "phenotype_file_separator": "\n", "has_labels": "True", "compare": "", "plot": "True", "levene": "False", "ovl": "False"}

dir_path = os.path.dirname(os.path.realpath(__file__))

debug = False

try:
    for i in range(len(arguments)):

        argument = arguments[i]

        if argument.startswith("-"):

            classed_arguments[argument[1:]] = arguments[i + 1]

    if "debug" in classed_arguments.keys():
        debug = True

    if "h" in classed_arguments.keys() or "help" in classed_arguments.keys():
        help_output = open(dir_path + "/help_pages/ribonorma-compare.txt", "r").read()
        print(help_output)
        os._exit(1)

    has_labels = classed_arguments["has_labels"] == "True"

    data = [x.split(classed_arguments["infile_separator"]) for x in open(classed_arguments["in"]).read().split("\n")]

    if has_labels:
        labels = data[0]
        data = data[1:]

    transposed_data = [list(x) for x in list(zip(*data))]
    ids = transposed_data[0]

    phenotypes = [x for x in open(classed_arguments["phenotype_file"]).read().split(classed_arguments["phenotype_file_separator"]) if len(x) > 0]
    unique_phenotypes = sorted(list(set(phenotypes)))
    compare_phenotypes = classed_arguments["compare"].split(",")

    genes = classed_arguments["genes"].split(",")

    assert len(genes) >= 1
    indices = [ids.index(x) for x in genes]

    # Group by phenotypes
    pheno_rearranged = list()
    for i in range(len(data)):

        pheno_datapoint = [list() for x in range(len(unique_phenotypes))]

        for j in range(1, len(data[i])):

            index = unique_phenotypes.index(phenotypes[j - 1])
            pheno_datapoint[index].append(float(data[i][j]))

        pheno_rearranged.append(pheno_datapoint)

    # Perform tests
    assert len(phenotypes) >= len(compare_phenotypes) and len(compare_phenotypes) != 0
    samples = [pheno_rearranged[x] for x in indices]

    transposed_phenos = [list(x) for x in list(zip(*samples))]

    output = statistics.ANOVA(transposed_phenos, genes)

    plot = classed_arguments["plot"] == "True"

    if not plot:

        for i in range(len(output)):

            current_datapoint = output[i]
            writable.append([current_datapoint["gene"], str(current_datapoint["fold-change"]), str(-math.log10(current_datapoint["fold-change"]) if current_datapoint["fold-change"] != 0 else math.nan), str(current_datapoint["statistic"]), str(current_datapoint["p-value"])])

        out = "\n".join([classed_arguments["outfile_separator"].join(x) for x in writable])

        if "out" in classed_arguments.keys():
            open(classed_arguments["out"], "w+").write(out)

        else:
            print(out)

    else:

        import matplotlib.pyplot as plt
        import seaborn as sns
        sns.set_style("whitegrid", {"axes.grid": False})

        for i in range(len(output)):

            test_levene = classed_arguments["levene"] == "True"
            calculate_ovl = classed_arguments["ovl"] == "True"

            addendum = str()

            if calculate_ovl:
                ovl = statistics.calculateOVL(samples[i])
                addendum += "\nOVL: {}".format("%.3g" % ovl)

            if test_levene:
                stat, p = stats.levene(*samples[i], center="mean")
                addendum += "\nLevene's p-value: {}".format("%.3g" % p)

            current_datapoint = output[i]
            fig = plt.figure()

            ax = sns.violinplot(data=samples[i], color="#004594")
            sns.stripplot(data=samples[i], color="#FF7D00", jitter=True)

            ax.spines["right"].set_visible(False)
            ax.spines["top"].set_visible(False)

            #plt.title(genes[i])
            plt.title(genes[i])
            plt.ylim(bottom=0)

            plt.ylabel("Counts")
            ax.set_xticklabels(unique_phenotypes)

            string = "Kruskal-Wallis' p-value = {}".format("%.3g" % current_datapoint["p-value"]) + addendum
            fig.text(0.93, 0.98, string, verticalalignment="top", horizontalalignment="right")

            plt.show()


except Exception as exception:

    if debug:
        logging.error(traceback.format_exc())

    else:
        print("USAGE\n  ribonorma-compare [-h(elp)] [-in filename] [-out filename]\n    [-phenotype_file filename] [-infile_separator \\t] [-outfile_separator \\t]\n    [-phenotype_file_separator \\n] [-plot True] [-genes ids] [-has_labels True]\n    [-levene False] [-ovl False] [-debug]\n\nDESCRIPTION\n  Ribonorma Python v1.1\n\nUse '-help' to print detailed descriptions of command line arguments\n========================================================================")
