#!/usr/bin/env python
"""
mbank_validate_metric
---------------------

A script to make some plots to validate the metric at a given point in space.
It makes the following plots:
	- Match vs distance parabolae
	- Constant match ellipse
	- Metric accuracy histogram
	- epsilon validation


To validate the metric:

	mbank_validate_metric --options-you-like
	
You can also load (some) options from an ini-file:

	mbank_validate_metric --some-options other_options.ini

Make sure that the mbank is properly installed.
To know which options are available:

	mbank_validate_metric --help
"""

import numpy as np
import matplotlib.pyplot as plt
import sys
from scipy.spatial import Rectangle
import warnings
from tqdm import tqdm

from mbank.metric import cbc_metric
from mbank.handlers import tiling_handler, tile, variable_handler
from mbank.utils import plot_tiles_templates, load_PSD, get_boundaries_from_ranges
import mbank.parser

import argparse
import os

#########################################################################

def plot_parabolae(center, metric_hessian, parabolae):

	eigvals, eigvecs = np.linalg.eig(metric_hessian)
	
		#Try to recalibrate eigenvalues along direction of L
	#L = np.linalg.cholesky(metric_hessian).T
	#L_inv = np.linalg.inv(L)
	#eigvecs = (L_inv / np.linalg.norm(L_inv.T, axis = 0)).T
	#eigvals = eigvals*0+1.

		##
		# Plotting & computing eigs
	y_2nd_sin = lambda x, a: np.square(np.sin(np.sqrt(a*x**2)))
	y_2nd = lambda x, a, b: a*x**2 + b
	y_4th = lambda x, a, b, c: a*x**4 +b*x**2 +c
	y_6th = lambda x, a, b, c, d: a*x**6 +b*x**6 +c*x**2+d
	fig, axes = plt.subplots(len(eigvals), 1, sharex = True, figsize = (10, 2*len(eigvals)))
	plt.suptitle("center = {}".format(center))
	new_eigvals = []

	for d, (ax, eigval, parabola) in enumerate(zip(axes, eigvals, parabolae)):
		ax.set_title("{} | Eigvec {}".format(d, np.round(eigvecs[:,d],3)))
		parabola[:,1] = 1- parabola[:,1] #making the match a distance
		ax.plot(parabola[:,0], parabola[:,1], 'o', c = 'g')
		
			#doing the fitting
		#p = np.polyfit(parabola[:,0], parabola[:,1], 2) #parabolic fit
		p_2nd = np.polyfit(parabola[:,0]**2, parabola[:,1], 1) #parabolic fit
		p_4th = np.polyfit(parabola[:,0]**2, parabola[:,1], 2) #quartic fit
		p_6th = np.polyfit(parabola[:,0]**2, parabola[:,1], 3) #quartic fit
		
		x = np.sort(parabola[:,0])
		ax.plot(x, y_2nd(x, *p_2nd), '--', label = '2nd fit: {}'.format(np.format_float_scientific(p_2nd[0],2)))
		ax.plot(x, y_4th(x, *p_4th), ':', label = '4th fit: {} | {}'.format(np.format_float_scientific(p_4th[1],2), np.format_float_scientific(p_4th[0],2)))
		#ax.plot(x, y_6th(x, *p_6th), '-.', label = '6th fit: {} | {} | {}'.format(np.format_float_scientific(p_6th[2],2),
		#		np.format_float_scientific(p_6th[1],2), np.format_float_scientific(p_6th[0],2)))
		ax.plot(x, y_2nd(x, eigval, 0), '-', label = 'hessian: {}'.format(np.format_float_scientific(eigval,2)))
		#ax.plot(x, y_2nd_sin(x, eigval), '-', label = 'hessian sine: {}'.format(np.format_float_scientific(eigval,2)))
		ax.set_ylabel(r"$1-\mathcal{M}$")
		#ax.set_xscale('log')
		#ax.set_yscale('log')
		ax.legend()
		
		new_eigvals.append(p_2nd[0])

	min_x, max_x = axes[-1].set_xlim()
	axes[-1].set_xlim((min_x, max_x*2.))
	
	plt.xlabel(r"$\epsilon$")
	plt.tight_layout()
	
	return

#########################################################################


##### Creating parser
parser = argparse.ArgumentParser(__doc__)

mbank.parser.add_general_options(parser)
mbank.parser.add_metric_options(parser)

parser.add_argument(
	"--theta", required = False, type = float, nargs = '+',
	help="The point at which the metric is evaluated at")
parser.add_argument(
	"--match", required = False, default = 0.97, type = float,
	help="Match for plotting the constant match ellipse and metric accuracy histogram")
parser.add_argument(
	"--overlap",  default = False, action='store_true',
	help="Whether to compute the metric based on the overlap rather than the match")
parser.add_argument(
	"--symphony-match", default = False, action='store_true',
	help="Whether to compute the match with the symphony match")
parser.add_argument(
	"--on-boundaries",  default = False, action='store_true',
	help="Whether the sampled points for validation should be drawn on the boundaries of the constant match ellipse. If False (default) points will be drawn inside.")
parser.add_argument(
	"--epsilon",  default = 1e-6, type=float,
	help="Infinitesimal step for the numerical gradients")
parser.add_argument(
	"--N-points",  default = 500, type=int,
	help="Number of points to sample inside constant match ellipse")

args, filenames = parser.parse_known_args()

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = mbank.parser.updates_args_from_ini(f, args, parser)

##################################################
	######
	#	Interpreting the parser and initializing variables
	######
if (args.theta is None) or (args.variable_format is None) or (args.psd is None):
	raise ValueError("The arguments --theta, --variable-format and --psd must be set!")

if args.run_dir is not None:
	if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
	if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)

	######
	#	Initializing the metric object
	######

f, PSD = load_PSD(args.psd, args.asd, args.ifo, df = args.df)
m_obj = cbc_metric(args.variable_format,
		PSD = (f,PSD),
		approx = args.approximant,
		f_min = args.f_min, f_max = args.f_max)

center = np.array(args.theta)
assert center.shape[0] == m_obj.D, "Wrong dimensionality for the center: expected {} but given {}".format(center.shape[0], m_obj.D)
if not m_obj.var_handler.is_theta_ok(center, args.variable_format):
	raise ValueError("Unacceptable value for the center was provided!")
		

	######
	# Computing the metric in all its forms
	######

target_match = 0.9
metric_hessian = m_obj.get_hessian(center, overlap = args.overlap, order = None, epsilon = args.epsilon)
metric_symphony = m_obj.get_hessian_symphony(center, overlap = args.overlap, order = None, epsilon = args.epsilon)
metric_fisher = m_obj.get_fisher_matrix(center, overlap = args.overlap, order = None, epsilon = args.epsilon)
metric_block_diagonal = m_obj.get_block_diagonal_hessian(center,  overlap = args.overlap, epsilon = 1e-6, order = None)
metric_numerical = np.eye(center.shape[0])  #m_obj.get_numerical_hessian(center,  overlap = overlap, target_match = 0.999, epsilon = 1e-6)
metric_parabolae, parabolae, original_parabola_metric = m_obj.get_parabolic_fit_hessian(center, overlap = args.overlap, target_match = target_match,
			N_epsilon_points = 10, log_epsilon_range = (-5, 1), full_output = True)

metric = m_obj.get_metric(center, overlap = args.overlap, metric_type = args.metric_type, epsilon = args.epsilon)

print("Variable format: ", args.variable_format)
print("Approximant: ", args.approximant)
print("Frequency range: ", args.f_min, args.f_max)
print("center: ", center)
print("  Eig hessian: ", np.linalg.eig(metric_hessian)[0], np.linalg.det(metric_hessian))
print("  Eig symphony: ", np.linalg.eig(metric_symphony)[0], np.linalg.det(metric_symphony))
print("  Eig Fisher: ", np.linalg.eig(metric_fisher)[0], np.linalg.det(metric_fisher))
print("  Eig block diagonal: ", np.linalg.eig(metric_block_diagonal)[0], np.linalg.det(metric_block_diagonal))
print("  Eig numerical hessian: ", np.linalg.eig(metric_numerical)[0], np.linalg.det(metric_numerical))
print("  Eig parabola metric: ", np.linalg.eig(metric_parabolae)[0], np.linalg.det(metric_parabolae))

if args.run_dir is None and not args.plot: quit()

if args.plot and not args.run_dir: args.show = True

	######
	# Sampling from the ellipse and computing match
	######
points = m_obj.get_points_on_ellipse(args.N_points, center, match = args.match, metric = metric, inside = (not args.on_boundaries), overlap = args.overlap)
matches = np.zeros((args.N_points,))+np.nan
ids_inside = m_obj.var_handler.is_theta_ok(points, args.variable_format, raise_error = False) #DEBUG

#p1, p2 = m_obj.get_points_at_match(args.N_points, center, match = args.match, metric = metric, overlap = args.overlap)
#ids_inside = np.logical_and(m_obj.var_handler.is_theta_ok(p1, args.variable_format, raise_error = False),
#		m_obj.var_handler.is_theta_ok(p2, args.variable_format, raise_error = False))#DEBUG

if sum(ids_inside)>0:
	matches[ids_inside] = m_obj.match(points[ids_inside, :], center,
		overlap = args.overlap, symphony = args.symphony_match, antenna_patterns = (1,0))
	#matches[ids_inside] = m_obj.match(p1[ids_inside, :], p2[ids_inside, :], overlap = args.overlap) #DEBUG
#points = np.concatenate([p1, p2], axis = 0)
#matches = np.concatenate([matches, matches], axis = 0)

	######
	# Plotting the parabolae
	######
plot_parabolae(center, original_parabola_metric, parabolae)
if args.run_dir is not None: plt.savefig(args.run_dir+"parabolae.png")

	######
	# Plotting the ellipses and histogram
	######
t_obj = tiling_handler()
var_handler = variable_handler()
var2_ranges = (1,10) if var_handler.format_info[args.variable_format]['mass_format'].find('q')>-1 else (0.,0.25)
boundaries = get_boundaries_from_ranges(args.variable_format, (1, 100), var2_ranges)
boundaries = np.stack([center-1e-4, center+1e-4], axis = 0)
#t_obj.extend([tile(Rectangle(center-1e-6, center+1e-6), metric)]) #mbank metric
t_obj.extend([tile(Rectangle(*boundaries), metric)]) #mbank metric

with warnings.catch_warnings():
	warnings.simplefilter("ignore")
	plot_tiles_templates(center[None,:] , args.variable_format, t_obj,
		injections = points, inj_cmap = matches,
		dist_ellipse = np.sqrt(1-args.match),
		save_folder = None, show = False) 
plt.close(plt.get_fignums()[-1])
plt.suptitle("{} - {}\ncenter = {}".format(args.variable_format, args.approximant, center), fontsize = 25)
if args.run_dir is not None: plt.savefig(args.run_dir+"ellipses.png")

	#histogram
plt.figure()
plt.title("Center: {} - Overlap {}".format(center, args.overlap))
hist_kwargs = {'bins': 20, 'histtype':'step', 'density':True, 'cumulative': True, 'color':'orange'}
plt.hist(matches, label = 'mbank', **hist_kwargs)
plt.axvline(args.match, c = 'r')
plt.yscale('log')
plt.legend()
if args.run_dir is not None: plt.savefig(args.run_dir+"validation_hist.png")

	
	######
	# Epsilon validation (plots and computation)
	######
epsilon_list = np.logspace(-10, -2, 10)
order_list = [1,2,4, 6, 8]
det_list = []
grad_norm_list = []

for epsilon in tqdm(epsilon_list, desc = 'Loop on epsilon', leave = False):
	temp_list_det = []
	temp_list_grad_norm = []
	
	epsilon_ = np.full(center.shape, epsilon)
	#epsilon_[:3] = [1e-4, 1e-4, 1e-4]
	epsilon_[0] = epsilon_[0]/10.
	for order in order_list:
			#to compute the det
		metric = m_obj.get_metric(center, overlap = args.overlap, order = order, epsilon = epsilon_, metric_type = args.metric_type)
		grads = m_obj.get_WF_grads(center, approx = m_obj.approx, order = order, epsilon = epsilon_) #(D,K)

		temp_list_det.append(np.linalg.det(metric))
		temp_list_grad_norm.append(np.linalg.norm(grads, axis = 0))
		
	det_list.append(temp_list_det)
	grad_norm_list.append(temp_list_grad_norm)
det_list = np.array(det_list)
grad_norm_list = np.array(grad_norm_list)

	#Plotting epsilon
color_list = ['r','g', 'b', 'orange', 'c', 'm', 'y', 'k', 'indigo', 'fuchsia']

plt.figure()
plt.title("{} - theta = {}\napproximant = {}".format(args.variable_format, center, args.approximant))
for i, order in enumerate(order_list):
	if det_list.ndim ==2: plt.loglog(epsilon_list, det_list[:,i], label = 'order = {}'.format(order))
	else:
		for j in range(det_list.shape[-1]):
			plt.loglog(epsilon_list, det_list[:,i,j], label = 'dim = {} - order = {}'.format(j, order), c = color_list[j])
plt.legend()
plt.xlabel(r'$\epsilon$')
plt.ylabel(r'$|M|$')
if args.run_dir is not None: plt.savefig(args.run_dir+"epsilon_det.png")

fig, axes = plt.subplots(len(center), 1, sharex = True)
plt.suptitle("{} - theta = {}\napproximant = {}".format(args.variable_format, center, args.approximant))
for d, ax in enumerate(axes):
	var_name = var_handler.labels(args.variable_format)[d]
	ax.set_title("Grad norm: d = {} - {}".format(d, var_name))
	ax.set_ylabel(r'$||\partial_{{{}}} h||$'.format(var_name))
	
	for i, order in enumerate(order_list):
		if det_list.ndim ==2: ax.loglog(epsilon_list, grad_norm_list[:,i,d], label = 'order = {}'.format(order))
		else:
			for j in range(det_list.shape[-1]):
				ax.loglog(epsilon_list, grad_norm_list[:,i,j,d], label = 'dim = {} - order = {}'.format(j, order), c = color_list[j])
plt.legend()
plt.xlabel(r'$\epsilon$')
plt.tight_layout()
if args.run_dir is not None: plt.savefig(args.run_dir+"epsilon_grad_norm.png")

if args.show: plt.show()







	
	
	
