#!/usr/bin/env python
"""
mbank_generate_flow_dataset
---------------------------

Executable to generate a dataset useful for training a normalizing flow model.
Each row of the dataset consist in :math:`\\theta, M_{ij}, LL` where log likelihood is:

.. math
				
	LL = \log{\sqrt{|M(\\theta)|}}

The script will create a dataset with random points :math:`\\theta` drawn within the space of interest.

To train a flow:

	mbank_generate_flow_dataset --options-you-like

You can also load (some) options from an ini-file:

	mbank_generate_flow_dataset --some-options other_options.ini

Make sure that the mbank is properly installed.

To know which options are available:

	mbank_generate_flow_dataset --help
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
import warnings
from ligo.lw.utils import load_filename

import torch
from torch import optim

from mbank import variable_handler, cbc_metric
from mbank.utils import load_PSD, avg_dist, plot_tiles_templates, get_boundaries_from_ranges
from mbank.parser import get_boundary_box_from_args, boundary_keeper, make_sub_file
import mbank.parser

from mbank.flow import STD_GW_Flow, GW_Flow
from mbank.flow.utils import early_stopper

import emcee
from tqdm import tqdm
import argparse
import os
from lal import MTSUN_SI

############################################################################
def log_prob_initial_dataset(mceta, b_checker, sample_format):
	squeeze = (np.asarray(mceta).ndim == 1)
	mceta = np.atleast_2d(mceta)
	res = np.full(mceta[:,0].shape, -100000)
	ids_inside = b_checker(mceta, sample_format)
	if np.any(ids_inside):
		res[ids_inside] = -(np.log(2048/(5*np.pi)) + 13/3 * np.log(np.pi*args.f_min) + 10/3 * np.log(mceta[ids_inside,0]) + 8/5 * np.log(mceta[ids_inside,1]))
	if squeeze: res = res[0]
	return res
	

def get_random_masses(boundaries):

	assert boundaries.shape == (2,2)

	n_walkers = 100
	start_points = []
	out_variable_format = '{}_nonspinning'.format(var_handler.format_info[args.variable_format]['mass_format'])

	boundaries_checker = boundary_keeper(args)

	while len(start_points)<n_walkers:
		new_points = var_handler.convert_theta(np.random.uniform(*boundaries, (100, 2)), out_variable_format, sample_format)
		new_points = new_points[boundaries_checker(new_points, sample_format)]
		if len(new_points)>0:
			start_points.extend(new_points)
	start_points = np.array(start_points)[:n_walkers]
	
	sampler = emcee.EnsembleSampler(n_walkers, 2, log_prob_initial_dataset, args=[boundaries_checker, sample_format])
		#Burn-in
	points = sampler.run_mcmc(start_points, 400, tune=True, store=False, progress=args.verbose, progress_kwargs={'desc': 'Burn-in for mass samples'})
	
	new_masses = []
	for i in tqdm(range(args.n_datapoints//n_walkers+1), desc = 'Generating mass samples', disable = not args.verbose):
		points = sampler.run_mcmc(points, 10, tune=False, store=False, progress=False)
		new_masses.append(points.coords)

	new_masses = np.concatenate(new_masses, axis = 0)
	new_masses = var_handler.convert_theta(new_masses[:args.n_datapoints], sample_format, out_variable_format)

	return new_masses

##### Creating parser
parser = argparse.ArgumentParser(__doc__)

mbank.parser.add_general_options(parser)
mbank.parser.add_metric_options(parser)
mbank.parser.add_range_options(parser)

	#Options specific to this program
parser.add_argument(
	"--n-datapoints", default = 10000, type = int,
	help="Number of datapoints to be stored in the dataset")
parser.add_argument(
	"--dataset", default = None, type = str,
	help="Files with the datasets (more than one are accpted). If the file does not exist, the dataset will be created.")
parser.add_argument(
	"--datapoints-file", default = None, type = str,
	help="Csv with a list of points to build the dataset: for each of them the likelihood will be computed and stored in the dataset")
parser.add_argument(
	"--only-ll", default = False, action = 'store_true',
	help="Whether to store only the metric LL in the dataset file (to save memory). If not set, the metric will be stored, together with the LL.")
parser.add_argument(
	"--make-sub",  default = False, action='store_true',
	help="If set, it will make a condor submit file that the user can use to launch the job through condor.")

args, filenames = parser.parse_known_args()

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = mbank.parser.updates_args_from_ini(f, args, parser)

####################################################################################################
	######
	#	Interpreting the parser and initializing variables
	######

if (args.psd is None) or (args.dataset is None) or (args.variable_format is None):
	raise ValueError("The arguments --dataset, --psd and --variable-format must be set!")

var_handler = variable_handler()
assert args.variable_format in var_handler.valid_formats, "Wrong value {} for variable-format".format(args.variable_format)
D = var_handler.D(args.variable_format)

if not args.run_dir: args.run_dir = './out_{}/'.format(args.run_name)
if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)
if args.psd.find('/') <0: args.psd = args.run_dir+args.psd
if args.dataset.find('/') <0: args.dataset = args.run_dir+args.dataset
if args.datapoints_file:
	if args.datapoints_file.find('/') <0: args.datapoints_file = args.run_dir+args.datapoints_file

if args.make_sub:
	make_sub_file(args.run_dir, args.run_name)
	quit()

plot_folder = args.run_dir if args.plot else None
if args.verbose: print("## Running dataset generation: ", args.run_name)

mbank.parser.save_args(args, args.run_dir+'args_{}.json'.format(os.path.basename(__file__)))

boundaries = get_boundary_box_from_args(args)

sample_format = 'mceta_nonspinning'

	######
	#	Loading PSD and initializing metric
	######

m = cbc_metric(args.variable_format,
	PSD = load_PSD(args.psd, args.asd, args.ifo, df = args.df),
	approx = args.approximant,
	f_min = args.f_min, f_max = args.f_max)

	######
	#	Drawing initial points (either from file or from sampling)
	#	Mass boundaries are complicated, while the rest is just a rectangle in the coordinates
	######
if args.datapoints_file:
		datapoints = np.loadtxt(args.datapoints_file)[:args.n_datapoints,:D]
else:
	datapoints = np.random.uniform(*boundaries, (args.n_datapoints, D))
			
		#Checking boundaries
	if False:
		plot_tiles_templates(datapoints[boundary_keeper(args)(datapoints, args.variable_format)], args.variable_format, show = True)
		quit()
			
	new_masses = get_random_masses(boundaries[:,:2])
	datapoints[:,:2] = new_masses

	######
	#	LL computation
	######	
	
metric_list = []
for d in tqdm(datapoints, disable = not args.verbose, desc = 'Generating dataset'):
	try:
		metric = m.get_metric(d, metric_type = args.metric_type)
		if args.only_ll:
			dataset_row = [0.5*np.log(np.abs(np.linalg.det(metric)))]
		else:
			dataset_row = np.array([*metric.flatten(), 0.5*np.log(np.abs(np.linalg.det(metric)))])
		metric_list.append(dataset_row)
	except KeyboardInterrupt:
		warnings.warn("KeyboardInterrupt: interuppting dataset generation and continuing with the flow training")
		break
	
to_save = np.concatenate([datapoints[:len(metric_list)], metric_list], axis = 1)
if args.only_ll:
	header = 'variable format: {0}\ntheta ({1},) | log_pdf (1,) '.format(args.variable_format, D)
else:
	header = 'variable format: {0}\ntheta ({1},) | metric ({1}*{1},) | log_pdf (1,) '.format(args.variable_format, D)
np.savetxt(args.dataset, to_save, header = header)

if args.verbose: print("Dataset saved")

	######
	#	Plotting
	######

if args.plot:
	plot_tiles_templates(datapoints[:10000, :D], args.variable_format,
	tiling = None, dist_ellipse = None, save_folder = None, show = False, title = 'Validation data')
	if args.show: plt.show()
