#!/usr/bin/env python

import wispy.logger
import wispy.waveutils
import numpy as np
import os
import argparse
import phenom
import tomlkit
from tomlkit import parse
import datetime
import lal
import lalsimulation as lalsim
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.use("agg")

mpl.rcParams.update(mpl.rcParamsDefault)
plt.style.use("ggplot")
mpl.rcParams.update({"font.size": 16})


def filter_is_table(doc, is_table):
    """basically performs a loop over the toml doc keys and only retains the keys that have .is_table() == True or False depending on `is_table` arg.

    Args:
        doc ([tomlkit.items.Table]): [description]
        is_table ([bool]):

    Returns:
        [list of str]: [description]
    """
    return list(filter(lambda x: doc[x].is_table() == is_table, doc.keys()))


def get_sub_table_names(doc):
    return filter_is_table(doc, True)


def get_table_common_options(doc):
    return filter_is_table(doc, False)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description="""Generate training data and save to disk for fitting

        Assumes that you want to work with amplitude and phase data only


        """,
    )

    parser.add_argument(
        "--config-file", type=str, help="path to workflow config.toml file"
    )
    parser.add_argument(
        "-v",
        help="""increase output verbosity
        no -v: WARNING
        -v: INFO
        -vv: DEBUG""",
        action="count",
        dest="verbose",
        default=0,
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="if given then will not throw an error if output directory exists",
    )

    overall_start_time = datetime.datetime.now()

    args = parser.parse_args()
    args_dict = vars(args)

    # https://stackoverflow.com/questions/14097061/easier-way-to-enable-verbose-logging
    level = min(2, args.verbose)  # capped to number of levels
    logger = wispy.logger.init_logger(level=level)

    logger.error(
        "have you fixed it yet? You need to change this so that the validation data uses the scaling from the training data"
    )
    import sys

    sys.exit()

    logger.info(f"wispy version: {wispy.__version__}")

    logger.info("==========")
    logger.info("printing command line args")
    for k in args_dict.keys():
        logger.info(f"{k}: {args_dict[k]}")
    logger.info("==========")

    with open(args.config_file, "r") as f:
        text = f.read()

    doc = parse(text)

    logger.info("==========")
    logger.info("printing toml config contents")
    for k in doc.keys():
        logger.info(f"{k}: {doc[k]}")
    logger.info("==========")

    logger.info(f"making output directory: {doc['output']}")
    os.makedirs(f"{doc['output']}", exist_ok=args.force)

    logger.info("looking for [approx] section")
    try:
        doc["approx"]
    except tomlkit.exceptions.NonExistentKey as error:
        logger.error(error)
        raise error

    approx_options = get_table_common_options(doc["approx"])
    logger.info("found the follow approx options")
    for k in approx_options:
        logger.info(f"{k} = {doc['approx'][k]}")

    approximants = get_sub_table_names(doc["approx"])
    logger.info("found the following approximants")
    logger.info(f"{approximants}")

    for approximant in approximants:
        logger.info(f"approximant = {approximant}")
        logger.info("options:")
        logger.info(f"{doc['approx'][approximant]}")

    logger.info("generating waveforms")
    wavegen = {}
    for approximant in approximants:
        start_time = datetime.datetime.now()
        logger.info(f"working: {approximant}")
        kwargs = dict(
            approx=approximant,
            dt=doc["approx"]["dt"],
            t_min=doc["approx"]["t_min"],
            t_max=doc["approx"]["t_max"],
        )
        if approximant == "NR":
            kwargs.update({"qlist": None, "nrfiles": doc["approx"]["NR"]["nrfiles"]})
            wavegen.update({approximant: wispy.waveutils.gen_model_waveforms(**kwargs)})
        else:
            q_min = doc["approx"][approximant]["q_min"]
            q_max = doc["approx"][approximant]["q_max"]
            dq = doc["approx"][approximant]["dq"]
            qs = np.arange(q_min, q_max, dq)
            kwargs.update({"qlist": qs})
            wavegen.update({approximant: wispy.waveutils.gen_model_waveforms(**kwargs)})
        end_time = datetime.datetime.now()
        duration = end_time - start_time
        logger.info(f"time taken: {duration}")

    logger.info("collecting together all data")

    logger.info("we are assuming that all waveforms are on a common time grid")
    common_times = wavegen[approximants[0]][0].times

    # collect together all amplitudes
    all_qs = []
    all_amps = []
    all_phases = []
    all_approximants = []
    for approximant in approximants:
        wfs = wavegen[approximant]
        all_qs.append([wf.__getattribute__("q") for wf in wfs])
        all_amps.append(np.array([wf.__getattribute__("amp") for wf in wfs]))
        all_phases.append(np.array([wf.__getattribute__("phase") for wf in wfs]))
        all_approximants.append(
            np.array([wf.__getattribute__("approximant_string") for wf in wfs])
        )
    all_qs = np.hstack((all_qs)).flatten()
    all_amps = np.vstack((all_amps))
    all_phases = np.vstack((all_phases))
    all_approximants = np.hstack((all_approximants)).flatten()

    logger.info("aggregated data")
    logger.info(f"mass-ratio data shape: {all_qs.shape}")
    logger.info(f"amplitude data shape: {all_amps.shape}")
    logger.info(f"phase data shape: {all_phases.shape}")

    logger.info("scaling data")
    logger.info("scaling amplitude by eta")
    all_eta = phenom.eta_from_q(all_qs)
    all_amps = all_amps / all_eta[:, np.newaxis]

    logger.info("scaling phase by leading order TaylorT3 term")
    tc = -doc["approx"]["t_min"] + 1
    logger.info(f"tc = {tc}")
    t3_leading = wispy.waveutils.taylorT3_leading_term(
        common_times, all_eta[:, np.newaxis], tc
    )
    all_phases = all_phases / t3_leading

    logger.info("scaling by constant to enforce maximum value is 1.0")

    scale_param = {}
    scale_param["amp"] = np.around(np.max(all_amps), 2)
    scale_param["phase"] = np.around(np.min(all_phases), 2)

    logger.info("scale_param:")
    for k, v in scale_param.items():
        logger.info(f"{k}: {v}")

    logger.info("dividing amp and phase by constant scale_param")
    all_amps = all_amps / scale_param["amp"]
    all_phases = all_phases / scale_param["phase"]

    logger.info("saving")

    out = os.path.join(doc["output"], "approximants.npy")
    logger.info(f"saving approximants: {out}")
    np.save(out, all_approximants)

    out = os.path.join(doc["output"], "mass-ratios.npy")
    logger.info(f"saving mass-ratios: {out}")
    np.save(out, all_qs)

    out = os.path.join(doc["output"], "times.npy")
    logger.info(f"saving times: {out}")
    np.save(out, common_times)

    out = os.path.join(doc["output"], "amplitude.npy")
    logger.info(f"saving amplitude: {out}")
    np.save(out, all_amps)

    out = os.path.join(doc["output"], "phase.npy")
    logger.info(f"saving phase: {out}")
    np.save(out, all_phases)

    out = os.path.join(doc["output"], "data_processing_params.npz")
    logger.info(f"saving data processing constants: {out}")
    np.savez(out, tc=tc, amp_scale=scale_param["amp"], phase_scale=scale_param["phase"])

    # plt.figure()
    # for amp in all_amps:
    #     plt.plot(common_times, amp)
    # plt.show()

    # plt.figure()
    # for phase in all_phases:
    #     plt.plot(common_times, phase)
    # plt.show()

    overall_end_time = datetime.datetime.now()
    overall_duration = overall_end_time - overall_start_time
    logger.info(f"total time: {overall_duration}")
    logger.info("finished!")