#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import click
import gaga
import gatetools.phsp as phsp
import os

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('pth_filename')
@click.option('--n', '-n', default='1e4', help='Number of samples to generate')
@click.option('--output', '-o', default='AUTO', help='If AUTO, use pth_filename.npy')
@click.option('--output_folder', '-f', default=None, help='Output folder')
@click.option('--toggle/--no-toggle', default=False, help='Convert XY to angle')
@click.option('--epoch', default=-1, help='Use G at this epoch')
@click.option('--radius', default=350, help='When convert angle, need the radius (in mm)')
def gaga_generate(pth_filename, n, output, output_folder, toggle, radius, epoch):
    """
    Generate a PHSP from a (trained) GAN

    \b
    <PTH_FILENAME>    : input GAN PTH file (.pth)
    """

    init_n = str(n)
    n = int(float(n))

    # load pth
    params, G, D, optim, dtypef = gaga.load(pth_filename, 'auto', verbose=False, epoch=epoch)
    f_keys = list(params['keys_list'])

    # generate samples (b is batch size)
    b = 1e5
    fake = gaga.generate_samples2(params, G, D, n, b, False, True)

    # Keep X,Y or convert to toggle
    if toggle:
        keys = phsp.keys_toggle_angle(f_keys)
        fake, f_keys = phsp.add_missing_angle(fake, f_keys, keys, radius)
        fake = phsp.select_keys(fake, f_keys, keys)
    else:
        keys = f_keys

    # write    
    if output == 'AUTO':
        gp = params['penalty_type']
        gpw = params['penalty_weight']
        full_path = os.path.split(pth_filename)
        b, extension = os.path.splitext(full_path[1])
        if not output_folder:
            output_folder = '.'
        output = f'{b}_{gp}_{gpw}_{init_n}.npy'
        output = os.path.join(output_folder, output)
        print(output)
    phsp.save_npy(output, fake, keys)


# --------------------------------------------------------------------------
if __name__ == '__main__':
    gaga_generate()
