#!/usr/bin/env python

import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow import keras
from tensorflow.keras import layers

import wispy.utils
import wispy.logger
import wispy.callbacks
import wispy.model_utils
import numpy as np
import os
import argparse
import subprocess
import tomlkit
from tomlkit import parse
import datetime

# import tqdm
# import tqdm.keras


def numpy_into_tf_dataset(x, y, batch_size):
    """convert numpy arrays into tf.data.Dataset

    Args:
        x (numpy array): [description]
        y (numpy array): [description]
        batch_size (int): [description]

    Returns:
        [tensorflow Dataset]: [description]
    """

    return tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size)


def check_gpu():
    logger.info("running 'tf.config.list_physical_devices('GPU')'")
    logger.info(tf.config.list_physical_devices("GPU"))

    try:
        logger.info("running 'nvidia-smi -L'")
        subprocess.call(["nvidia-smi", "-L"])
    except FileNotFoundError:
        logger.info("could not run 'nvidia-smi -L'")


def set_gpu_memory_growth():
    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            logger.info("running: tf.config.experimental.set_memory_growth")
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            logger.info(e)
            raise e


def build_encoder(input_shape, latent_dim, units, acts):
    # build the encoder
    timeseries_input = keras.Input(shape=(input_shape,))
    for i, (unit, act) in enumerate(zip(units, acts)):
        if i == 0:
            x = layers.Dense(unit, activation=act)(timeseries_input)
        else:
            x = layers.Dense(unit, activation=act)(x)
    latent_output = layers.Dense(latent_dim)(x)
    encoder_model = keras.Model(timeseries_input, latent_output)
    logger.info("=" * 20)
    logger.info("encoder summary")
    logger.info("=" * 20)
    if args.verbose:
        encoder_model.summary()

    return encoder_model


def build_decoder(input_shape, latent_dim, units, acts):
    # build the decoder
    latent_input = keras.Input(shape=(latent_dim,))
    for i, (unit, act) in enumerate(zip(units, acts)):
        if i == 0:
            x = layers.Dense(unit, activation=act)(latent_input)
        else:
            x = layers.Dense(unit, activation=act)(x)
    timeseries_output = layers.Dense(input_shape, activation="linear")(x)
    decoder_model = keras.Model(latent_input, timeseries_output)
    logger.info("=" * 20)
    logger.info("decoder summary")
    logger.info("=" * 20)
    if args.verbose:
        decoder_model.summary()

    return decoder_model


def build_and_compile_autoencoder(
    input_shape, encoder_model, decoder_model, opt, loss, metrics
):
    # build the AE
    timeseries_input = keras.Input(shape=(input_shape,))
    latent = encoder_model(timeseries_input)
    timeseries_output = decoder_model(latent)
    ae_model = keras.Model(timeseries_input, timeseries_output)
    logger.info("=" * 20)
    logger.info("autoencoder summary")
    logger.info("=" * 20)
    if args.verbose:
        ae_model.summary()

    ae_model.compile(optimizer=opt, loss=loss, metrics=metrics)

    return ae_model


def run_ae_fit(
    train_dataset,
    input_shape,
    epochs,
    stop_threshold,
    opt,
    validation_dataset,
    latent_dim,
    verbose,
    loss,
    encoder_units,
    encoder_acts,
    decoder_units,
    decoder_acts,
    learning_rate_scheduler=None,
    metrics=None,
):
    encoder_model = build_encoder(input_shape, latent_dim, encoder_units, encoder_acts)
    decoder_model = build_decoder(input_shape, latent_dim, decoder_units, decoder_acts)
    ae_model = build_and_compile_autoencoder(
        input_shape, encoder_model, decoder_model, opt, loss, metrics
    )

    callbacks = []

    if learning_rate_scheduler:
        callbacks.append(learning_rate_scheduler)

    # tqdm_callback = tfa.callbacks.TQDMProgressBar()
    # tqdm_callback = tqdm.keras.TqdmCallback()
    # callbacks.append(tqdm_callback)

    callbacks.append(wispy.callbacks.ThresholdCallback(stop_threshold))

    logger.info("starting fit")
    start_time = datetime.datetime.now()
    ae_history = ae_model.fit(
        train_dataset,
        epochs=epochs,
        callbacks=callbacks,
        verbose=verbose,
        validation_data=validation_dataset,
    )
    end_time = datetime.datetime.now()

    duration = end_time - start_time
    logger.info("fit complete")
    logger.info(f"duration: {duration}")

    return ae_model, ae_history, encoder_model, decoder_model, duration


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description="""Fit training data using an autoencoder""",
    )

    parser.add_argument(
        "--config-file", type=str, help="path to workflow config.toml file"
    )
    parser.add_argument(
        "-v",
        help="""increase output verbosity
        no -v: WARNING(")
        -v: INFO
        -vv: DEBUG""",
        action="count",
        dest="verbose",
        default=0,
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="if given then will not throw an error if output directory exists",
    )

    overall_start_time = datetime.datetime.now()

    args = parser.parse_args()
    args_dict = vars(args)

    # https://stackoverflow.com/questions/14097061/easier-way-to-enable-verbose-logging
    level = min(2, args.verbose)  # capped to number of levels
    logger = wispy.logger.init_logger(level=level)

    logger.info(f"wispy version: {wispy.__version__}")
    logger.info(f"Using TensorFlow v{tf.__version__}")

    logger.info("==========")
    logger.info("printing command line args")
    for k in args_dict.keys():
        logger.info(f"{k}: {args_dict[k]}")
    logger.info("==========")

    with open(args.config_file, "r") as f:
        text = f.read()

    doc = parse(text)

    logger.info("==========")
    logger.info("printing toml config contents")
    wispy.utils.recursive_dict_print(doc, print_fn=logger.info)
    logger.info("==========")

    if doc["gpu"]["CUDA_VISIBLE_DEVICES"]:
        logger.info("setting CUDA_VISIBLE_DEVICES")
        os.environ["CUDA_VISIBLE_DEVICES"] = doc["gpu"]["CUDA_VISIBLE_DEVICES"]
        if doc["gpu"]["CUDA_VISIBLE_DEVICES"] != "-1":
            check_gpu()
            set_gpu_memory_growth()

    logger.info(f"making output dir: {doc['output']}")
    os.makedirs(f"{doc['output']}", exist_ok=args.force)

    logger.info("loading training data")
    logger.info(f"{doc['data']['y_train']}")
    y_train = np.load(doc["data"]["y_train"])
    logger.info(f"y_train.shape: {y_train.shape}")

    input_shape = y_train.shape[1]  # number of time points
    logger.info(f"input_shape for network: {input_shape}")

    logger.info("converting training data numpy arrays into tf Dataset")
    train_dataset = (
        numpy_into_tf_dataset(y_train, y_train, doc["fit"]["batch_size"])
        .cache()
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    logger.info("deleting y_train numpy")
    del y_train

    if "y_val" in doc["data"].keys():
        logger.info("found validation data")
        logger.info("loading validation data")
        logger.info(f"{doc['data']['y_val']}")
        y_val = np.load(doc["data"]["y_val"])
        logger.info(f"y_val.shape: {y_val.shape}")
        logger.info("converting validation data numpy arrays into tf Dataset")
        validation_dataset = (
            numpy_into_tf_dataset(y_val, y_val, doc["fit"]["batch_size"])
            .cache()
            .prefetch(tf.data.experimental.AUTOTUNE)
        )
        logger.info("deleting y_val numpy")
        del y_val
    else:
        logger.info("no validation data found")
        validation_dataset = None

    logger.info("setting up optimizer function")
    to_eval = "{package}.{function}".format(
        package=doc["fit"]["optimizer"]["package"],
        function=doc["fit"]["optimizer"]["name"],
    )
    logger.info(f"opt: {to_eval}")
    kwargs = doc["fit"]["optimizer"]["kwargs"]
    logger.info(f"opt kwargs: {kwargs}")
    opt = eval(to_eval)(**kwargs)

    if "learning_rate_scheduler" in doc["fit"]["callbacks"].keys():
        logger.info("setting up learning_rate_scheduler function")
        lrs_func_name = doc["fit"]["callbacks"]["learning_rate_scheduler"]["name"]
        to_eval = "{package}.{function}".format(
            package=doc["fit"]["callbacks"]["learning_rate_scheduler"]["package"],
            function=lrs_func_name,
        )
        logger.info("constructing learning_rate_scheduler function")
        logger.info(f"learning_rate_scheduler: {to_eval}")
        kwargs = doc["fit"]["callbacks"]["learning_rate_scheduler"]["kwargs"]
        logger.info(f"kwargs: {kwargs}")
        learning_rate_fn = eval(to_eval)(**kwargs)

        if lrs_func_name != "ReduceLROnPlateau":
            learning_rate_scheduler = tf.keras.callbacks.LearningRateScheduler(
                learning_rate_fn
            )
        else:
            learning_rate_scheduler = learning_rate_fn

    else:
        logger.info("not using learning_rate_scheduler")
        logger.info(
            f"learning_rate: {doc['fit']['optimizer']['kwargs']['learning_rate']}"
        )
        learning_rate_scheduler = None

    if "metrics" in doc["fit"]:
        metrics = doc["fit"]["metrics"]
        assert isinstance(metrics, list), "metrics must be a list"
    else:
        metrics = None

    ae_model, ae_history, encoder_model, decoder_model, duration = run_ae_fit(
        train_dataset=train_dataset,
        input_shape=input_shape,
        epochs=doc["fit"]["epochs"],
        stop_threshold=doc["fit"]["stop_threshold"],
        opt=opt,
        latent_dim=doc["fit"]["latent_dim"],
        validation_dataset=validation_dataset,
        verbose=doc["fit"]["verbose"],
        loss=doc["fit"]["loss"],
        encoder_units=doc["fit"]["encoder"]["units"],
        encoder_acts=doc["fit"]["encoder"]["activations"],
        decoder_units=doc["fit"]["decoder"]["units"],
        decoder_acts=doc["fit"]["decoder"]["activations"],
        learning_rate_scheduler=learning_rate_scheduler,
        metrics=metrics,
    )

    logger.info("saving model")

    try:
        filename = f"{doc['output']}/ae_model.h5"
        logger.info(f"saving autoencoder: {filename}")
        ae_model.save(f"{filename}")
    except TypeError as e:
        logger.warning(e)
        logger.warning("saving autoencoder failed!")

    filename = f"{doc['output']}/encoder_model.h5"
    logger.info(f"saving encoder: {filename}")
    encoder_model.save(f"{filename}")

    filename = f"{doc['output']}/decoder_model.h5"
    logger.info(f"saving decoder: {filename}")
    decoder_model.save(f"{filename}")

    filename = f"{doc['output']}/history.pickle"
    logger.info(f"saving history: {filename}")
    wispy.model_utils.save_history(ae_history.history, filename)
    filename = f"{doc['output']}/duration.pickle"
    logger.info(f"saving duration: {filename}")
    wispy.model_utils.save_datetime(duration, filename)