#!/usr/bin/env python
"""
mbank_run
---------

A script to run mbank for generating a bank

To generate a bank:

	mbank_run --options-you-like

You can also load (some) options from an ini-file:

	mbank_run --some-options other_options.ini

Make sure that the mbank is properly installed.

To create a sub file and run in condor:

	mbank_run --make-sub options.ini

To know which options are available:

	mbank_run --help
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
import warnings
from ligo.lw.utils import load_filename

from mbank import variable_handler, cbc_metric, cbc_bank
from mbank.utils import updates_args_from_ini, int_tuple_type, load_PSD, avg_dist, plot_tiles_templates, get_boundaries_from_ranges
from mbank.flow.utils import compare_probability_distribution

import argparse
import os

def int_tuple_type(strings):
	strings = strings.replace("(", "").replace(")", "")
	mapped_int = map(int, strings.split(","))
	return tuple(mapped_int)

##### Creating parser
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
	"--variable-format", required = False,
	help="Choose which variables to include in the bank. Valid formats are those of `mbank.handlers.variable_format`")
parser.add_argument(
	"--mm", required = False, type = float, default = 0.97,
	help="Minimum match for the bank (a.k.a. average distance between templates)")
parser.add_argument(
	"--psd",  required = False,
	help="The input file for the PSD: it can be either a txt either a ligolw xml file")
parser.add_argument(
	"--asd",  default = False, action='store_true',
	help="Whether the input file has an ASD (sqrt of the PSD)")
parser.add_argument(
	"--ifo", default = 'L1', type=str, choices = ['L1', 'H1', 'V1'],
	help="Interferometer name: it can be L1, H1, V1. This is a field for the xml files for the PSD and the bank")
parser.add_argument(
	"--plot", action='store_true',
	help="Whether to make some plots. They will be store in run-dir")
parser.add_argument(
	"--show", action='store_true',
	help="Whether to show the plots.")
parser.add_argument(
	"--run-dir", default = './out_$(run_name)',
	help="Output directory in which the bank will be saved. If default is used, the bank name will be appended.")
parser.add_argument(
	"--run-name", default = 'cbc_mbank',
	help="Name for the bank and tiling output file")
parser.add_argument(
	"--grid-size", default = None, type=int_tuple_type,
	help="Number of grid points for each dimension. The number of grid must match the number extra dimensions. If None, the grid size will be a set of ones")
parser.add_argument(
	"--f-min",  default = 10., type=float,
	help="Minium frequency for the scalar product")
parser.add_argument(
	"--f-max",  default = 1024., type=float,
	help="Maximum frequency for the scalar product")
parser.add_argument(
	"--approximant", default = 'IMRPhenomPv2',
	help="LAL approximant for the bank generation")
parser.add_argument(
	"--placing-method", default = 'geometric', type = str, choices = cbc_bank('Mq_nonspinning').placing_methods,
	help="Which placing method to use for each tile")
parser.add_argument(
	"--n-livepoints", default = 10000, type = float,
	help="Parameter to control the number of livepoints to use in the `random` and `pruning` placing method. For `random` (or related), it represents the number of livepoints to use for the estimation of the coverage fraction. For `pruning`, it amounts to the the ratio between the number of livepoints and the number of templates placed by ``uniform`` placing method.")
parser.add_argument(
	"--covering-fraction", default = 0.01, type = float,
	help="Parameter to control the maximum fraction of livepoints alive before terminating the bank generation with the `random` and `pruning` placing method. The smaller the threshold, the higher the nuber of templates in the final bank.")
parser.add_argument(
	"--empty-iterations", default = 100, type = float,
	help="Number of consecutive rejected proposal after which the `stochastic` placing method stops.")
parser.add_argument(
	"--tile-tolerance", default = 0.1, type = float,
	help="Maximum tolerated variation of the relative difference of the metric determinant between parent and child in the iterative splitting procedure")
parser.add_argument(
	"--max-depth", default = 6, type = int,
	help="Maximum number of iterative splitting before terminating.")
parser.add_argument(
	"--metric-type", default = 'hessian', type = str, choices = ['hessian', 'parabolic_fit_hessian', 'symphony'],
	help="Method to use to compute the metric.")
parser.add_argument(
	"--use-ray", action='store_true', default = False,
	help="Whether to use ray package to parallelize the metric computation")

parser.add_argument(
	"--train-flow", action='store_true', default = False,
	help="Whether to train a normalizing flow model for the tiling. It will be used for metric interpolation")
parser.add_argument(
	"--n-layers", default = 2, type = int,
	help="Number of layers for the flow model to train (applicable only if --train-flow)")
parser.add_argument(
	"--hidden-features", default = 4, type = int,
	help="Number of hidden features for the masked autoregressive flow to train. (applicable only if --train-flow)")
parser.add_argument(
	"--n-epochs", default = 1000, type = int,
	help="Number of training epochs for the flow (applicable only if --train-flow)")
	
	#ranges for physical parameters
parser.add_argument(
	"--m-range", default = [10., 100], type=float, nargs = 2,
	help="Range values for the masses (in solar masses)")
parser.add_argument(
	"--mtot-range", default = [10., 100], type=float, nargs = 2,
	help="Range values for the total masses (in solar masses).")
parser.add_argument(
	"--q-range", default = [1., 10.], type=float, nargs = 2,
	help="Range values for the mass ratio.")
parser.add_argument(
	"--mc-range", default = [10., 100], type=float, nargs = 2,
	help="Range values for the total masses (in solar masses).")
parser.add_argument(
	"--eta-range", default = [.18, .25], type=float, nargs = 2,
	help="Range values for the mass ratio.")
parser.add_argument(
	"--s1-range", default = [-0.99,0.99], type=float, nargs = 2,
	help="Range values for magnitude of spin 1 (if applicable)")
parser.add_argument(
	"--s2-range", default = [-0.99,0.99], type=float, nargs = 2,
	help="Range values for magnitude of spin 1 (if applicable)")
parser.add_argument(
	"--chi-range", default = [-0.99,0.99], type=float, nargs = 2,
	help="Range values for effective spin parameter (if applicable)")
parser.add_argument(
	"--theta-range", default = [-np.pi, np.pi], type=float, nargs = 2,
	help="Range values for theta angles of spins (if applicable)")
parser.add_argument(
	"--phi-range", default = [-np.pi/2, np.pi/2], type=float, nargs = 2,
	help="Range values for phi angles of spins (if applicable)")
parser.add_argument(
	"--e-range", default = [0., 0.5], type=float, nargs = 2,
	help="Range values for the eccentricity (if applicable)")
parser.add_argument(
	"--meanano-range", default = [0., 1], type=float, nargs = 2,
	help="Range values for the mean anomaly (if applicable). TODO: find a nice default...")
parser.add_argument(
	"--iota-range", default = [0., np.pi], type=float, nargs = 2,
	help="Range values for iota (if applicable)")
parser.add_argument(
	"--ref-phase-range", default = [-np.pi, np.pi], type=float, nargs = 2,
	help="Range values for reference phase (if applicable)")

parser.add_argument(
	"--make-sub",  default = False, action='store_true',
	help="If set, it will make a condor submit file that the user can use to launch the job through condor.")

args, filenames = parser.parse_known_args()

sub_file_str = """Universe   = vanilla
Executable = {}
arguments = "{}"
getenv = true
Log = {}_mbank_run_{}.log
Error = {}_mbank_run_{}.err
Output = {}_mbank_run_{}.out
request_memory = 4GB
request_cpus = {}
request_disk = 4GB

queue
"""

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = updates_args_from_ini(f, args, parser)

####################################################################################################
	######
	#	Interpreting the parser and initializing variables
	######

if (args.psd is None) or (args.mm is None) or (args.variable_format is None):
	raise ValueError("The arguments mm, psd and variable_format must be set!")

var_handler = variable_handler()
assert args.variable_format in var_handler.valid_formats, "Wrong value {} for variable-format".format(args.variable_format)

if args.run_dir == './out_$(run_name)':	args.run_dir = './out_{}/'.format(args.run_name)
if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)
if args.psd.find('/') <0: args.psd = args.run_dir+args.psd

if args.grid_size is None: args.grid_size = tuple([1 for i in range(var_handler.D(args.variable_format))])

if args.make_sub:
	if '--make-sub' in sys.argv:
		sys.argv.remove('--make-sub')
	else:
		warnings.warn('The option --make-sub was given in the inifile. Make sure you remove it from your ini before launching the condor job!')
	n_cpu = None if args.grid_size is None else int(np.prod(args.grid_size))
	sub_file_str = sub_file_str.format(sys.argv[0], ' '.join(sys.argv[1:]),
		args.run_dir, args.run_name,
		args.run_dir, args.run_name,
		args.run_dir, args.run_name,
		n_cpu)
	
	sub_file = args.run_dir+'mbank_run_{}.sub'.format(args.run_name)
	with open(sub_file, 'w') as f:
		f.write(sub_file_str)
	print('#####')
	sub_file = args.run_dir+'mbank_run_{}.sub'.format(args.run_name)
	print("Submit file generated @ {}\nSubmit a job it with condor_submit {}\n".format(sub_file, sub_file))
	print("Monitor it with: tail -f {}mbank_run_{}.err".format(args.run_dir, args.run_name))
	print('#####')
	print(sub_file_str)
	quit()

m_min, m_max = args.m_range
mtot_min, mtot_max = args.mtot_range
q_min, q_max = args.q_range
mc_min, mc_max = args.mc_range
eta_min, eta_max = args.eta_range
s1_min, s1_max = args.s1_range
s2_min, s2_max = args.s2_range
chi_min, chi_max = args.chi_range
theta_min, theta_max = args.theta_range
e_min, e_max = args.e_range
meanano_min, meanano_max = args.meanano_range
phi_min, phi_max = args.phi_range
iota_min, iota_max = args.iota_range
ref_phase_min, ref_phase_max = args.ref_phase_range

plot_folder = None
if args.plot: plot_folder = args.run_dir

format_info = var_handler.format_info[args.variable_format]

	######
	#	Setting boundaries: shape (2,D)
	######
	#setting mass boundaries
if format_info['mass_format'] == 'm1m2':
	var1_min, var1_max = m_min, m_max
	var2_min, var2_max = m_min, m_max
elif format_info['mass_format'] in ['Mq', 'logMq']:
	var1_min, var1_max = mtot_min, mtot_max
	var2_min, var2_max = q_min, q_max
elif format_info['mass_format'] == 'mceta':
	var1_min, var1_max = mc_min, mc_max
	var2_min, var2_max = eta_min, eta_max

	#setting spin boundaries
boundaries = get_boundaries_from_ranges(args.variable_format,
	(var1_min, var1_max), (var2_min, var2_max), chi_range = (chi_min, chi_max),
	s1_range = (s1_min, s1_max), s2_range = (s2_min, s2_max), theta_range = (theta_min, theta_max), phi_range = (phi_min, phi_max),
	iota_range = (iota_min, iota_max), ref_phase_range = (ref_phase_min, ref_phase_max),
	e_range = (e_min, e_max), meanano_range = (meanano_min, meanano_max))

	######
	#	Loading PSD and initializing metric
	######
m = cbc_metric(args.variable_format,
			PSD = load_PSD(args.psd, args.asd, args.ifo),
			approx = args.approximant,
			f_min = args.f_min, f_max = args.f_max)

print("## Running: ", args.run_name)

	######
	#	Running and saving the output
	######
bank = cbc_bank(args.variable_format)

if format_info['mass_format'] == 'm1m2':
	raise NotImplementedError("Currently no template placement is implemented for the mass format m1m2.")

else:
	t_obj = bank.generate_bank(m, minimum_match = args.mm, boundaries = boundaries,
		max_depth = args.max_depth, tolerance = args.tile_tolerance,
		placing_method = args.placing_method, metric_type = args.metric_type,
		grid_list = args.grid_size, train_flow = args.train_flow,
		N_livepoints = args.n_livepoints, covering_fraction = args.covering_fraction, empty_iterations = args.empty_iterations,
		use_ray = args.use_ray, n_layers = args.n_layers, hidden_features = args.hidden_features, N_epochs = args.n_epochs)
	
	tiling_file = args.run_dir+'tiling_{}.npy'.format(args.run_name)
	flow_file = args.run_dir+'flow_{}.zip'.format(args.run_name) if args.train_flow else None
	t_obj.save(tiling_file, flow_file) #saving also the flow

if bank.templates is None:
	print("No templates were added to the bank: the bank is not saved")
	quit()

print("Generated bank with {} templates and {} tiles".format(len(bank.templates), len(t_obj)))
print("Saving bank to {}".format(args.run_dir))
bank.save_bank(args.run_dir+'bank_{}.dat'.format(args.run_name))
bank.save_bank(args.run_dir+'bank_{}.xml.gz'.format(args.run_name), args.f_max, args.ifo); title = ''

if args.plot:
	dist = None #avg_dist(args.mm, bank.D) if bank.D == 2 else None
	plot_tiles_templates(bank.templates, args.variable_format, #t_obj,
		dist_ellipse = dist, save_folder = plot_folder, show = args.show)
	
	if t_obj.flow:
		compare_probability_distribution(t_obj.flow.sample(5000).detach().numpy(), data_true = t_obj.sample_from_tiling(5000),
			variable_format = args.variable_format,
			title = None, hue_labels = ['flow', 'tiling'],
			savefile = '{}/flow.png'.format(plot_folder), show = args.show)
		
	













