#!python3
import time
import sys
import os
from importlib.metadata import version
import warnings

import pycurl
import numpy as np
from numpy.random import default_rng
import seaborn as sns
import pandas as pd
import big_o as bo

from gues import grues
from gues.sample import sample_causal_dag


def complexity(rng, path, how):
    # the results seem relatively consistent across different random seeds

    # create one data set to run all tests on
    _, dataset = sample_causal_dag(num_nodes=50, edge_prob=0.5, samp_size=100000)

    # define functions to call GrUES while varying only the specified parameter
    def vary_num_nodes(num_nodes):  # n
        obj = grues.InputData(dataset[:1000, :num_nodes])
        obj.mcmc(max_moves=10)

    def vary_samp_size(sample_size):  # s
        obj = grues.InputData(dataset[:sample_size, :10])
        obj.mcmc(max_moves=10)

    def vary_mc_len(mc_len):  # l
        obj = grues.InputData(dataset[:1000, :10])
        obj.mcmc(max_moves=mc_len)

    # run the experiments
    n_best, _ = bo.big_o(vary_num_nodes, bo.datagen.n_, min_n=5, max_n=50)
    s_best, _ = bo.big_o(vary_samp_size, bo.datagen.n_, min_n=1000, max_n=100000)
    l_best, _ = bo.big_o(vary_mc_len, bo.datagen.n_, min_n=5, max_n=1000)
    # the complexity estimates hold when Markov chain length is
    # increased to 1,000,000, so it's set to only 1,000 here to
    # greatly reduce experiment run time;

    # make and save table
    header = np.array(("varied parameter", "estimated complexity"))
    left_col = np.array(("number of nodes", "sample size", "Markov chain length"))
    right_col = np.array((f"{n_best}", f"{s_best}", f"{l_best}"))
    table = np.vstack((left_col, right_col)).T
    table = np.vstack((header, table))
    np.savetxt(f"{path}complexity_table.csv", table, fmt="%.40s", delimiter=",")

    # for class_, residuals in others.items():
    #     print("{!s:<60s}    (res: {:.2G})".format(class_, residuals))


def hist(rng, path, how):
    subdir = "hist_results/"
    if how == "fresh":
        rng = default_rng(rng.integers(0, 10**10, 7))  # set new random seed
        precision = np.eye(3, dtype=float)
        precision[0, 2] = 0.9247
        precision = precision.dot(precision.T)
        cov = np.linalg.inv(precision)
        sample = rng.multivariate_normal(np.zeros(3), cov, 1000)

        model = grues.InputData(sample, rng)
        model.mcmc(max_moves=10000, prior=(2, 3))

        uniques, indcs, inv, counts = np.unique(
            model.uec_markov_chain[-1000:],
            return_index=True,
            return_inverse=True,
            return_counts=True,
            axis=0,
        )
        graph_names = np.array(
            [
                "$\mathcal{U}_1$",
                "$\mathcal{U}_4$",
                "$\mathcal{U}_2$",
                "$\mathcal{U}_6$",
                "$\mathcal{U}_3$",
                "$\mathcal{U}_7$",
                "$\mathcal{U}_5$",
                "$\mathcal{U}_8$",
            ]
        )
        interpretable_markov_chain = graph_names[inv]

        os.makedirs(os.path.dirname(path + subdir), exist_ok=True)
        np.save(path + subdir + "n3_markov_chain.npy", interpretable_markov_chain)
    elif how == "download":
        download(path, subdir, ("n3_markov_chain.npy",))

    interpretable_markov_chain = np.load(path + subdir + "n3_markov_chain.npy")
    sns.set(font_scale=2)
    plt = sns.histplot(np.sort(interpretable_markov_chain), stat="probability")
    plt.figure.tight_layout()
    plt.figure.savefig(path + "hist.png")
    plt.figure.clf()


def uni(rng, path, how):
    # set experiment parameters
    subdir = "uni_results/"
    num_runs_per = 100
    densities = np.round(np.arange(0.1, 1, 0.1), 2)
    num_nodes = 5
    samp_size = 1000
    mc_len = 10000
    prior = None

    # run experiments, unless called to only redo plots
    if how == "fresh":
        os.makedirs(os.path.dirname(path + subdir), exist_ok=True)
        for density in densities:
            print(f"Running density={density}")
            run_experiments(
                path + subdir,
                num_runs_per,
                num_nodes,
                density,
                samp_size,
                mc_len,
                prior,
                rng,
            )
    elif how == "download":
        end_paths = (
            f"n={num_nodes}_d={p}_ss={samp_size}_mc={mc_len}.npz" for p in densities
        )
        download(path, subdir, end_paths)

    # load all results
    results = {
        p: np.load(f"{path + subdir}n={num_nodes}_d={p}_ss={samp_size}_mc={mc_len}.npz")
        for p in densities
    }

    # make plots
    make_plots(
        path,
        path + subdir,
        results,
        num_runs_per,
        densities,
        num_nodes,
        mc_len,
        incl_map=True,
    )


def tri_n(rng, path, how):
    # set experiment parameters
    subdir = "tri_n_results/"
    num_runs_per = 100
    densities = np.round(np.arange(0.1, 1, 0.1), 2)
    num_nodes = 5
    samp_size = 1000
    mc_len = 10000
    prior = num_nodes

    # run experiments, unless called to only redo plots
    if how == "fresh":
        os.makedirs(os.path.dirname(path + subdir), exist_ok=True)
        for density in densities:
            print(f"Running density={density}")
            run_experiments(
                path + subdir,
                num_runs_per,
                num_nodes,
                density,
                samp_size,
                mc_len,
                prior,
                rng,
            )
    elif how == "download":
        end_paths = (
            f"n={num_nodes}_d={p}_ss={samp_size}_mc={mc_len}.npz" for p in densities
        )
        download(path, subdir, end_paths)

    # load all results
    results = {
        p: np.load(f"{path + subdir}n={num_nodes}_d={p}_ss={samp_size}_mc={mc_len}.npz")
        for p in densities
    }

    # make plots
    make_plots(
        path,
        path + subdir,
        results,
        num_runs_per,
        densities,
        num_nodes,
        mc_len,
        incl_map=True,
    )


def n_8(rng, path, how):
    # set experiment parameters
    subdir = "n_8_results/"
    num_runs_per = 5
    densities = np.round(np.arange(0.1, 1, 0.1), 2)
    num_nodes = 8
    samp_size = 1000
    mc_len = 1000000
    prior = num_nodes

    # run experiments, unless called to only redo plots
    if how == "fresh":
        os.makedirs(os.path.dirname(path + subdir), exist_ok=True)
        for density in densities:
            print(f"Running density={density}")
            run_experiments(
                path + subdir,
                num_runs_per,
                num_nodes,
                density,
                samp_size,
                mc_len,
                prior,
                rng,
            )
    elif how == "download":
        end_paths = (
            f"n={num_nodes}_d={p}_ss={samp_size}_mc={mc_len}.npz" for p in densities
        )
        download(path, subdir, end_paths)

    # load all results
    results = {
        p: np.load(f"{path + subdir}n={num_nodes}_d={p}_ss={samp_size}_mc={mc_len}.npz")
        for p in densities
    }

    # make plots
    make_plots(
        path,
        path + subdir,
        results,
        num_runs_per,
        densities,
        num_nodes,
        mc_len,
        incl_map=True,
    )


def n_10(rng, path, how):
    # set experiment parameters
    subdir = "n_10_results/"
    num_runs_per = 50
    densities = np.round(np.arange(0.1, 1, 0.1), 2)
    num_nodes = 10
    samp_size = 10000
    mc_len = 100000
    prior = num_nodes

    # run experiments, unless called to only redo plots
    if how == "fresh":
        os.makedirs(os.path.dirname(path + subdir), exist_ok=True)
        for density in densities:
            print(f"Running density={density}")
            run_experiments(
                path + subdir,
                num_runs_per,
                num_nodes,
                density,
                samp_size,
                mc_len,
                prior,
                rng,
            )
    elif how == "download":
        end_paths = (
            f"n={num_nodes}_d={p}_ss={samp_size}_mc={mc_len}.npz" for p in densities
        )
        download(path, subdir, end_paths)

    # load all results
    results = {
        p: np.load(f"{path + subdir}n={num_nodes}_d={p}_ss={samp_size}_mc={mc_len}.npz")
        for p in densities
    }

    # make plots
    make_plots(
        path,
        path + subdir,
        results,
        num_runs_per,
        densities,
        num_nodes,
        mc_len,
        incl_map=True,
    )


# functions to actually run (sets of) experiments and make plots


def experiment(num_nodes, edge_prob, samp_size, mc_len, prior, rng, name):
    true_uec, true_dag, dataset = sample_causal_dag(
        num_nodes,
        edge_prob,
        weight_interv=(0, 1),
        samp_size=samp_size,
        return_uec=True,
        rng=rng,
    )

    if type(prior) is int:
        num_sources = (true_dag.sum(0) == 0).sum()
        prior = (num_sources, prior)

    start = time.time()
    obj = grues.InputData(dataset, rng)
    obj.mcmc(
        init=("gauss", 0.05),
        max_moves=mc_len,
        prior=prior,
    )
    end = time.time()
    runtime = end - start

    true = grues.InputData(dataset)
    true.mcmc(init=true_uec, max_moves=1)

    results = {
        "runtime": runtime,
        "true_dag": true_dag.astype(int),
        "true_uec": true_uec.astype(int),
        "true_bic": true.best_bic,
        "true_nuc": true.best_nuc,
        "it_uec": obj.indep_test_uec,
        "map_uec": obj.map_uec,
        "bic_uec": obj.bic_uec,
        "nuc_uec": obj.nuc_uec,
        "best_bic": obj.best_bic,
        "best_nuc": obj.best_nuc,
        "bic_mc": obj.bic_markov_chain,
        "nuc_mc": obj.nuc_markov_chain,
        "uec_mc": obj.uec_markov_chain,
        "dataset": dataset,
    }
    np.savez_compressed(name + ".npz", **results)

    return (
        runtime,
        true_dag.astype(int),
        true_uec.astype(int),
        true.best_bic,
        true.best_nuc,
        obj.indep_test_uec,
        obj.map_uec,
        obj.bic_uec,
        obj.nuc_uec,
        obj.best_bic,
        obj.best_nuc,
        obj.bic_markov_chain,
        obj.nuc_markov_chain,
        obj.uec_markov_chain,
        dataset,
    )


def run_experiments(path, num_runs, num_nodes, density, samp_size, mc_len, prior, rng):
    results = {
        "runtimes": np.empty(num_runs, float),
        "true_dags": np.empty((num_runs, num_nodes, num_nodes), int),
        "true_uecs": np.empty((num_runs, num_nodes, num_nodes), int),
        "true_bics": np.empty(num_runs, float),
        "true_nucs": np.empty(num_runs, float),
        "it_uecs": np.empty((num_runs, num_nodes, num_nodes), int),
        "map_uecs": np.empty((num_runs, num_nodes, num_nodes), int),
        "bic_uecs": np.empty((num_runs, num_nodes, num_nodes), int),
        "nuc_uecs": np.empty((num_runs, num_nodes, num_nodes), int),
        "best_bics": np.empty(num_runs, float),
        "best_nucs": np.empty(num_runs, float),
        "bic_mcs": np.empty((num_runs, mc_len), float),
        "nuc_mcs": np.empty((num_runs, mc_len), float),
        "uec_mcs": np.empty((num_runs, mc_len, num_nodes, num_nodes), int),
        "datasets": np.empty((num_runs, samp_size, num_nodes), float),
    }

    for run in range(num_runs):
        name = f"{path}r{run}_d={density}"
        print(run + 1, " out of ", num_runs)
        (
            results["runtimes"][run],
            results["true_dags"][run],
            results["true_uecs"][run],
            results["true_bics"][run],
            results["true_nucs"][run],
            results["it_uecs"][run],
            results["map_uecs"][run],
            results["bic_uecs"][run],
            results["nuc_uecs"][run],
            results["best_bics"][run],
            results["best_nucs"][run],
            results["bic_mcs"][run],
            results["nuc_mcs"][run],
            results["uec_mcs"][run],
            results["datasets"][run],
        ) = experiment(num_nodes, density, samp_size, mc_len, prior, rng, name)

    name = f"{path}n={num_nodes}_d={density}_ss={samp_size}_mc={mc_len}.npz"
    np.savez_compressed(name, **results)


def download(path, subdir, end_paths: iter):
    print("Downloading intermediate results...")
    base_url = (
        "https://codeberg.org/alex-markham/reproduced_astat_results/raw/branch/main/"
    )
    for end_path in end_paths:
        os.makedirs(os.path.dirname(path + subdir), exist_ok=True)
        with open(path + subdir + end_path, "wb") as f:
            c = pycurl.Curl()
            c.setopt(c.URL, base_url + subdir + end_path)
            c.setopt(c.WRITEDATA, f)
            c.perform()
            c.close()


def make_plots(
    path, subpath, results, num_runs_per, densities, num_nodes, mc_len, incl_map
):
    num_densities = len(densities)
    # initialize dicts for proportion correct UECs
    correct_bics = np.empty_like(densities, float)
    correct_nucs = np.empty_like(densities, float)
    correct_its = np.empty_like(densities, float)
    if incl_map:
        correct_maps = np.empty_like(densities, float)
        # look at 9 different sizes of HPD credible sets: (0.1, ..., 0.9)
        interval_maps = np.empty((9, num_densities), float)
        interval_sizes = np.empty((9, num_densities), float)

    # initialize dicts for structural accuracy of UECS
    sim_T_B = np.empty_like(densities, float)
    sim_T_N = np.empty_like(densities, float)
    sim_T_I = np.empty_like(densities, float)
    if incl_map:
        sim_T_M = np.empty_like(densities, float)

    # fill in dicts for correct and accuracy plots
    for idx, density in enumerate(densities):
        res_dens = results[density]

        # blue in Fig 10(a)
        correct_its[idx] = (
            (res_dens["true_uecs"] == res_dens["it_uecs"]).all((1, 2)).mean()
        )

        # orange in Fig 10(a)
        correct_bics[idx] = (
            (res_dens["true_uecs"] == res_dens["bic_uecs"]).all((1, 2)).mean()
        )

        # green in Fig 10(a)
        correct_nucs[idx] = (
            (res_dens["true_uecs"] == res_dens["nuc_uecs"]).all((1, 2)).mean()
        )

        # for Fig 10(b)
        norm = num_nodes * (num_nodes - 1)

        # blue in Fig 10(b)
        t_i_sum = (res_dens["true_uecs"] == res_dens["it_uecs"]).sum((1, 2))
        sim_T_I[idx - 1] = ((t_i_sum - num_nodes) / norm).mean()

        # orange in Fig 10(b)
        t_b_sum = (res_dens["true_uecs"] == res_dens["bic_uecs"]).sum((1, 2))
        sim_T_B[idx - 1] = ((t_b_sum - num_nodes) / norm).mean()

        # green in Fig 10(b)
        t_n_sum = (res_dens["true_uecs"] == res_dens["nuc_uecs"]).sum((1, 2))
        sim_T_N[idx - 1] = ((t_n_sum - num_nodes) / norm).mean()

        if incl_map:
            # posterior frequency
            uci = {
                idx: np.unique(arr, return_index=True, return_counts=True)
                for idx, arr in enumerate(res_dens["bic_mcs"])
            }

            # initialize dicts for HPD credible sets
            true_aps = np.empty(num_runs_per, float)
            maps = np.empty(num_runs_per, float)
            in_interval = np.empty(
                (9, num_runs_per), bool
            )  # 9 for the different sizes of HPD crebible sets
            interval_size = np.empty((9, num_runs_per), int)

            # fill in HPD credible set dicts (purple and brown in Fig 10(a))
            for jdx in uci.keys():
                uniques, indcs, counts = uci[jdx]

                true_rss = res_dens["true_bics"][jdx]
                true_rss_idx = np.where(uniques == true_rss)[0]
                ap = counts[true_rss_idx] / mc_len
                ap = 0 if len(ap) == 0 else ap
                true_aps[jdx] = ap

                map_idx = np.argmax(counts)
                maps[jdx] = counts[map_idx] / mc_len

                order = np.argsort(counts)[::-1]
                cdf = np.cumsum(counts[order]) / mc_len

                for kdx in range(9):
                    interval = (kdx + 1) / 10
                    interval_mask = cdf <= interval
                    interval_size[kdx, jdx] = interval_mask.sum()
                    in_interval[kdx, jdx] = true_rss in uniques[order][interval_mask]

            # red in Fig 10(a)
            correct_maps[idx] = (
                (res_dens["true_uecs"] == res_dens["map_uecs"]).all((1, 2)).mean()
            )

            # for purple and brown in Fig 10(a)
            interval_maps[:, idx] = in_interval.mean(1)

            # for Table 2
            interval_sizes[:, idx] = interval_size.mean(1)

            # red in Fig 10(b)
            t_m_sum = (res_dens["true_uecs"] == res_dens["map_uecs"]).sum((1, 2))
            sim_T_M[idx - 1] = ((t_m_sum - num_nodes) / norm).mean()

    # save data for Fig 10(a)
    estimates = np.array(
        (
            "$\hat\mathcal{U}_{\mathrm{IT}}$",
            "$\hat\mathcal{U}_{\ell_0}$",
            "$\hat\mathcal{U}_{\mathrm{nuc}}$",
        )
    )
    prop_corr = np.hstack((correct_its, correct_bics, correct_nucs))
    if incl_map:
        map_estimates = np.array(
            (
                "$\hat\mathcal{U}_{\mathrm{MAP}}$",
                "$\hat\mathcal{U}_{[0.1]}$",
                "$\hat\mathcal{U}_{[0.2]}$",
            )
        )
        estimates = np.hstack((estimates, map_estimates))
        map_prop = np.hstack(
            (
                correct_maps,
                interval_maps[0],
                interval_maps[1],
            )
        )
        prop_corr = np.hstack((prop_corr, map_prop))

    plot_data_correct = {
        "density": np.tile(densities, len(estimates)),
        "estimate": np.repeat(
            estimates,
            num_densities,
        ),
        "proportion correct": prop_corr,
    }
    if incl_map:
        plot_data_correct["interval_maps"] = interval_maps
        plot_data_correct["interval_sizes"] = interval_sizes

    np.save(f"{subpath}plot_data_correct.npy", plot_data_correct)

    # save data for Fig 10(b)
    comp = np.array(
        (
            "$(\mathcal{U}, \hat\mathcal{U}_{\mathrm{IT}})$",
            "$(\mathcal{U}, \hat\mathcal{U}_{\ell_0})$",
            "$(\mathcal{U}, \hat\mathcal{U}_{\mathrm{nuc}})$",
        )
    )
    shs = np.hstack((sim_T_I, sim_T_B, sim_T_N))
    if incl_map:
        comp = np.hstack((comp, ("$(\mathcal{U}, \hat\mathcal{U}_{\mathrm{MAP}})$")))
        shs = np.hstack((shs, sim_T_M))
    plot_data_accuracy = {
        "density": np.tile(densities, len(comp)),
        "comparison": np.repeat(
            comp,
            num_densities,
        ),
        "average SHS": shs,
    }
    np.save(f"{subpath}plot_data_accuracy.npy", plot_data_accuracy)

    # format data for Fig Y(a)
    plot_data_correct = np.load(
        f"{subpath}plot_data_correct.npy", allow_pickle=True
    ).item()

    if incl_map:
        # save csv file for Table
        interval_sizes = plot_data_correct["interval_sizes"]
        table = np.hstack((densities[:, None].astype(str), interval_sizes.astype(str)))
        table = np.vstack((np.hstack(("t", densities)), table))
        np.savetxt(f"{subpath[:-9]}_table.csv", table, fmt="%1s", delimiter=",")

        del plot_data_correct["interval_maps"]
        del plot_data_correct["interval_sizes"]

    correct_df = pd.DataFrame.from_dict(plot_data_correct)

    # format data for Fig 10(b)
    plot_data_accuracy = np.load(
        f"{subpath}plot_data_accuracy.npy", allow_pickle=True
    ).item()
    accuracy_df = pd.DataFrame.from_dict(plot_data_accuracy)

    # plot Fig 10(a)
    sns.set(font_scale=1)
    plt = sns.barplot(
        data=correct_df, x="density", y="proportion correct", hue="estimate"
    )
    sns.move_legend(plt, "lower center", bbox_to_anchor=(0.5, 1), ncol=3, title="")
    plt.set(ylim=(0, 1.05))
    plt.figure.tight_layout()
    plt.figure.savefig(f"{subpath[:-9]}_correct.png")
    plt.figure.clf()

    # plot Fig 10(b)
    plt = sns.barplot(
        data=accuracy_df,
        x="density",
        y="average SHS",
        hue="comparison",
    )
    sns.move_legend(plt, "lower center", bbox_to_anchor=(0.5, 1), ncol=4, title="")
    plt.figure.tight_layout()
    plt.figure.savefig(f"{subpath[:-9]}_accuracy.png")
    plt.figure.clf()


# script logic for CLI
if __name__ == "__main__":
    # check versions to ensure accurate reproduction
    if version("gues") != "0.2.0":  # also check other packages?
        warnings.warn(
            f"Current Python package versions unsupported. In case of problems, first "
            "reinstall with `pipenv install 'gues [reproduce_astat] == 0.2.0'` to "
            "ensure installed versions are supported."
        )
    # prevent deprecation warnings from cluttering output
    warnings.filterwarnings("ignore", category=DeprecationWarning)

    # check command line arguments to determine which results to compute
    try:
        how = sys.argv[1]
        valid = ("fresh", "download", "redo_plots")
        if how not in valid:
            raise ValueError(f"`{how}` not valid. Pick one of {valid}.")

        figures = sys.argv[2].split(",")
        valid = ("complexity", "hist", "uni", "tri_n", "n_8", "n_10")
        if not (figures == ["all"] or np.in1d(figures, valid).sum()):
            raise ValueError(
                f"`{figures}` not valid. Pick `all` or a subset of {valid}."
            )

    except IndexError:
        how = "fresh"
        figures = "all"

    path = "reproduced_astat_results/"
    os.makedirs(os.path.dirname(path), exist_ok=True)

    seed = 1312

    reproduction_dict = {
        "complexity": complexity,
        "hist": hist,
        "uni": uni,
        "tri_n": tri_n,
        "n_8": n_8,
        "n_10": n_10,
    }

    print(f"Saving results to {path}...")
    if figures == ["all"]:
        figures = reproduction_dict.keys()
    for fig in figures:
        print(f"Reproducing Figure {fig}...")
        reproduction_dict[fig](default_rng(seed), path, how)
    print("Done!")
