#!/usr/bin/env python
"""
mbank_mcmc
----------

A script to sample points from the PDF defined by the metric.
It uses `emcee` package, which is not among the package dependencies, so make sure it

To sample from:

	mbank_mcmc --options-you-like

You can also load (some) options from an ini-file:

	mbank_mcmc --some-options other_options.ini

Make sure that the mbank is properly installed.
To know which options are available:

	mbank_mcmc --help
"""
import numpy as np

from mbank import cbc_metric, cbc_bank
from mbank.handlers import tile, variable_handler, tiling_handler
from mbank.utils import load_PSD, plot_tiles_templates, get_boundaries_from_ranges
from mbank.parser import get_boundary_box_from_args
import mbank.parser

from scipy.spatial import Rectangle
import argparse
import os

##### Creating parser
parser = argparse.ArgumentParser(__doc__)

parser.add_argument(
	"--n-samples", required = False, type = int,
	help="Number of samples to be drawn from the MCMC")
parser.add_argument(
	"--n-walkers", required = False, type = int, default =1,
	help="Number of walkers for the MCMC. If a chain-file is set, the file should has enough walkers")
parser.add_argument(
	"--chain-file",  required = False,
	help="Path to a chain file. If given, this will be the input of the MCMC and the burn-in phase will not be performed.")
parser.add_argument(
	"--use-ray", action='store_true', default = False,
	help="Whether to use ray package to parallelize the metric computation")

mbank.parser.add_general_options(parser)
mbank.parser.add_metric_options(parser)
mbank.parser.add_range_options(parser)
	
args, filenames = parser.parse_known_args()

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = mbank.parser.updates_args_from_ini(f, args, parser)

####################################################################################################
	######
	#	Interpreting the parser and initializing variables
	######

if (args.psd is None) or (args.n_samples is None) or (args.variable_format is None):
	raise ValueError("The arguments n-samples, psd and variable-format must be set!")

var_handler = variable_handler()
assert args.variable_format in var_handler.valid_formats, "Wrong value {} for variable-format".format(args.variable_format)

if args.run_dir == './out_$(run_name)':	args.run_dir = './out_{}/'.format(args.run_name)
if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)
if args.psd.find('/') <0: args.psd = args.run_dir+args.psd
if isinstance(args.chain_file, str):
	if args.chain_file.find('/') <0: args.chain_file = args.run_dir+args.chain_file

plot_folder = None
if args.plot: plot_folder = args.run_dir

mbank.parser.save_args(args, args.run_dir+'args_{}.json'.format(os.path.basename(__file__)))

format_info = var_handler.format_info[args.variable_format]

	######
	#	Setting boundaries: shape (2,D)
	######
boundaries = get_boundary_box_from_args(args)

	######
	#	Loading PSD and initializing metric
	######
m = cbc_metric(args.variable_format,
			PSD = load_PSD(args.psd, args.asd, args.ifo, df = args.df),
			approx = args.approximant,
			f_min = args.f_min, f_max = args.f_max)

print("## Running: ", args.run_name)

	######
	#	Running and saving the output
	######
bank = cbc_bank(args.variable_format)
#bank.load(args.run_dir+'bank_{}.dat'.format(args.run_name))

bank.generate_bank_mcmc(m, args.n_samples, boundaries = boundaries, n_walkers = args.n_walkers,
		thin_factor = None, #FIXME: should you expose this to the user?
		load_chain = args.chain_file,
		save_chain = args.run_dir+'chain_{}_{}.dat'.format(args.run_name, args.n_walkers),
		use_ray = args.use_ray, verbose = True)
	

print("Saving {} samples to bank file {}".format(bank.templates.shape[0], args.run_dir))
bank.save_bank(args.run_dir+'bank_{}.dat'.format(args.run_name))
bank.save_bank(args.run_dir+'bank_{}.xml.gz'.format(args.run_name), args.ifo)

if False: #DEBUG plot
	import matplotlib.pyplot as plt
	plt.figure()
	plt.scatter(*np.random.multivariate_normal([10,10], np.eye(bank.templates.shape[1]), 100000).T, s = 1, c= 'r', alpha = 0.5)
	plt.scatter(*bank.templates.T, s = 2)
	plt.show()

if args.plot:
	t_obj = tiling_handler()
	t_obj.extend([tile(Rectangle(boundaries[0,:], boundaries[1,:]),
			m.get_metric(np.mean(boundaries, axis =0)))])
	plot_tiles_templates(bank.templates, args.variable_format,
		tiling = None, dist_ellipse = None, save_folder = plot_folder, show = args.show)













