#!python
# -*- coding: utf-8 -*-

import sys
import garf
import ntpath
import time
import json
import socket
import torch
import click

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('param')
@click.argument('data')
@click.argument('output')
@click.option('--progress-bar/--no-progress-bar', default=True)
def garf_train(param, data, output, progress_bar):
    '''
    \b
    Train a ARF-nn (neural network) from a training dataset.
    <PARAM> :
    '''
    param_filename = param
    data_filename = data
    output_filename = output

    # Read training dataset
    data, theta, phi, E, w = garf.load_training_dataset(data_filename)

    # Read parameters
    param_file = open(param_filename).read()
    params = json.loads(param_file)
    params['progress_bar'] = progress_bar

    # Print info
    print('Training dataset',data_filename)
    garf.print_training_dataset_info(data)

    # Train
    x = data[:, 0:3]  # Input:  theta, phi, E
    y = data[:, 3]    # Output: w
    start = time.strftime("%c")
    print('\nNeural network parameters', param_filename)
    nn = garf.train_nn(x, y, params)

    # Add infos
    now = time.strftime("%c")
    hn = socket.gethostname()
    model_data = nn['model_data']
    model_data['training_filename'] = data_filename
    model_data['training_size'] = len(x)
    model_data['start date'] = start
    model_data['end date'] = now
    model_data['hostname'] = hn

    # save output model and associated data
    f = open(output_filename, 'wb')
    torch.save(nn, output_filename)
    print("\nNN saved to ", output_filename)


# -----------------------------------------------------------------------------
if __name__ == '__main__':
    garf_train()
