#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import click
import gatetools.phsp as phsp
import numpy as np
from matplotlib import pyplot as plt
import gaga
from scipy.stats import gaussian_kde

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('phsp_filename')
@click.argument('pth_filename')
@click.option('--n', '-n', default=1e4, help='Number of samples to generate')
@click.option('--epoch', '-e', default=-1, help='Load the G net at the given epoch (-1 for last stored epoch)')
def gaga_gauss_plot(phsp_filename, pth_filename, n, epoch):
    """
    \b
    Plot 2D mixture of Gaussian ref in phsp, gan in pth

    \b
    <PHSP_FILENAME>   : input reference phase space file PHSP file (.npy)
    <PTH_FILENAME>    : input GAN PTH file (.pth)
    """

    # nb of values
    n = int(n)

    # load phsp
    real, r_keys, m = phsp.load(phsp_filename, nmax=n, shuffle=True)

    # load pth
    params, G, D, optim, dtypef = gaga.load(pth_filename, epoch=epoch)
    f_keys = params['keys']
    if isinstance(f_keys, str):
        f_keys = params['keys_list']
    keys = f_keys.copy()

    # generate samples
    fake = gaga.generate_samples2(params, G, D, n, int(1e5), False, True)

    # get 2D points
    x_ref = real[:, 0]
    y_ref = real[:, 1]
    x = fake[:, 0]
    y = fake[:, 1]

    # plot
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    # boundary
    bnd = 10

    a  =ax[0]
    xy = np.vstack([x_ref, y_ref])
    z = gaussian_kde(xy)(xy)
    a.scatter(x_ref, y_ref, c=z, marker='.')
    a.axis('equal')
    a.set_xlim([-bnd, bnd])
    a.set_ylim([-bnd, bnd])

    a = ax[1]
    xy = np.vstack([x, y])
    z = gaussian_kde(xy)(xy)
    a.scatter(x, y, c=z, marker='.')
    a.axis('equal')
    a.set_xlim([-bnd, bnd])
    a.set_ylim([-bnd, bnd])

    plt.title(pth_filename)
    plt.show()


# --------------------------------------------------------------------------
if __name__ == '__main__':
    gaga_gauss_plot()
