#!/usr/bin/env python
"""
mbank_compare_volumes
---------------------

A script to compare the volumes between different parameter space.
It computes the ratio :math:`\\frac{V^{other}}{V}` between volume computed with different PSD, speciefied `--psd-other` and `--psd`.
It outputs the results of the computation in a json file.

To compare two volumes:

	mbank_compare_volumes --options-you-like
	
You can also load (some) options from an ini-file:

	mbank_compare_volumes --some-options other_options.ini

Make sure that the mbank is properly installed.
To know which options are available:

	mbank_compare_volumes --help
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
import json
from tqdm import tqdm

from mbank import cbc_bank, tiling_handler, variable_handler, cbc_metric
from mbank.utils import plot_tiles_templates, avg_dist, load_PSD, plot_colormap
from mbank.placement import place_random_flow, place_stochastically_flow, place_geometric_flow
from mbank.flow import STD_GW_Flow
import mbank.parser

import argparse
import os

import torch 

##### Creating parser
parser = argparse.ArgumentParser(__doc__)

mbank.parser.add_general_options(parser)
mbank.parser.add_metric_options(parser)
mbank.parser.add_range_options(parser) #only for the flow

parser.add_argument(
	"--comparison-file", type = str,
	help="JSON file to store the result of the comparison between different volumes")

parser.add_argument(
	"--flow-file", default = None, type = str,
	help="File to load the normalizing flow, used for importance sampling estimation of the volume ratio. If not given, the samples will be drawn from a uniform distribution in the coordinates.")

parser.add_argument(
	"--psd-other",  required = False,
	help="The input file for the PSD of the other space: it can be either a txt either a ligolw xml file")
parser.add_argument(
	"--asd-other",  default = False, action='store_true',
	help="Whether the input PSD file has an ASD (sqrt of the PSD)")
parser.add_argument(
	"--ifo-other", default = 'L1', type=str, choices = ['L1', 'H1', 'V1'],
	help="Interferometer name: it can be L1, H1, V1. This is a field for the xml files for the PSD and the bank")
parser.add_argument(
	"--f-min-other",  default = None, type=float,
	help="Minium frequency for the scalar product of the other space (if not set, it will be set the same as the base space)")
parser.add_argument(
	"--f-max-other",  default = None, type=float,
	help="Maximum frequency for the scalar product of the other space (if not set, it will be set the same as the base space)")
parser.add_argument(
	"--df-other",  default = None, type=float,
	help="Spacing of the frequency grid where the PSD is evaluated for the other space (if not set, it will be set the same as the base space)")
parser.add_argument(
	"--approximant-other", default = None,
	help="LAL approximant for the metric computation in the other space (if not set, it will be set the same as the base space)")
parser.add_argument(
	"--metric-type-other", default = None, type = str, choices = ['hessian', 'parabolic_fit_hessian', 'symphony'],
	help="Method to use to compute the metric for the other space (if not set, it will be set the same as the base space)")


parser.add_argument(
	"--label", default = '1', type = str,
	help="Label for the parameter space characterized by --psd")
parser.add_argument(
	"--label-other", default = '2', type = str,
	help="Label for the other parameter space characterized by --psd-other")

parser.add_argument(
	"--n-points", default = 1000, type = int,
	help="Number of points to evaluate the volume ratio with")

args, filenames = parser.parse_known_args()

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = mbank.parser.updates_args_from_ini(f, args, parser)

##################################################
	######
	#	Interpreting the parser and initializing variables
	######
if (args.variable_format is None):
	raise ValueError("The argument --variable-format must be set!")

var_handler = mbank.variable_handler()
assert args.variable_format in var_handler.valid_formats, "Wrong value {} for variable-format".format(args.variable_format)
D = var_handler.D(args.variable_format)

if not args.run_dir:
	args.run_dir = './out_{}/'.format(args.run_name)

if args.run_dir =='': args.run_dir = './'
if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)

if args.flow_file:
	if args.flow_file.find('/') <0: args.flow_file = args.run_dir+args.flow_file
if args.psd:
	if args.psd.find('/') <0: args.psd = args.run_dir+args.psd
if args.psd_other:
	if args.psd_other.find('/') <0: args.psd_other = args.run_dir+args.psd_other
if args.comparison_file:
	if args.comparison_file.find('/') <0: args.comparison_file = args.run_dir+args.comparison_file
else:
	args.comparison_file = args.run_dir+'volume_comparison_{}.json'.format(args.run_name)

plot_folder = None
if args.plot: plot_folder = args.run_dir

mbank.parser.save_args(args, args.run_dir+'args_{}.json'.format(os.path.basename(__file__)))

	######
	#	Initializing metric generation
	######

		#Loading boundaries
bk = mbank.parser.boundary_keeper(args)
def boundaries_checker(theta):
	return bk(theta, args.variable_format)

	#Loading metrics & flow
m = cbc_metric(args.variable_format,
	PSD = load_PSD(args.psd, args.asd, args.ifo, df = args.df),
	approx = args.approximant,
	f_min = args.f_min, f_max = args.f_max)

	#Setting defaults, if appropriate...
if not args.psd_other:
	args.psd_other = args.psd
	args.asd_other = args.asd
if not args.df_other: args.df_other = args.df
if not args.approximant_other: args.approximant_other = args.approximant
if not args.f_max_other: args.f_max_other = args.f_max
if not args.f_min_other: args.f_min_other = args.f_min
if not args.metric_type_other:  args.metric_type_other = args.metric_type

m_other = cbc_metric(args.variable_format,
	PSD = load_PSD(args.psd_other, args.asd_other, args.ifo_other, df = args.df_other),
	approx = args.approximant_other,
	f_min = args.f_min_other, f_max = args.f_max_other)

if args.flow_file:
	flow = STD_GW_Flow.load_flow(args.flow_file)
else:
	flow = None


	######
	#	Computing volume ratios
	######
out_dict = {
	'variable_format': args.variable_format,
	'theta': [],
	'volume_element_flow': [],
	'volume_element': [], 'volume_element_other': [],
	'volume_ratio': 0
}

if flow:
	with torch.no_grad():
		points, log_prob = flow.sample_within_boundaries(args.n_points, boundaries_checker)
else:
	vol_coordinates, _ = bk.volume(args.n_points, args.variable_format, 20)
	points = bk.sample(args.n_points, args.variable_format)

for i in tqdm(range(args.n_points), desc = 'Evaluating the metric ratio', disable = not args.verbose):

	vol_elem = m.get_volume_element(points[i], metric_type = args.metric_type)
	vol_elem_other = m_other.get_volume_element(points[i], metric_type = args.metric_type_other)
	if flow:
		vol_elem_flow = np.exp(log_prob[i].numpy() + flow.constant.detach().numpy())[0]
	else:
		vol_elem_flow = 1/vol_coordinates
	
	#print(points[i].tolist(), vol_elem, vol_elem_other, vol_elem_flow)
	
	out_dict['theta'].append(points[i].tolist())
	out_dict['volume_element_flow'].append(float(vol_elem_flow))
	out_dict['volume_element_other'].append(float(vol_elem_other))
	out_dict['volume_element'].append(float(vol_elem))


	#Volume estimation

ve = np.array(out_dict['volume_element'])
ve_other = np.array(out_dict['volume_element_other'])
ve_flow = np.array(out_dict['volume_element_flow'])

out_dict['volume_ratio'] = np.mean(ve_other/ve_flow)/np.mean(ve/ve_flow)

print('Volume ratio between spaces "{}" vs "{}" is: {}'.format(args.label, args.label_other, out_dict['volume_ratio']))

	######
	#	Plotting & saving
	######

with open(args.comparison_file, 'w') as f:
	json.dump(out_dict, f, indent = 2)

if args.plot:
	points = np.array(out_dict['theta'])

	plt.figure()
	plt.hist(np.log10(ve)-np.log10(ve_flow),
		histtype = 'step', bins = int(np.sqrt(len(ve))), density = True, label = args.label)
	plt.hist(np.log10(ve_other)-np.log10(ve_flow),
		histtype = 'step', bins = int(np.sqrt(len(ve))), density = True, label = args.label_other)
	plt.legend()
	plt.xlabel(r"$\log_{10}(M_{space}/M_{flow})$")
	plt.savefig(args.run_dir+'samples_accuracy.png')

	plot_tiles_templates(points, args.variable_format,
		injections=points, inj_cmap=np.log10(ve_other/ve),
		save_folder = plot_folder, show = False)
	os.rename(args.run_dir+'injections.png', args.run_dir+'samples.png')
	os.rename(args.run_dir+'hist.png', args.run_dir+'samples_hist.png')
	os.remove(args.run_dir+'bank.png')

	plot_colormap(points, np.log10(ve_other/ve), args.variable_format, statistics = 'mean', bins = int(np.sqrt(len(ve)))//2,
		savefile = args.run_dir+'colormap_volume_ratio.png', values_label = r'$|\log_{10}\frac{p_{flow}}{p_{true}}|$', title = 'Residuals')

	if args.show: plt.show()
	
