#!python
# -*- coding: utf-8 -*-

import torch
import click
from torch.autograd import Variable
import json
import garf
import os

# -----------------------------------------------------------------------------
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('pth_filename', nargs=1)
@click.option('--output', '-o', default='auto', help="If 'auto', use filename.pt and filename.json")
@click.option('--gpu/--no-gpu', default=False)
@click.option('--verbose/--no-verbose', '-v', default=False)
def garf_convert_pth_to_pt(pth_filename, output, gpu, verbose):
    '''
    \b
    Convert a .pth file from garf to a .pt and a .json files (used by libtorch in Gate)

    \b
    <PTH_FILENAME> : input GARF PTH file (.pth)
    '''

    # output filename
    if output == 'auto':
        b, extension = os.path.splitext(pth_filename)
        output_pt = b+'.pt'
        output_json = b+'.json'
    else:
        output_pt = output+'.pt'
        output_json = output+'.json'

    nn, model = garf.load_nn(pth_filename, gpu=gpu, verbose=False)
    p = nn['model_data']

    x_mean = p['x_mean']
    x_std = p['x_std']
    rr = p['RR']
    if verbose:
        print('Found Russian Roulette value: ', rr)

    dtypef = torch.FloatTensor
    if (gpu):
        dtypef = torch.cuda.FloatTensor
        model.cuda()

    b = 10
    dim = len(x_mean)
    z = Variable(torch.randn(b, dim)).type(dtypef)

    # generate trace
    traced_script_module = torch.jit.trace(model, z)
    traced_script_module.save(output_pt)
    if verbose:
        print('Write model to', output_pt)

    # Save dict nn into json
    p["x_mean"] = p["x_mean"].tolist()
    print(p["x_mean"])
    p["x_std"] = p["x_std"].tolist()
    p["loss_values"] = ''
    outfile = open(output_json, 'w')
    if verbose:
        print('Write info  to', output_json)
    json.dump(p, outfile)

# -----------------------------------------------------------------------------
if __name__ == '__main__':
    garf_convert_pth_to_pt()
