#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from shutil import copyfile
import click
import torch
from matplotlib import pyplot as plt
import gaga

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('pth_filename', nargs=-1)
@click.option('--plot/--no-plot', default=False)
@click.option('--add_energy', default=float(-1),
              help='Add the key Ekine with the given value in the parameters of the pth file')
@click.option('--sfig/--no-sfig', is_flag=True, default=False, help='special plot for figure')
def gaga_info(pth_filename, plot, add_energy, sfig):
    """
    \b
    Print information about a trained GAN phase space.
    If --plot option: plot the loss wrt epoch

    \b
    <PTH_FILENAME> : input PTH file (.pth)
    """

    if plot:
        fig, ax = plt.subplots(1, 1, figsize=(16, 5))  # or 1,3

    for f in pth_filename:
        params, G, D, optim, dtypef = gaga.load(f, fatal_on_unknown_keys=False)
        gaga.print_info(params, optim)
        # print(G)
        # print(D)
        if plot:
            if sfig:
                gaga.plot_epoch_fig(ax, params, optim, f)
            else:
                # gaga.plot_epoch(ax, params, optim, f)
                gaga.plot_epoch2(ax, params, optim, f)  # FIXME
            # gaga.plot_epoch_wasserstein(ax, optim, f)

    if plot:
        plt.tight_layout()
        plt.savefig('a.pdf', dpi=fig.dpi)
        plt.show()

    if add_energy != -1:
        if len(pth_filename) != 1:
            print('Cannot add_energy to several pth_filename')
            exit(0)
        f = pth_filename[0]
        params['Ekine'] = add_energy
        if params['current_gpu']:
            nn = torch.load(f)
        else:
            nn = torch.load(f, map_location=lambda storage, loc: storage)
        nn['params'] = params
        copyfile(f, f + '.save')
        torch.save(nn, f)


# --------------------------------------------------------------------------
if __name__ == '__main__':
    gaga_info()
