#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import click
import gatetools.phsp as phsp
import numpy as np
from matplotlib import pyplot as plt
import gaga

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('phsp_filename')
@click.argument('pth_filename')
@click.option('--n', '-n', default=1e4, help='Number of samples to generate')
@click.option('--nb_bins', '-b', default=int(200), help='Number of bins')
@click.option('--toggle/--no-toggle', default=False, help='convert angle to XY')
@click.option('--quantile', '-q', default=float(0), help='Restrict histogram to quantile')
@click.option('--radius', default=350, help='When convert angle, need the radius (in mm)')
@click.option('--plot2d',
              type=(str, str),
              help='Add 2D plots (key1,key2), such as --plot2d X Ekine --plot2d X Y ', multiple=True)
@click.option('--epoch', '-e', default=-1, help='Load the G net at the given epoch (-1 for last stored epoch)')
@click.option('--no-title', is_flag=True, default=False)
@click.option('--output', '-o', type=str, help='Do not plot, only output a pdf with the given name')
def gaga_plot(phsp_filename, pth_filename, n, nb_bins,
              toggle, radius, quantile, plot2d, epoch,
              output, no_title):
    """
    \b
    Plot marginal distributions from a GAN-PHSP

    \b
    <PHSP_FILENAME>   : input phase space file PHSP file (.npy)
    <PTH_FILENAME>    : input GAN PTH file (.pth)
    """

    # nb of values
    n = int(n)

    keys_2d = plot2d
    if keys_2d is None:
        keys_2d = []

    # load phsp
    real, r_keys, m = phsp.load(phsp_filename, nmax=n, shuffle=False)

    # load pth
    params, G, D, optim, dtypef = gaga.load(pth_filename, epoch=epoch)
    f_keys = params['keys']
    if isinstance(f_keys, str):
        f_keys = params['keys_list']
    keys = f_keys.copy()

    # generate samples
    fake = gaga.generate_samples2(params, G, D, n, int(1e5), False, True)

    # Keep X,Y or convert to toggle
    if toggle:
        keys = phsp.keys_toggle_angle(keys)

    real, r_keys = phsp.add_missing_angle(real, r_keys, keys, radius)
    fake, f_keys = phsp.add_missing_angle(fake, f_keys, keys, radius)

    real = phsp.select_keys(real, r_keys, keys)
    fake = phsp.select_keys(fake, f_keys, keys)

    # curate keys_2d
    k2 = []
    for k in keys_2d:
        if (k[1] in keys) and (k[0] in keys):
            k2.append(k)
    keys_2d = k2

    # fig panel
    nb_fig = len(keys) + len(keys_2d)
    nrow, ncol = phsp.fig_get_nb_row_col(nb_fig)
    fig, ax = plt.subplots(nrow, ncol, figsize=(25, 10))

    # plot all keys for real data
    i = 0
    q = {}
    for k in keys:
        index = keys.index(k)
        d = real[:, index]
        q1 = quantile
        q2 = 1.0 - quantile
        q[k] = (np.quantile(d, q1), np.quantile(d, q2))
        lab = ''
        if no_title:
            lab = 'PHSP '
        gaga.fig_plot_marginal(real, k, keys, ax, i, nb_bins, 'g', q[k], lab)
        i = i + 1

    # plot all keys for fake data (same range)
    i = 0
    for k in keys:
        if no_title:
            lab = 'GAN '
        gaga.fig_plot_marginal(fake, k, keys, ax, i, nb_bins, 'r', q[k], lab)
        i = i + 1

    # plot 2D distribution
    if len(keys) > 1:
        starti = i
        for kk in keys_2d:
            gaga.fig_plot_marginal_2d(real, kk[0], kk[1], keys, ax, i, nb_bins, 'g')
            i = i + 1

        # plot 2D distribution
        i = starti
        for kk in keys_2d:
            gaga.fig_plot_marginal_2d(fake, kk[0], kk[1], keys, ax, i, nb_bins, 'r')
            i = i + 1

        if False:
            for kk in keys_2d:
                a = phsp.fig_get_sub_fig(ax, i)
                gaga.fig_plot_diff_2d(real, fake, keys, kk, a, fig, nb_bins)
                i = i + 1

    # remove empty plot
    phsp.fig_rm_empty_plot(nb_fig, ax)

    if not no_title:
        plt.suptitle(pth_filename)
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)

    if output:
        plt.savefig(output)
    else:
        plt.show()
    plt.close()


# --------------------------------------------------------------------------
if __name__ == '__main__':
    gaga_plot()
