#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import click
import gaga
import torch
import os
import json
from torch.autograd import Variable

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('pth_filename')
@click.option('--output', '-o', default='auto', help="If 'auto', use filename.pt and filename.json")
@click.option('--gpu/--no-gpu', default=False)
@click.option('--verbose/--no-verbose', '-v', default=False)
@click.option('--keys', '-k', type=(str, float), multiple=True)
def gaga_convert_pth_to_pt(pth_filename, output, gpu, verbose, keys):
    '''
    \b
    Convert a .pth file from gaga to a .pt and a .json files (used by libtorch in Gate)

    \b
    <PTH_FILENAME> : input GAN PTH file (.pth)
    '''

    # FIXME Harmonize with create_pt for GARF

    # output filename
    if output == 'auto':
        b, extension = os.path.splitext(pth_filename)
        output_pt = b+'.pt'
        output_json = b+'.json'
    else:
        output_pt = output+'.pt'
        output_json = output+'.json'

    if gpu:
        gpu = 'true'
    else:
        gpu = 'false'

    # load pth
    params, G, D, optim, dtypef= gaga.load(pth_filename, gpu, verbose)
        
    b = 10
    z_dim = params['z_dim']
    z = Variable(torch.randn(b, z_dim)).type(dtypef)

    # generate trace
    traced_script_module = torch.jit.trace(G, z)
    traced_script_module.save(output_pt)
    if verbose:
        print('Writing ', output_pt)

    #Save dict nn into json
    p = {}
    params["x_mean"] = params["x_mean"][0].tolist()
    params["x_std"] = params["x_std"][0].tolist()
    params['d_nb_weights'] = int(params['d_nb_weights'])
    params['g_nb_weights'] = int(params['g_nb_weights'])

    for k in keys:
        params[k[0]] = k[1];
    
    outfile = open(output_json, 'w')
    if verbose:
        print('Writing ', output_json)
    json.dump(params, outfile)


# --------------------------------------------------------------------------
if __name__ == '__main__':
    gaga_convert_pth_to_pt()

