#!/usr/bin/env python

import numpy as np
import h5py
import sys
import os
from scipy.stats import kde
import argparse

epsilon = 1.0e-8

def collect_data(gamma, weights, frequencies, t_index, cutoff, max_freq):
    freqs = []
    mode_prop = []
    mode_weights = []
    for w, freq, g in zip(weights, frequencies, gamma):
        tau = 1.0 / np.where(g > 0, g, -1) / (2 * 2 * np.pi)
        tau = np.where(tau > 0, tau, 0)
        if cutoff:
            tau = np.where(tau < cutoff, tau, 0)
    
        if max_freq is None:
            freqs += list(freq)
            mode_prop += list(tau)
            mode_weights += [w] * len(freq)
        else:
            freqs += list(np.extract(freq < max_freq, freq))
            mode_prop += list(np.extract(freq < max_freq, tau))
            mode_weights += [w] * np.sum(freq < max_freq)
    
    x = np.array(freqs)
    y = np.array(mode_prop)
    z = np.array(mode_weights)

    return x, y, z

def run_KDE(x, y, z, nbins):
    """Running Gaussian-KDE

    http://stackoverflow.com/questions/19390320/scatterplot-contours-in-matplotlib
    """

    kernel = kde.gaussian_kde([x, y])
    xi, yi = np.mgrid[x.min():x.max():nbins*1j, y.min():y.max():nbins*1j]
    zi = kernel(np.vstack([xi.flatten(), yi.flatten()])).reshape(xi.shape)
    
    zi_max = np.max(zi)
    indices = []
    for i, r_zi in enumerate((zi.T)[::-1]):
        if indices:
            indices.append(nbins - i - 1)
        elif np.max(r_zi) > zi_max / 10:
            indices = [nbins - i - 1]
    
    ynbins = nbins ** 2 // len(indices)
    xi, yi = np.mgrid[x.min():x.max():nbins*1j, y.min():y.max():ynbins*1j]
    zi = kernel(np.vstack([xi.flatten(), yi.flatten()])).reshape(xi.shape)

    return xi, yi, zi, indices

def plot(plt, xi, yi, zi, indices, nbins, xmax, ymax):    
    #
    # Plotting
    #
    x_cut = []
    y_cut = []
    z_cut = []
    threshold = ymax / nbins * len(indices)
    threshold = threshold / nbins * (nbins - 1)
    for xv, yv, zv in zip(x, y, z):
        if (epsilon < yv and yv < threshold and
            epsilon < xv and xv < xmax - epsilon):
            x_cut.append(xv)
            y_cut.append(yv)
            z_cut.append(zv)
    
    fig = plt.figure()
    plt.pcolormesh(xi[:,:nbins], yi[:,:nbins], zi[:,:nbins])
    plt.colorbar()
    
    plt.scatter(x_cut, y_cut, c=z_cut, s=0.1)
    
    plt.xlim(xmin=0, xmax=xmax)
    plt.ylim(ymin=0, ymax=(np.max(y_cut) + epsilon))
    if title:
        plt.title(title, fontsize=20)
    plt.xlabel('Phonon frequency (THz)', fontsize=18)
    plt.ylabel('Lifetime (ps)', fontsize=18)

    return fig

# Arg-parser
parser = argparse.ArgumentParser(
    description="Plot property density with gaussian KDE")
parser.add_argument("--fmax", help="Max frequency to plot",
                    type=float, default=None)
parser.add_argument("--cutoff",
                    help=("Property (y-axis) below this value is included in "
                          "data before running Gaussian-KDE"),
                    type=float, default=None)
parser.add_argument('--nbins', type=int, default=100,
                    help=("Number of bins in which data are assigned, "
                          "i.e., determining resolution of plot")),
parser.add_argument('--temperature', type=float, default=300.0,
                    dest='temperature',
                    help='Temperature to output data at')
parser.add_argument("--title", dest="title", default=None, help="Plot title")
parser.add_argument('filenames', nargs='*')
args = parser.parse_args()

#
# Matplotlib setting
#
import matplotlib
matplotlib.use('Agg')            
import matplotlib.pyplot as plt
from matplotlib import rc
rc('text', usetex=True)
rc('font', family='serif')
# rc('font', serif='Times New Roman')
rc('font', serif='Liberation Serif')
# plt.rcParams['pdf.fonttype'] = 42

#
# Initial setting
#
if os.path.isfile(args.filenames[0]):
    f = h5py.File(args.filenames[0])
else:
    print("File %s doens't exist." % args.filenames[0])
    sys.exit(1)

if args.title:
    title = args.title
else:
    title = None

#
# Set temperature
#
temperatures = f['temperature'][:]
if len(temperatures) > 29:
    t_index = 30
else:
    t_index = 0
for i, t in enumerate(temperatures):
    if np.abs(t - args.temperature) < epsilon:
        t_index = i
        break

#
# Set data
#
weights = f['weight'][:]
frequencies = f['frequency'][:]
gammas = [f['gamma'][t_index],]
symbols = ['',]
if 'gamma_N' in f:
    gammas.append(f['gamma_N'][t_index])
    symbols.append('N')
if 'gamma_U' in f:
    gammas.append(f['gamma_U'][t_index])
    symbols.append('U')

#
# Run
#
for gamma, s in zip(gammas, symbols):
    x, y, z = collect_data(gamma, weights, frequencies,
                           t_index, args.cutoff, args.fmax)
    xi, yi, zi, indices = run_KDE(x, y, z, args.nbins)
    xmax = np.max(x)
    ymax = np.max(y)
    fig = plot(plt, xi, yi, zi, indices, args.nbins, xmax, ymax)
    if s:
        fig.savefig("lifetime-%s.png" % s)
    else:
        fig.savefig("lifetime.png")
    plt.close(fig)
