#!/usr/bin/env python

import os
import glob
from pathlib import Path
import numpy as np

import wispy.model_utils
import wispy.utils

import wispy.logger
import argparse

import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.use("agg")

from cycler import cycler
from itertools import cycle

mpl.rcParams.update(mpl.rcParamsDefault)
plt.style.use("ggplot")
mpl.rcParams.update({"font.size": 16})

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description="""Fit training data using an autoencoder""",
    )

    parser.add_argument(
        "--glob-str",
        type=str,
        help="""argument to pass to glob
    Could be something like 'runs/run_0*/fit/history.pickle'
    """,
    )

    parser.add_argument(
        "--output-dir",
        type=str,
        help="output directory name",
    )

    parser.add_argument(
        "--metric-name",
        type=str,
        help="name of the metric to plot. e.g. 'mse' - if not given then will plot the loss",
        default="loss",
    )

    parser.add_argument(
        "--fit-config-file",
        type=str,
        help="path to fit config toml file",
        required=False,
    )

    parser.add_argument(
        "-v",
        help="""increase output verbosity
        no -v: WARNING(")
        -v: INFO
        -vv: DEBUG""",
        action="count",
        dest="verbose",
        default=0,
    )

    args = parser.parse_args()
    args_dict = vars(args)

    # https://stackoverflow.com/questions/14097061/easier-way-to-enable-verbose-logging
    level = min(2, args.verbose)  # capped to number of levels
    logger = wispy.logger.init_logger(level=level)

    logger.info(f"wispy version: {wispy.__version__}")

    logger.info("==========")
    logger.info("printing command line args")
    for k in args_dict.keys():
        logger.info(f"{k}: {args_dict[k]}")
    logger.info("==========")

    logger.info(f"making output dir: {args.output_dir}")
    os.makedirs(f"{args.output_dir}", exist_ok=True)

    history_files = glob.glob(args.glob_str)

    logger.info("setting up cycler")
    # using these links I can create an extension to the standard
    # mpl cycle which then iterates over linestyle to get some
    # more unique styles
    # https://matplotlib.org/stable/tutorials/intermediate/color_cycle.html
    # https://matplotlib.org/cycler/
    ls_cycle = cycler("ls", ["-", "--", "-."])
    sty_cycle = ls_cycle * plt.rcParams["axes.prop_cycle"]
    plt.rc("axes", prop_cycle=sty_cycle)

    metric_name = args.metric_name
    metric_val_name = f"val_{metric_name}"

    logger.info(f"metric_name: {metric_name}")
    logger.info(f"metric_val_name: {metric_val_name}")

    hists = {}
    # names = []
    for hf in history_files:
        if args.fit_config_file:
            config_file = args.fit_config_file
        else:
            config_file = glob.glob(os.path.join(hf.split("/fit")[0], "config.toml"))[0]
        run_dir_name = Path(os.path.realpath(config_file)).parent.name
        doc = wispy.utils.read_and_parse_toml(config_file)
        # name = doc["fit"]["optimizer"]["name"]
        # names.append(name)
        # tag = f"{run_dir_name}:{name}"
        tag = f"{run_dir_name}"
        hist = wispy.model_utils.load_history(hf)
        last_metric = hist[metric_name][-1]

        # print_string = f"{name}: last loss = {last_metric:.4e}"
        print_string = f"{tag}: last {metric_name} = {last_metric:.4e}"
        try:
            last_val_metric = hist[metric_val_name][-1]
            print_string += f" - last {metric_val_name} = {last_val_metric:.4e}"
        except:
            logger.info("no validation found")
        logger.info(print_string)

        # if no variable learning rate then look up what the constant learning
        # rate was and insert it
        if "lr" not in hist.keys():
            lr = doc["fit"]["optimizer"]["kwargs"]["learning_rate"]
            hist.update({"lr": lr * np.ones_like(hist[metric_name])})

        hists.update({tag: hist})

    legend_kwargs = dict(
        loc="center left", bbox_to_anchor=(1, 0.5), ncol=2, prop={"size": 6}
    )

    logger.info("plotting")

    plt.figure(figsize=(10, 12))
    plt.subplot(2, 1, 1)
    for k, v in hists.items():
        plt.plot(v[metric_name], label=k)
    plt.xscale("log")
    plt.yscale("log")
    plt.legend(**legend_kwargs)
    plt.title(metric_name)

    plt.subplot(2, 1, 2)
    for k, v in hists.items():
        if metric_val_name in v.keys():
            plt.plot(v[metric_val_name], label=k)
    plt.xscale("log")
    plt.yscale("log")
    plt.legend(**legend_kwargs)
    plt.title(f"{metric_val_name}")
    plt.tight_layout()
    outname = os.path.join(args.output_dir, f"{metric_name}.png")
    logger.info(f"plot save: {outname}")
    plt.savefig(outname)
    plt.close()

    plt.figure(figsize=(10, 12))
    for k, v in hists.items():
        if "lr" in v.keys():
            plt.plot(v["lr"], label=k)
    plt.xscale("log")
    plt.yscale("log")
    plt.legend(**legend_kwargs)
    plt.title("learning rate")
    plt.tight_layout()
    outname = os.path.join(os.path.join(args.output_dir, "lr.png"))
    logger.info(f"plot save: {outname}")
    plt.savefig(outname)
    plt.close()
