#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import click
import gaga
import gatetools.phsp as phsp
from torch.autograd import Variable
import torch
import numpy as np

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('phsp_filename')
@click.argument('pth_filename')
@click.option('--n', '-n', default=1e4, help='Number of samples to generate')
@click.option('--l', '-l', default=1e2, help='Number of projections')
@click.option('--p', '-p', default=1, help='Wasserstein distance power p=1 default')
@click.option('--keys', '-k', help='Plot the given keys (as a str list such that "X Y Z")', default='')
@click.option('--toggle/--no-toggle', default=False, help='Convert XY to angle')
@click.option('--radius', default=350, help='When convert angle, need the radius (in mm)')
@click.option('--normalize/--no-normalize', default=True,
              help='Normalize dimensions')
def gaga_wasserstein(phsp_filename, pth_filename, n, l, p, keys, toggle, radius, normalize):
    '''
    \b
    Compute sliced Wasserstein between real and GAN generated distributions

    \b
    <PHSP_FILENAME>   : input phase space file PHSP file (.npy)
    <PTH_FILENAME>    : input GAN PTH file (.pth)
    '''

    n = int(n)
    
    # load phsp
    real, r_keys, m = phsp.load(phsp_filename, n)

    # user keys
    keys = phsp.str_keys_to_array_keys(keys)
    
    # load pth
    params, G, D, optim, dtypef = gaga.load(pth_filename)
    f_keys = params['keys']
    if len(keys) == 0:
        keys = f_keys
            
    # generate samples (do NOT normalize yet)
    fake = gaga.generate_samples2(params, G, D, n, 1e5, False, True)
    
    # select keys
    real = phsp.select_keys(real, r_keys, keys)
    fake = phsp.select_keys(fake, f_keys, keys)

    # normalize for computing the wasserstein
    if normalize:
        x_mean = params['x_mean']
        x_std = params['x_std']
        f_keys = params['keys']
        x_mean, z = phsp.add_missing_angle(x_mean, f_keys, keys, radius)
        x_std, z = phsp.add_missing_angle(x_std, f_keys, keys, radius)
        x_mean = phsp.select_keys(x_mean, z, keys)
        x_std = phsp.select_keys(x_std, z, keys)        
        real = (real-x_mean)/x_std
        fake = (fake-x_mean)/x_std

    # convert to pytorch
    fake = Variable(torch.from_numpy(fake)).type(dtypef)
    real = Variable(torch.from_numpy(real)).type(dtypef)

    # distance
    d = gaga.sliced_wasserstein(real, fake, l, p)
    print(pth_filename, d, keys, n, l)
    

# --------------------------------------------------------------------------
if __name__ == '__main__':
    gaga_wasserstein()

