#!/usr/bin/env python
"""
mbank_mcmc
----------

A script to sample points from the PDF defined by the metric.
It uses `emcee` package, which is not among the package dependencies, so make sure it

To sample from:

	mbank_mcmc --options-you-like

You can also load (some) options from an ini-file:

	mbank_mcmc --some-options other_options.ini

Make sure that the mbank is properly installed.
To know which options are available:

	mbank_mcmc --help
"""
import numpy as np

from mbank import cbc_metric, cbc_bank
from mbank.handlers import tile, variable_handler, tiling_handler
from mbank.utils import updates_args_from_ini, int_tuple_type, load_PSD, plot_tiles_templates, get_boundaries_from_ranges

from scipy.spatial import Rectangle
import argparse
import os

def int_tuple_type(strings):
	strings = strings.replace("(", "").replace(")", "")
	mapped_int = map(int, strings.split(","))
	return tuple(mapped_int)


##### Creating parser
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
	"--variable-format", required = False,
	help="Choose which variables to include in the bank. Valid formats are those of `mbank.handlers.variable_format`")
parser.add_argument(
	"--n-samples", required = False, type = int,
	help="Number of samples to be drawn from the MCMC")
parser.add_argument(
	"--n-walkers", required = False, type = int, default =1,
	help="Number of walkers for the MCMC. If a chain-file is set, the file should has enough walkers")
parser.add_argument(
	"--chain-file",  required = False,
	help="Path to a chain file. If given, this will be the input of the MCMC and the burn-in phase will not be performed.")
parser.add_argument(
	"--psd",  required = False,
	help="The input file for the PSD: it can be either a txt either a ligolw xml file")
parser.add_argument(
	"--asd",  default = False, action='store_true',
	help="Whether the input file has an ASD (sqrt of the PSD)")
parser.add_argument(
	"--ifo", default = 'L1', type=str, choices = ['L1', 'H1', 'V1'],
	help="Interferometer name: it can be L1, H1, V1. This is a field for the xml files for the PSD and the bank")
parser.add_argument(
	"--plot", action='store_true',
	help="Whether to make some plots. They will be store in run-dir")
parser.add_argument(
	"--show", action='store_true',
	help="Whether to show the plots.")
parser.add_argument(
	"--run-dir", default = './out_$(run_name)',
	help="Output directory in which the bank will be saved. If default is used, the bank name will be appended.")
parser.add_argument(
	"--run-name", default = 'cbc_mbank',
	help="Name for the bank and tiling output file")
parser.add_argument(
	"--f-min",  default = 10., type=float,
	help="Minium frequency for the scalar product")
parser.add_argument(
	"--f-max",  default = 1024., type=float,
	help="Maximum frequency for the scalar product")
parser.add_argument(
	"--approximant", default = None,
	help="LAL approximant for the bank generation")
parser.add_argument(
	"--use-ray", action='store_true', default = False,
	help="Whether to use ray package to parallelize the metric computation")
	
	#ranges for physical parameters
parser.add_argument(
	"--m-range", default = [10., 100], type=float, nargs = 2,
	help="Range values for the masses (in solar masses)")
parser.add_argument(
	"--mtot-range", default = [10., 100], type=float, nargs = 2,
	help="Range values for the total masses (in solar masses).")
parser.add_argument(
	"--q-range", default = [1., 10.], type=float, nargs = 2,
	help="Range values for the mass ratio.")
parser.add_argument(
	"--mc-range", default = [10., 100], type=float, nargs = 2,
	help="Range values for the total masses (in solar masses).")
parser.add_argument(
	"--eta-range", default = [.18, .25], type=float, nargs = 2,
	help="Range values for the mass ratio.")
parser.add_argument(
	"--s1-range", default = [-0.99,0.99], type=float, nargs = 2,
	help="Range values for magnitude of spin 1 (if applicable)")
parser.add_argument(
	"--s2-range", default = [-0.99,0.99], type=float, nargs = 2,
	help="Range values for magnitude of spin 1 (if applicable)")
parser.add_argument(
	"--chi-range", default = [-0.99,0.99], type=float, nargs = 2,
	help="Range values for effective spin parameter (if applicable)")
parser.add_argument(
	"--theta-range", default = [-np.pi, np.pi], type=float, nargs = 2,
	help="Range values for theta angles of spins (if applicable)")
parser.add_argument(
	"--phi-range", default = [-np.pi/2, np.pi/2], type=float, nargs = 2,
	help="Range values for phi angles of spins (if applicable)")
parser.add_argument(
	"--e-range", default = [0., 0.5], type=float, nargs = 2,
	help="Range values for the eccentricity (if applicable)")
parser.add_argument(
	"--meanano-range", default = [0., 1], type=float, nargs = 2,
	help="Range values for the mean anomaly (if applicable). TODO: find a nice default...")
parser.add_argument(
	"--iota-range", default = [0., np.pi], type=float, nargs = 2,
	help="Range values for iota (if applicable)")
parser.add_argument(
	"--ref-phase-range", default = [-np.pi, np.pi], type=float, nargs = 2,
	help="Range values for reference phase (if applicable)")

args, filenames = parser.parse_known_args()

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = updates_args_from_ini(f, args, parser)

####################################################################################################
	######
	#	Interpreting the parser and initializing variables
	######

if (args.psd is None) or (args.n_samples is None) or (args.variable_format is None):
	raise ValueError("The arguments n-samples, psd and variable-format must be set!")

var_handler = variable_handler()
assert args.variable_format in var_handler.valid_formats, "Wrong value {} for variable-format".format(args.variable_format)

if args.run_dir == './out_$(run_name)':	args.run_dir = './out_{}/'.format(args.run_name)
if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)
if args.psd.find('/') <0: args.psd = args.run_dir+args.psd
if isinstance(args.chain_file, str):
	if args.chain_file.find('/') <0: args.chain_file = args.run_dir+args.chain_file

m_min, m_max = args.m_range
mtot_min, mtot_max = args.mtot_range
q_min, q_max = args.q_range
mc_min, mc_max = args.mc_range
eta_min, eta_max = args.eta_range
s1_min, s1_max = args.s1_range
s2_min, s2_max = args.s2_range
chi_min, chi_max = args.chi_range
theta_min, theta_max = args.theta_range
e_min, e_max = args.e_range
meanano_min, meanano_max = args.meanano_range
phi_min, phi_max = args.phi_range
iota_min, iota_max = args.iota_range
ref_phase_min, ref_phase_max = args.ref_phase_range

plot_folder = None
if args.plot: plot_folder = args.run_dir

format_info = var_handler.format_info[args.variable_format]

	######
	#	Setting boundaries: shape (2,D)
	######
	#setting mass boundaries
if format_info['mass_format'] == 'm1m2':
	var1_min, var1_max = m_min, m_max
	var2_min, var2_max = m_min, m_max
elif format_info['mass_format'] in ['Mq', 'logMq']:
	var1_min, var1_max = mtot_min, mtot_max
	var2_min, var2_max = q_min, q_max
elif format_info['mass_format'] == 'mceta':
	var1_min, var1_max = mc_min, mc_max
	var2_min, var2_max = eta_min, eta_max

	#setting spin boundaries
boundaries = get_boundaries_from_ranges(args.variable_format,
	(var1_min, var1_max), (var2_min, var2_max), chi_range = (chi_min, chi_max),
	s1_range = (s1_min, s1_max), s2_range = (s2_min, s2_max), theta_range = (theta_min, theta_max), phi_range = (phi_min, phi_max),
	iota_range = (iota_min, iota_max), ref_phase_range = (ref_phase_min, ref_phase_max),
	e_range = (e_min, e_max), meanano_range = (meanano_min, meanano_max))

	######
	#	Loading PSD and initializing metric
	######
m = cbc_metric(args.variable_format,
			PSD = load_PSD(args.psd, args.asd, args.ifo),
			approx = args.approximant,
			f_min = args.f_min, f_max = args.f_max)

print("## Running: ", args.run_name)

	######
	#	Running and saving the output
	######
bank = cbc_bank(args.variable_format)
#bank.load(args.run_dir+'bank_{}.dat'.format(args.run_name))

bank.generate_bank_mcmc(m, args.n_samples, boundaries = boundaries, n_walkers = args.n_walkers,
		thin_factor = None, #FIXME: should you expose this to the user?
		load_chain = args.chain_file,
		save_chain = args.run_dir+'chain_{}_{}.dat'.format(args.run_name, args.n_walkers),
		use_ray = args.use_ray, verbose = True)
	

print("Saving {} samples to bank file {}".format(bank.templates.shape[0], args.run_dir))
bank.save_bank(args.run_dir+'bank_{}.dat'.format(args.run_name))
bank.save_bank(args.run_dir+'bank_{}.xml.gz'.format(args.run_name), args.ifo)

if False: #DEBUG plot
	import matplotlib.pyplot as plt
	plt.figure()
	plt.scatter(*np.random.multivariate_normal([10,10], np.eye(bank.templates.shape[1]), 100000).T, s = 1, c= 'r', alpha = 0.5)
	plt.scatter(*bank.templates.T, s = 2)
	plt.show()

if args.plot:
	t_obj = tiling_handler()
	t_obj.extend([tile(Rectangle(boundaries[0,:], boundaries[1,:]),
			m.get_metric(np.mean(boundaries, axis =0)))])
	plot_tiles_templates(bank.templates, args.variable_format,
		tiling = None, dist_ellipse = None, save_folder = plot_folder, show = args.show)













