#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import click
import gaga
import gatetools.phsp as phsp
import torch
import numpy as np
from torch.autograd import Variable
from matplotlib import pyplot as plt

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('phsp_filename')
@click.argument('pth_filename')
@click.option('--n', '-n', default=1e4, help='Number of samples to generate')
@click.option('--nb_bins', '-b', default=int(200), help='Number of bins')
@click.option('--toggle/--no-toggle', default=False, help='convert angle to XY')
@click.option('--quantile', '-q', default=float(0), help='Restrict histogram to quantile')
@click.option('--radius', default=350, help='When convert angle, need the radius (in mm)')
@click.option('--plot2d',
              type=(str, str),
              help='Add 2D plots (key1,key2), such as --plot2d X Ekine --plot2d X Y ', multiple=True)
def gaga_plot(phsp_filename, pth_filename, n, nb_bins,
              toggle, radius, quantile, plot2d):
    '''
    \b
    Plot marginal distributions from a GAN-PHSP

    \b
    <PHSP_FILENAME>   : input phase space file PHSP file (.npy)
    <PTH_FILENAME>    : input GAN PTH file (.pth)
    '''

    # nb of values
    n = int(n)

    keys_2d = plot2d;
    if keys_2d == None:
        keys_2d = []

    # load phsp
    real, r_keys, m = phsp.load(phsp_filename, n)

    # load pth
    params, G, D, optim, dtypef= gaga.load(pth_filename)
    f_keys = params['keys']
    keys = f_keys.copy()

    # generate samples
    fake = gaga.generate_samples2(params, G, n, int(1e5), False, True)

    # Keep X,Y or convert to toggle
    if toggle:
        keys = phsp.keys_toggle_angle(keys)

    real, r_keys = phsp.add_missing_angle(real, r_keys, keys, radius)
    fake, f_keys = phsp.add_missing_angle(fake, f_keys, keys, radius)

    real = phsp.select_keys(real, r_keys, keys)
    fake = phsp.select_keys(fake, f_keys, keys)

    # curate keys_2d
    k2 = []
    for k in keys_2d:
        if (k[1] in keys) and (k[0] in keys):
            k2.append(k)
    keys_2d = k2

    # fig panel
    nb_fig = len(keys)+len(keys_2d)
    nrow, ncol = phsp.fig_get_nb_row_col(nb_fig)
    fig, ax = plt.subplots(nrow, ncol, figsize=(25,10))

    # plot all keys for real data
    i = 0

    q = {}
    for k in keys:
        index = keys.index(k)
        d = real[:,index]
        q1 = quantile
        q2 = 1.0-quantile
        q[k] = (np.quantile(d, q1), np.quantile(d, q2))
        gaga.fig_plot_marginal(real, k, keys, ax, i, nb_bins, 'g', q[k])
        i = i+1

    # plot all keys for fake data
    i = 0
    for k in keys:
        index = keys.index(k)
        d = real[:,index]
        #q1 = quantile
        #q2 = 1.0-quantile
        #q = (np.quantile(d, q1), np.quantile(d, q2))
        #print(q)
        gaga.fig_plot_marginal(fake, k, keys, ax, i, nb_bins, 'r', q[k])
        i = i+1


    # plot 2D distribution
    if len(keys) > 1:
        starti = i
        for kk in keys_2d:
            gaga.fig_plot_marginal_2d(real, kk[0], kk[1], keys, ax, i, nb_bins, 'g')
            i = i+1

        # plot 2D distribution
        i = starti
        for kk in keys_2d:
            gaga.fig_plot_marginal_2d(fake, kk[0], kk[1], keys, ax, i, nb_bins, 'r')
            i = i+1

        if False:
            for kk in keys_2d:
                a = phsp.fig_get_sub_fig(ax,i)
                gaga.fig_plot_diff_2d(real, fake, keys, kk, a, fig, nb_bins)
                i = i+1


    # remove empty plot
    phsp.fig_rm_empty_plot(nb_fig, ax)

    plt.suptitle(pth_filename)
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    plt.show()

    #output_filename = 'aa.png'
    #plt.savefig(output_filename)
    #plt.close()

# --------------------------------------------------------------------------
if __name__ == '__main__':
    gaga_plot()
