#!/usr/bin/env python

"""
dora the explorer

fits data using ANNs defined in explorer.py

has command line interface to experiment with hyperparameters
"""

import pugna.explorer
import pugna.callbacks
import pugna.data
import pugna.model_utils
import pugna.logger
import pugna.learning_rate_schedulers
import pandas as pd
import numpy as np
import tensorflow as tf
import datetime
import os
import subprocess
import argparse
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.use('agg')

mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update({"font.size": 16})

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--output-dir", type=str,
                        help="output directory", required=True)

    parser.add_argument("--X-scale-method", type=str, default="None",
                        help="sklearn method to scale X data", choices=["None", "MinMaxScaler", "StandardScaler"])
    parser.add_argument("--y-scale-method", type=str, default="None",
                        help="sklearn method to scale y data", choices=["None", "MinMaxScaler", "StandardScaler"])

    parser.add_argument("--X-data-train", type=str,
                        help="path to X.npy (training)", required=True)
    parser.add_argument("--y-data-train", type=str,
                        help="path to y.npy (training)", required=True)

    parser.add_argument("--X-data-val", type=str,
                        help="path to X.npy (validation)", required=True)
    parser.add_argument("--y-data-val", type=str,
                        help="path to y.npy (validation)", required=True)

    parser.add_argument("-v", "--verbose",
                        help="""
                        increase output verbosity
                        no -v: WARNING
                        -v: INFO
                        -vv: DEBUG
                        """,
                        action='count', default=0)

    # ANN arguments
    parser.add_argument("--epochs", type=int, default=100,
                        help="number of epochs to run for")

    parser.add_argument("--n-blocks", type=int, default=3,
                        help="number of blocks")
    parser.add_argument("--units-per-layer", type=int, default=256,
                        help="number of units per layers")
    parser.add_argument("--layers-per-block", type=int, default=3,
                        help="number of Dense layers before a BatchNormalisation Layer is added if batch_norm is True")
    parser.add_argument("--activation", type=str, default='leaky_relu',
                        help="name of activation function for each layer",
                        choices=['relu', 'leaky_relu', 'prelu',
                                 'elu', 'softplus', 'tanh', 'selu']
                        )
    parser.add_argument("--leaky-relu-alpha", type=float, default=0.3,
                        help="leaky-relu alpha parameter")

    parser.add_argument("--learning-rate", type=float, default=0.001,
                        help="learning rate. If using lrs then this is the initial learning rate.")

    parser.add_argument("--use-lrs", help="use learning rate schedular. By default this is false.",
                        action="store_true")
    parser.add_argument("--lrs-final-learning-rate", type=float, default=1e-5,
                        help="final learning rate.")
    parser.add_argument("--lrs-decay-rate", type=int, default=10,
                        help="""learning rate schedular (lrs) option.
                        factor by which learning rate decays
                        """)
    parser.add_argument("--lrs-decay-steps", type=int, default=1000,
                        help="""learning rate schedular (lrs) option.
                        number of epochs before decaying.
                        """)

    parser.add_argument("--batch-norm",
                        help="""add BatchNormalization layers in each hidden
                        layer before activation function. Default is false.""",
                        action="store_true")

    parser.add_argument("--optimizer", type=str, default='adam',
                        help="name of optimizer to use",
                        choices=['adam'])
    parser.add_argument("--loss", type=str, default='mse',
                        help="name of loss function to use",
                        choices=['mse', 'mape'])

    parser.add_argument("--batch-size-factor", type=float, default=1,
                        help="""mini-batch size specified by positive factor.
                        batch size is calculated as X.shape[0] / batch_size_factor
                        If 1 then mini-batch is the entire data set.""")

    args = parser.parse_args()

    # https://stackoverflow.com/questions/14097061/easier-way-to-enable-verbose-logging
    level = min(2, args.verbose)  # capped to number of levels
    logger = pugna.logger.init_logger(level=level)
    logger.info("running dora")
    logger.info(f"verbosity turned on at level: {level}")

    tf.config.threading.set_inter_op_parallelism_threads(1)
    tf.config.threading.set_intra_op_parallelism_threads(1)
    logger.info(
        f"tf using {tf.config.threading.get_inter_op_parallelism_threads()} inter_op_parallelism_threads thread(s)")
    logger.info(
        f"tf using {tf.config.threading.get_intra_op_parallelism_threads()} intra_op_parallelism_threads thread(s)")

    if "OMP_NUM_THREADS" not in os.environ:
        logger.info("'OMP_NUM_THREADS' not set. Setting it now.")
        os.environ["OMP_NUM_THREADS"] = "1"
    logger.info(f"OMP_NUM_THREADS: {os.environ['OMP_NUM_THREADS']}")

    if int(os.environ['OMP_NUM_THREADS']) != 1:
        logger.warning(
            f"OMP_NUM_THREADS is not 1! value: {os.environ['OMP_NUM_THREADS']}")

    logger.info("running 'tf.config.list_physical_devices('GPU')'")
    logger.info(tf.config.list_physical_devices('GPU'))

    try:
        logger.info("running 'nvidia-smi -L'")
        subprocess.call(["nvidia-smi", "-L"])
    except FileNotFoundError:
        logger.info("could not run 'nvidia-smi -L'")

    logger.info("NN SETTINGS:")
    logger.info(f"epochs: {args.epochs}")
    logger.info(f"n-blocks: {args.n_blocks}")
    logger.info(f"units-per-layer: {args.units_per_layer}")
    logger.info(f"layers-per-block: {args.layers_per_block}")
    logger.info(f"activation: {args.activation}")
    logger.info(f"leaky-relu-alpha: {args.leaky_relu_alpha}")
    logger.info(f"learning-rate: {args.learning_rate}")
    logger.info(f"use-lrs: {args.use_lrs}")
    logger.info(f"lrs-final-learning-rate: {args.lrs_final_learning_rate}")
    logger.info(f"lrs-decay-rate: {args.lrs_decay_rate}")
    logger.info(f"lrs-decay-steps: {args.lrs_decay_steps}")
    logger.info(f"batch-norm: {args.batch_norm}")
    logger.info(f"optimizer: {args.optimizer}")
    logger.info(f"loss: {args.loss}")
    logger.info(f"batch-size-factor: {args.batch_size_factor}")

    logger.info(f"making output dir: {args.output_dir}")
    os.makedirs(f"{args.output_dir}", exist_ok=True)

    logger.info(f"loading: {args.X_data_train}")
    X_train = np.load(args.X_data_train)

    logger.info(f"loading: {args.y_data_train}")
    y_train = np.load(args.y_data_train)

    logger.info(f"loading: {args.X_data_val}")
    X_val = np.load(args.X_data_val)

    logger.info(f"loading: {args.y_data_val}")
    y_val = np.load(args.y_data_val)

    logger.info(f"X_train.shape: {X_train.shape}")
    logger.info(f"y_train.shape: {y_train.shape}")
    logger.info(f"X_val.shape: {X_val.shape}")
    logger.info(f"y_val.shape: {y_val.shape}")

    logger.info(f"X scale method: {args.X_scale_method}")
    logger.info(f"y scale method: {args.y_scale_method}")

    if args.X_scale_method == "None":
        X_train_scaled = X_train.copy()
        X_val_scaled = X_val.copy()
    else:
        X_scalers = pugna.data.make_scalers(
            X_train, method=args.X_scale_method)
        X_train_scaled = pugna.data.apply_scaler(
            X_train, X_scalers)
        X_val_scaled = pugna.data.apply_scaler(
            X_val, X_scalers)

        outname = os.path.join(
            args.output_dir, f"X_scalers.npy")

        pugna.data.save_scalers(
            Scalers=X_scalers, filename=outname)

    if args.y_scale_method == "None":
        y_train_scaled = y_train.copy()
        y_val_scaled = y_val.copy()
    else:
        y_scalers = pugna.data.make_scalers(
            y_train, method=args.y_scale_method)
        y_train_scaled = pugna.data.apply_scaler(
            y_train, y_scalers)
        y_val_scaled = pugna.data.apply_scaler(
            y_val, y_scalers)

        outname = os.path.join(
            args.output_dir, f"y_scalers.npy")

        pugna.data.save_scalers(
            Scalers=y_scalers, filename=outname)

    ####
    batch_size = int(X_train.shape[0] / args.batch_size_factor)

    logger.info(f"X_train.shape: {X_train.shape[0]}")
    logger.info(f"batch size factor: {args.batch_size_factor}")
    logger.info(f"batch size: {batch_size}")

    if args.loss == 'mse':
        loss = 'mse'
    elif args.loss == 'mape':
        loss = tf.keras.losses.MeanAbsolutePercentageError()

    if args.optimizer:
        optimizer = tf.keras.optimizers.Adam

    # learning_rate_fn = tf.keras.optimizers.schedules.InverseTimeDecay(
    #   args.learning_rate, args.lrs_decay_steps, args.lrs_decay_rate, staircase=True)

    learning_rate_fn = pugna.learning_rate_schedulers.InverseTimeDecay_WithFinalLR(
        args.learning_rate,
        args.lrs_final_learning_rate,
        args.lrs_decay_steps,
        args.lrs_decay_rate,
        staircase=True)

    logger.info("begining fits")

    X = X_train_scaled.copy()
    y = y_train_scaled.copy()
    X_val = X_val_scaled.copy()
    y_val = y_val_scaled.copy()

    input_dim = X.shape[1]
    output_dim = y.shape[1]

    model = explorer.build_model(
        input_dim=input_dim,
        output_dim=output_dim,
        units_per_layer=args.units_per_layer,
        layers_per_block=args.layers_per_block,
        n_blocks=args.n_blocks,
        activation=args.activation,
        batch_norm=args.batch_norm,
        summary=True)
    model = explorer.compile_model(
        model,
        learning_rate=args.learning_rate,
        loss=loss,
        optimizer=optimizer)
    callbacks = [pugna.callbacks.PrintDot()]
    if args.use_lrs:
        callbacks.append(
            tf.keras.callbacks.LearningRateScheduler(learning_rate_fn))

    starttime = datetime.datetime.now()
    history = model.fit(
        X,
        y,
        epochs=args.epochs,
        batch_size=batch_size,
        verbose=0,
        callbacks=callbacks,
        validation_data=(X_val, y_val)
    )
    endtime = datetime.datetime.now()

    duration = endtime - starttime

    logger.info("fits complete")
    logger.info(f"The time cost: {duration}")

    logger.info("saving model")
    outname = os.path.join(args.output_dir, f"model")
    pugna.model_utils.save_model_json(model, outname)
    pugna.model_utils.save_model_h5(model, outname)

    logger.info("saving history")
    outname = os.path.join(args.output_dir, f"history.pickle")
    pugna.model_utils.save_history(history.history, outname)

    last_loss = history.history['loss'][-1]
    logger.info(f"last loss: {last_loss}")
    last_val_loss = history.history['val_loss'][-1]
    logger.info(f"last val_loss: {last_val_loss}")

    if args.use_lrs:
        outname = os.path.join(args.output_dir, "lr.png")
        logger.info("saving learning rate plot")
        plt.figure(figsize=(14, 7))
        plt.plot(history.history['lr'])
        plt.yscale('log')
        plt.savefig(outname, bbox_inches='tight')
        plt.close()

    outname = os.path.join(args.output_dir, "loss.png")
    logger.info("saving loss plot")
    plt.figure(figsize=(14, 7))
    plt.plot(history.history['loss'], label='loss')
    plt.plot(history.history['val_loss'], label='val_loss')
    plt.yscale('log')
    plt.legend(loc='center left', fancybox=True,
               framealpha=0., bbox_to_anchor=(1.05, 0.5))
    plt.savefig(outname, bbox_inches='tight')
    plt.close()
