#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import click
import gaga
import copy
import torch
from shutil import copyfile
from matplotlib import pyplot as plt

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('pth_filename', nargs=-1)
@click.option('--plot/--no-plot', default=False)
@click.option('--add_energy', default=float(-1),
              help='Add the key Ekine with the given value in the parameters of the pth file')
def gaga_info(pth_filename, plot, add_energy):
    '''
    \b
    Print information about a trained GAN phase space.
    If --plot option: plot the loss wrt epoch

    \b
    <PTH_FILENAME> : input PTH file (.pth)
    '''

    if plot:
        fig, ax = plt.subplots(1, 3, figsize=(16,5))

    for f in pth_filename:
        params, G, D, optim, dtypef = gaga.load(f)
        gaga.print_info(params, optim)
        if plot:
            gaga.plot_epoch(ax, optim, f)
            gaga.plot_epoch_wasserstein(ax, optim, f)

    if plot:
        plt.tight_layout()
        plt.savefig('a.pdf', dpi=fig.dpi)
        plt.show()

    if add_energy != -1:
        if len(pth_filename) != 1:
            print('Cannot add_energy to several pth_filename')
            exit(0)
        f = pth_filename[0]
        params['Ekine'] = add_energy
        if params['current_gpu']:
            nn = torch.load(f)
        else:
            nn = torch.load(f, map_location=lambda storage, loc: storage)
        nn['params'] = params
        copyfile(f, f+'.save')
        torch.save(nn, f)


# --------------------------------------------------------------------------
if __name__ == '__main__':
    gaga_info()
