#!python
# -*- coding: utf-8 -*-

import garf
import time
import json
import socket
import torch
import click

CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument("param")
@click.argument("data")
@click.argument("output")
@click.option("--progress-bar/--no-progress-bar", default=True)
def garf_train(param, data, output, progress_bar):
    """
    \b
    Train a ARF-nn (neural network) from a training dataset.
    <PARAM> :
    """
    param_filename = param
    data_filename = data
    output_filename = output

    # Read training dataset
    data, theta, phi, E, w = garf.load_training_dataset(data_filename)

    # Read parameters
    param_file = open(param_filename).read()
    params = json.loads(param_file)
    params["progress_bar"] = progress_bar

    # Print info
    print("Training dataset", data_filename)
    garf.print_training_dataset_info(data, params["RR"])

    # Train
    x = data[:, 0:3]  # Input:  theta, phi, E
    y = data[:, 3]  # Output: w
    start = time.strftime("%c")
    print("\nNeural network parameters", param_filename)
    nn = garf.train_nn(x, y, params)

    # Add infos
    now = time.strftime("%c")
    hn = socket.gethostname()
    model_data = nn["model_data"]
    model_data["training_filename"] = data_filename
    model_data["training_size"] = len(x)
    model_data["start date"] = start
    model_data["end date"] = now
    model_data["hostname"] = hn

    # save output model and associated data
    f = open(output_filename, "wb")
    torch.save(nn, output_filename)
    print("\nNN saved to ", output_filename)


# -----------------------------------------------------------------------------
if __name__ == "__main__":
    garf_train()
