#!/usr/bin/env python
"""
mbank_compute_volume
--------------------

A script to compute the volume of a given parameter space. If --mm is given, the volume is also expressed in units of template volume (i.e. is a number of templates).
The volume is computed with importance sampling.

To compute the volume:

	mbank_compute_volume --options-you-like
	
You can also load (some) options from an ini-file:

	mbank_compute_volume --some-options other_options.ini

Make sure that the mbank is properly installed.
To know which options are available:

	mbank_compute_volume --help
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
import json
from tqdm import tqdm

from mbank import cbc_bank, tiling_handler, variable_handler, cbc_metric
from mbank.utils import plot_tiles_templates, avg_dist, load_PSD, plot_colormap
from mbank.placement import place_random_flow, place_stochastically_flow, place_geometric_flow
from mbank.flow import STD_GW_Flow
import mbank.parser

import argparse
import os

import torch 

##### Creating parser
parser = argparse.ArgumentParser(__doc__)

mbank.parser.add_general_options(parser)
mbank.parser.add_metric_options(parser)
mbank.parser.add_range_options(parser) #only for the flow

parser.add_argument(
	"--volume-file", type = str,
	help="JSON file to store the result of the volume computation")

parser.add_argument(
	"--flow-file", default = None, type = str,
	help="File to load the normalizing flow, used for importance sampling estimation of the volume ratio. If not given, the samples will be drawn from a uniform distribution in the coordinates.")

parser.add_argument(
	"--mm", required = False, type = float, default = None,
	help="Typical match between neighbouring templates, to compute the template volume")

parser.add_argument(
	"--n-points", default = 1000, type = int,
	help="Number of points to evaluate the volume ratio with")

parser.add_argument(
	"--n-experiments", default = 1, type = int,
	help="Number of times the volume computation is performed (to compute the error of the mean). Default means no standard error is computed")

args, filenames = parser.parse_known_args()

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = mbank.parser.updates_args_from_ini(f, args, parser)

##################################################
	######
	#	Interpreting the parser and initializing variables
	######
if (args.variable_format is None):
	raise ValueError("The argument --variable-format must be set!")

var_handler = mbank.variable_handler()
assert args.variable_format in var_handler.valid_formats, "Wrong value {} for variable-format".format(args.variable_format)
D = var_handler.D(args.variable_format)

if not args.run_dir:
	args.run_dir = './out_{}/'.format(args.run_name)

if args.run_dir =='': args.run_dir = './'
if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)

if args.flow_file:
	if args.flow_file.find('/') <0: args.flow_file = args.run_dir+args.flow_file
if args.volume_file:
	if args.volume_file.find('/') <0: args.volume_file = args.run_dir+args.volume_file
else:
	args.volume_file = args.run_dir+'volume_{}.json'.format(args.run_name)

plot_folder = None
if args.plot: plot_folder = args.run_dir

mbank.parser.save_args(args, args.run_dir+'args_{}.json'.format(os.path.basename(__file__)))

	######
	#	Initializing metric generation
	######

		#Loading boundaries
bk = mbank.parser.boundary_keeper(args)
def boundaries_checker(theta):
	return bk(theta, args.variable_format)

	#Loading metrics & flow
m = cbc_metric(args.variable_format,
	PSD = load_PSD(args.psd, args.asd, args.ifo, df = args.df),
	approx = args.approximant,
	f_min = args.f_min, f_max = args.f_max)

if args.flow_file:
	flow = STD_GW_Flow.load_flow(args.flow_file)
else:
	flow = None

	######
	#	Computing the volume
	######

bk.set_variable_format(args.variable_format)
box = bk.b_cache
vol_coordinates = bk.volume_box(args.variable_format)
D = box.shape[-1]
#print(box)

out_dict = {
	'variable_format': args.variable_format,
	'n_experiments': args.n_experiments,
	'n_points': args.n_points,
	'theta': [],
	'volume_element': [], 'volume_element_sampling': [] if flow else 1/vol_coordinates,
	'volume_list': [],
	'volume': np.nan,
	'std_err_mean': np.nan
}

	#Volume ratios
for n in tqdm(range(args.n_experiments), desc = 'Evaluating the volume over multiple experiments', disable = not args.verbose):

	if flow:
		with torch.no_grad():
			points, log_prob = flow.sample_and_log_prob(args.n_points)
		points, log_prob = points.numpy(), np.squeeze(log_prob.numpy())
	else:
		points = np.random.uniform(*box, (args.n_points, box.shape[1]))
	
	ids_ = boundaries_checker(points)

	vol_elem = m.get_volume_element(points[ids_])

	if flow:
		vol_elem_flow = np.exp(log_prob[ids_] + flow.constant.detach().numpy())
	else:
		vol_elem_flow = 1/vol_coordinates
		
	#if flow: print(points[ids_][0], vol_elem[0], vol_elem_flow[0])

	vol = vol_elem/vol_elem_flow
	
	vol = np.sum(vol)/args.n_points
	
	if flow: vol = vol * np.exp(flow.constant.detach().numpy())
	
	out_dict['volume_list'].append(float(vol))
	out_dict['theta'].extend(points[ids_].tolist())
	if flow: out_dict['volume_element_sampling'].extend(vol_elem_flow.tolist())
	out_dict['volume_element'].extend(vol_elem.tolist())

	if args.mm and args.verbose: print('\tNumber of templates: ',float(vol)/avg_dist(args.mm, D)**D)

	#Volume estimation

out_dict['volume'] = np.mean(out_dict['volume_list'])
if args.n_experiments>1:
	out_dict['std_err_mean'] = np.std(out_dict['volume_list'], ddof = 1)/np.sqrt(args.n_experiments)

if args.n_experiments>1:
	print('Volume of the space: {} +/- {}'.format(out_dict['volume'], out_dict['std_err_mean']))
	
else:
	print('Volume of the space: {}'.format(out_dict['volume']))
if args.mm: print("Number of templates: {}".format(out_dict['volume']/avg_dist(args.mm, D)**D))

	######
	#	Plotting & saving
	######

with open(args.volume_file, 'w') as f:
	json.dump(out_dict, f, indent = 2)

