#!/usr/bin/env python

import logging
import traceback
import sys
import os
import math
import numpy as np
from scipy import stats

arguments = sys.argv + [""]
classed_arguments = {"infile_separator": "\t", "has_labels": "True", "alpha": "0.05", "foldchange_log_limits": "1.5", "type": "volcano"}

dir_path = os.path.dirname(os.path.realpath(__file__))

debug = False

try:
    for i in range(len(arguments)):

        argument = arguments[i]

        if argument.startswith("-"):

            classed_arguments[argument[1:]] = arguments[i + 1]

    if "debug" in classed_arguments.keys():
        debug = True

    if "h" in classed_arguments.keys() or "help" in classed_arguments.keys():
        help_output = open(dir_path + "/help_pages/ribonorma-plot.txt", "r").read()
        print(help_output)
        os._exit(1)

    has_labels = classed_arguments["has_labels"] == "True"

    data = [x.split(classed_arguments["infile_separator"]) for x in open(classed_arguments["in"]).read().split("\n")]

    if has_labels:
        labels = data[0]
        data = data[1:]

    # Plot
    transposed_data = [list(x) for x in list(zip(*data))]
    p_values = [float(x) for x in transposed_data[4]]
    fold_changes = [float(x) for x in transposed_data[1]]

    import matplotlib.pyplot as plt
    import seaborn as sns
    sns.set_style("whitegrid", {"axes.grid": False})

    alpha = float(classed_arguments["alpha"])
    foldchange_log_limits = float(classed_arguments["foldchange_log_limits"])

    logged_p = [-math.log10(x + 1e-8) for x in p_values]
    logged_fc = [math.log10(x + 1e-8) for x in fold_changes]

    if classed_arguments["type"] == "volcano":

        fig, ax = plt.subplots()

        colours = np.where(np.array(logged_p) > -math.log10(alpha), "#004594", "grey")
        plt.scatter(logged_fc, logged_p, s=1, c=colours)

        plt.axhline(-math.log10(alpha), color="black", linestyle="dashed")
        plt.axvline(0, color="black")

        plt.xlabel("log(fold change)")
        plt.ylabel("-log10(p-value)")

        plt.xlim(-foldchange_log_limits, foldchange_log_limits)

        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)

    elif classed_arguments["type"] == "histogram":

        logged_values = list()

        for i in range(len(p_values)-1, -1, -1):

            if p_values[i] < alpha:

                del logged_fc[i]

        fig, ax = plt.subplots()

        plt.hist(logged_fc, bins=50, range=(-foldchange_log_limits, foldchange_log_limits), density=True, color="#004594")

        plt.xlabel("log(fold change)")
        plt.ylabel("Probability density")

        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)

        # Multiply by 1.25 to extend the estimation
        filtered = [x for x in logged_fc if abs(x) < 1.25 * foldchange_log_limits]
        stat, p = stats.shapiro(filtered)

        x = np.linspace(-foldchange_log_limits, foldchange_log_limits, 100)
        plt.plot(x, stats.norm.pdf(x, 0, np.std(filtered)), color="#FF7D00")

        p_value_text = "Shapiro-Wilk p-value: {}".format("%.3g" % p)

        fig.text(0.85, 0.83, p_value_text, verticalalignment="top", horizontalalignment="right")
        fig.legend(["Normal distribution (SD estimated)", "Fold change distribution"], loc=(0.55, 0.85))

    elif classed_arguments["type"] == "QQ":

        for i in range(len(p_values)-1, -1, -1):

            if p_values[i] < alpha:

                del logged_fc[i]

        # Multiply by 1.25 to extend the estimation
        filtered = [x for x in logged_fc if abs(x) < 1.25 * foldchange_log_limits]

        stat, fit = stats.probplot(filtered)

        fig, ax = plt.subplots()

        plt.scatter(stat[0], stat[1], color="#004594", s=2)

        plt.ylabel("Ordered RNA counts")
        plt.xlabel("Theoretical normal quantiles")

        r = fit[2]

        x = np.linspace(-4, 4, 3)
        y = fit[0] * x + fit[1]

        plt.plot(x, y, color="#FF7D00")

        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)

        if fit[1] < 0:
            text = "y = {}x - {}\nR = {}".format("%.3g" % fit[0], "%.3g" % abs(fit[1]), "%.3g" % r)
        else:
            text = "y = {}x + {}\nR = {}".format("%.3g" % fit[0], "%.3g" % fit[1], "%.3g" % r)

        fig.text(0.85, 0.45, text, verticalalignment="top", horizontalalignment="right")

    if "out" in classed_arguments.keys():
        fig.savefig(classed_arguments["out"])

    else:
        plt.show()

except Exception as exception:

    if debug:
        logging.error(traceback.format_exc())

    else:
        print("USAGE\n  ribonorma-plot [-h(elp)] [-type histogram/volcano/QQ] [-in filename] [-out filename]\n    [-infile_separator \\t] [-has_labels True] [-alpha 0.05]\n    [-foldchange_log_limits 1.5] [-debug]\n\nDESCRIPTION\n  Ribonorma Python v1.1\n\nUse '-help' to print detailed descriptions of command line arguments\n========================================================================")
