#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import click
import json
import time
import socket
import gatetools.phsp as phsp
import gaga
import copy
import numpy as np
from colorama import init
from colorama import Fore, Style
import torch
import datetime

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('phsp_filename',
                type=click.Path(exists=True, file_okay=True, dir_okay=False))
@click.argument('json_filename',
                type=click.Path(exists=True, file_okay=True, dir_okay=False))
@click.argument('output_filename')
@click.option('--epoch', '-e',
              help='Number of epoch (overwrite the epoch value of the json file)',
              type=int,
              default=None)
@click.option('--progress-bar/--no-progress-bar', default=True)
@click.option('--plot/--no-plot', default=False)
@click.option('--plot-every-epoch', default=400)
@click.option('--w-e', default=-1, help='[experimental] Compute sliced wasserstein every n epoch (slow)')
@click.option('--w-n', default=1e5, help='[experimental] Wasserstein nb of samples')
@click.option('--w-l', default=1e2, help='[experimental] Wasserstein nb of direction')
@click.option('--w-p', default=1, help='[experimental] Wasserstein distance p')
@click.option('--keys', '-k', default='',
              help='only learn the given keys, replace keys in json (as a str list such that "X Y Z")')
@click.option('--validation-dataset', default=None,
              help='Validation dataset filename, to compute loss per epoch (slow)',
              type=click.Path(exists=True, file_okay=True, dir_okay=False))
@click.option('--validation-every-epoch', default=400,
              help='Compute loss on validation dataset every epoch')
@click.option('--start-pth', default=None,
              help='Start from already computed G and D, read in the given file')
def gaga_train(phsp_filename,
               json_filename,
               output_filename,
               epoch,
               progress_bar,
               plot,
               plot_every_epoch,
               w_e, w_n, w_l, w_p,
               keys,
               validation_dataset,
               validation_every_epoch,
               start_pth):
    '''
    \b
    Train GAN to learn a PHSP (Phase Space File)

    \b
    <PHSP_FILENAME>   : input PHSP file (.npy)
    <JSON_FILENAME>   : input json file with all GAN parameters
    <OUTPUT_FILENAME> : output GAN as pth file
    '''

    # term color
    init()
    pkeys = keys

    # read parameters
    param_file = open(json_filename).read()
    params = json.loads(param_file)
    params['progress_bar'] = progress_bar
    params['plot'] = plot
    params['plot_every_epoch'] = plot_every_epoch
    params['training_filename'] = phsp_filename
    params['validation_filename'] = validation_dataset
    params['validation_every_epoch'] = validation_every_epoch
    start = datetime.datetime.now()
    params['start date'] = start.strftime(gaga.date_format)
    params['hostname'] = socket.gethostname()
    params['dump_wasserstein_every'] = int(w_e)
    params['w_n'] = int(w_n)
    params['w_l'] = int(w_l)
    params['w_p'] = int(w_p)
    params['start_pth'] = start_pth

    # the epoch parma in the json file may be overwritten by the option
    if epoch:
        params['epoch'] = epoch

    # read input training dataset
    print(Fore.CYAN +"Loading training dataset ... "+phsp_filename+Style.RESET_ALL)
    x, read_keys, m = phsp.load(phsp_filename)

    # consider only some keys
    if 'keys' in params:
        keys = params['keys']
        if pkeys != '':
            keys = pkeys
        keys = phsp.str_keys_to_array_keys(keys)
        if 'angleXY' in keys:
            x, read_keys = phsp.add_angle(x, read_keys, 'X', 'Y')
        x = phsp.select_keys(x, read_keys, keys)
    else:
        keys = read_keys

    params['training_size'] = len(x)
    params['keys'] = keys
    params['x_dim'] = len(keys)

    # normalisation
    x_mean = np.mean(x, 0, keepdims=True)
    x_std = np.std(x, 0, keepdims=True)
    params['x_mean'] = x_mean
    params['x_std'] = x_std
    x = (x-x_mean)/x_std

    # print parameters
    for e in params:
        if (e[0] != '#'):
            print('   {:22s} {}'.format(e, str(params[e])))

    # train
    print(Fore.CYAN +'Building the GAN model ...'+Style.RESET_ALL)
    gan = gaga.Gan(params,x)
    print(Fore.CYAN +'Start training ...'+Style.RESET_ALL)
    optim = gan.train()

    # save
    stop = datetime.datetime.now()
    params['end date'] = stop.strftime(gaga.date_format)
    output = dict()
    output['params'] = params
    output['optim'] = optim
    state = copy.deepcopy(gan.G.state_dict())
    output['g_model_state'] = state
    state = copy.deepcopy(gan.D.state_dict())
    output['d_model_state'] = state

    torch.save(output, output_filename)


# --------------------------------------------------------------------------
if __name__ == '__main__':
    gaga_train()
