#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import click
import gaga
import phsp
import torch
import os
import numpy as np
from torch.autograd import Variable
from matplotlib import pyplot as plt

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('pth_filename')
@click.option('--n', '-n', default=1e4, help='Number of samples to generate')
@click.option('--output', '-o', default='AUTO', help='If AUTO, use pth_filename.npy')
@click.option('--toggle/--no-toggle', default=False, help='Convert XY to angle')
@click.option('--radius', default=350, help='When convert angle, need the radius (in mm)')
def gaga_plot(pth_filename, n, output, toggle, radius):
    '''
    \b
    Generate a PHSP from a GAN

    \b
    <PTH_FILENAME>    : input GAN PTH file (.pth)
    '''

    n = int(n)

    # load pth
    params, G, optim, dtypef= gaga.load(pth_filename)
    f_keys = list(params['keys'])

    # generate samples
    b = 1e5
    fake = gaga.generate_samples2(params, G, n, b, False, True)

    # Keep X,Y or convert to toggle
    if toggle:
        keys = phsp.keys_toggle_angle(f_keys)
        fake, f_keys = phsp.add_missing_angle(fake, f_keys, keys, radius)
        fake = phsp.select_keys(fake, f_keys, keys)
    else:
        keys = f_keys

    # special case (for retro-compatibility)
    try:
        i = keys.index('E')
        fake[:,i][fake[:,i] <0] = 0.00000000001 # E should not be zero !
    except:
        i = keys.index('Ekine')
        fake[:,i][fake[:,i] <0] = 0.00000000001 # E should not be zero !

    # write    
    if output == 'AUTO':
        b, extension = os.path.splitext(pth_filename)
        output = b+'.npy'
    phsp.save_npy(output, fake, keys)

# --------------------------------------------------------------------------
if __name__ == '__main__':
    gaga_plot()

