#!/usr/bin/env python
"""
mbank_injections
----------------

A script to randomly draw injections inside a bank. Injections can be also read from an xml file.
It needs a tiling file so that the match can be computed with the metric approximation.

To perform injections:

	mbank_injections --options-you-like

You can also load (some) options from an ini-file:

	mbank_injections --some-options other_options.ini

Make sure that the mbank is properly installed.
To know which options are available:

	mbank_injections --help
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches
import sys
import warnings

from ligo.lw import utils as lw_utils
from ligo.lw import ligolw
from ligo.lw import table as lw_table
from ligo.lw import lsctables

from mbank import variable_handler, cbc_metric, cbc_bank, tiling_handler
from mbank.utils import ray_compute_injections_match, compute_injections_match, compute_injections_metric_match, load_PSD, plot_tiles_templates, save_injs, get_antenna_patterns, get_random_sky_loc, plot_match_histogram, save_inj_stat_dict, initialize_inj_stat_dict, read_xml
import mbank.parser
from mbank.flow import STD_GW_Flow

from itertools import combinations, permutations

import argparse
import os

@lsctables.use_in
class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
	pass

##################################
##### Creating parser
parser = argparse.ArgumentParser(__doc__)
mbank.parser.add_general_options(parser)
mbank.parser.add_metric_options(parser)
mbank.parser.add_injections_options(parser)
mbank.parser.add_range_options(parser) #only to generate injections for the flow...

#FIXME: it would be nice if the flow object had incorporated the boundaries, maybe in the form of some non-trainable weights...

parser.add_argument(
	"--bank-file", required = False, type = str,
	help="Path to the file of a bank. If no path to the file is provided, it is understood it is located in run-dir.")
parser.add_argument(
	"--tiling-file", required = False, type = str,
	help="The input file with a tiling. It must be generated by a tiling_handler object. If no path to the file is provided, it is understood it is located in run-dir.")
parser.add_argument(
	"--flow-file", required = False, type = str,
	help="An optional file with weigths for a normalizing flow model. If provided it will be used to interpolate the metric when computing the metric injection match.")
parser.add_argument(
	"--mm", type = float,
	help="Minimum match for the bank - for plotting purposes")

args, filenames = parser.parse_known_args()

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = mbank.parser.updates_args_from_ini(f, args, parser)

##################################################
	######
	#	Interpreting the parser and initializing variables
	######

var_handler = variable_handler()

if (args.bank_file is None) or (args.variable_format is None):
	raise ValueError("The arguments bank-file and variable_format must be set!")
if (args.inj_file is None):
	if (args.tiling_file is None) and (args.flow_file is None):
		raise ValueError("If the argument --inj-file is not set, you must provide a way to draw injections through arguments --tiling-file or --flow-file")

assert args.variable_format in var_handler.valid_formats, "Wrong value {} for variable-format".format(args.variable_format)

	#Defining some convenience variables
full_match = args.full_match or args.full_symphony_match
metric_match = isinstance(args.tiling_file, str)

	#FIXME: understand whether you want the default --run-dir to be './'
if args.run_dir is None: args.run_dir = args.bank_file.replace(os.path.basename(args.bank_file), '')
if args.run_dir =='': args.run_dir = './'
if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)

	#setting default locations for injection file, tiling file and bank file
if isinstance(args.inj_file, str):
	if args.inj_file.find('/') <0: args.inj_file = args.run_dir+args.inj_file
if isinstance(args.stat_dict, str):
	if args.stat_dict.find('/') <0: args.stat_dict = args.run_dir+args.stat_dict
else:
	args.stat_dict = args.run_dir+os.path.basename(args.bank_file).split('.')[0]+'-injections_stat_dict.json'
if isinstance(args.flow_file, str):
	if args.flow_file.find('/') <0: args.flow_file = args.run_dir+args.flow_file
if metric_match:
	if args.tiling_file.find('/') <0: args.tiling_file = args.run_dir+args.tiling_file
if args.bank_file.find('/') <0: args.bank_file = args.run_dir+args.bank_file
if isinstance(args.psd, str): args.psd = args.run_dir+args.psd if args.psd.find('/') <0 else args.psd

plot_folder = None
if args.plot: plot_folder = args.run_dir

mbank.parser.save_args(args, args.run_dir+'args_{}.json'.format(os.path.basename(__file__)))

	######
	#	Loading PSD, and creating an istance of metric (if it applies)
	######
if full_match: #in this case we need to load the PSD and instantiate the metric
		#load PSD
	f, PSD = load_PSD(args.psd, args.asd, args.ifo, df = args.df)

		#create metric
	metric = cbc_metric(PSD = (f, PSD), approx = args.approximant,
			f_min = args.f_min, f_max = args.f_max,
			variable_format = args.variable_format)

	######
	#	Loading the bank and instatianting a variable handler
	######
bank = cbc_bank(args.variable_format, args.bank_file)
if args.verbose: print("Loaded bank {} with {} templates".format(args.bank_file, bank.templates.shape[0]))

	#Loading the tiling and the flow, if it's the case
if metric_match:
	t_obj = tiling_handler(args.tiling_file)
	if isinstance(args.flow_file, str): t_obj.load_flow(args.flow_file)
elif isinstance(args.flow_file, str):
	flow = STD_GW_Flow.load_flow(args.flow_file)

	######
	#	Generating injections
	######

if args.inj_file is None:

	assert isinstance(args.n_injs, int), "Argument --n-injs must be specified if an injection file is not given"
		#generating tiles injections
	if args.tiling_file:
		projected_injs = t_obj.sample_from_tiling(args.n_injs, seed = args.seed)
	else:
		bk = mbank.parser.boundary_keeper(args)
		def boundaries_checker(theta):
			return bk(theta, args.variable_format)
		projected_injs, _ = flow.sample_within_boundaries(args.n_injs, boundaries_checker, seed = args.seed)
	
	#projected_injs = bank.templates; print("######### Bank injections!! ###########") #FIXME: this will become a bank-injs option!! Or something similar
	
	injections_full = var_handler.get_BBH_components(projected_injs, args.variable_format)
	del projected_injs #storing projected_injs is a waste of memory that can scale super badly
	
		#dealing with sky location
	
		#Storing injections to file
	inj_filename = args.run_dir+'injections.xml'
	if args.fixed_sky_loc_polarization:
		sky_locs = np.array(args.fixed_sky_loc_polarization)
	else:
		sky_locs = np.stack([*get_random_sky_loc(args.n_injs, args.seed)], axis = 1)

	save_injs(inj_filename, injections_full, 0, 10*args.n_injs, time_step = 10,
			approx = args.approximant,
			sky_locs = sky_locs, luminosity_distance = 100,
			f_min = args.f_min, f_max = args.f_max)
else:
		#loading injections from file and projecting them on the bank manifold
		#The injection fle can be either xml or json
	if args.inj_file.endswith('xml') or args.inj_file.endswith('xml.gz'):
		warnings.warn('When loading injections from an xml file, eccentricity and meanano are currently not supported')
		
		injections_full, sky_locs_xml = read_xml(args.inj_file, lsctables.SimInspiralTable, args.n_injs)
		args.n_injs = injections_full.shape[0]
			#the sky_locs are read only if args.sky_loc_polarization is not set
		sky_locs = np.array(args.fixed_sky_loc_polarization) if args.fixed_sky_loc_polarization else sky_locs_xml
		
	elif args.inj_file.endswith('json'):
		raise NotImplementedError("json format not implemented yet")
	else:
		raise ValueError("Format for the input file for the injections must be an xml")
	
	if args.verbose: print("Loaded {} injections from file: {}".format(args.n_injs, args.inj_file))

if not full_match: sky_locs = None

	######
	#	Computing tiling_injections with metric approximation
	#	Here injections are projected to the bank manifold
	######
injection_stat_dict = initialize_inj_stat_dict(injections_full, sky_locs)

if metric_match: 
	injection_stat_dict = compute_injections_metric_match(injection_stat_dict, bank, t_obj, verbose = args.verbose)

	######
	#	Recomputing the match with the actual match and with the actual injections (non-projected)
	#	(if it's the case)
	######
if full_match:
		#this function should take the stat_dict and for each best_id, it should replace the metric match with the actual match
		#The actual match shall be computed with metric.WF_match
		#The function returns an updated version of injection_stat_dict
	metric.set_variable_format('BBH_components')

	if args.use_ray:
		injection_stat_dict = ray_compute_injections_match(injection_stat_dict, bank, metric,
				mchirp_window = args.mchirp_window,
				symphony_match = args.full_symphony_match, verbose = args.verbose)
	else:
		injection_stat_dict = compute_injections_match(injection_stat_dict, bank, metric,
				mchirp_window = args.mchirp_window,
				symphony_match = args.full_symphony_match, verbose = args.verbose)
	metric.set_variable_format(args.variable_format)

	######
	#	Writing to file
	######
save_inj_stat_dict(args.stat_dict, injection_stat_dict)

	######
	#	Making some plots & printing percentiles
	######

	#projecting over the template manifold
projected_injs = var_handler.get_theta(injections_full, args.variable_format) #(N,D)
	
	#histogram for the overall match
matches_metric = injection_stat_dict['metric_match']
if metric_match: matches_metric[matches_metric<=0] = 0.01
matches = injection_stat_dict['match']

if args.verbose:
	if metric_match:
		print("Metric match percentiles [50, 5, 1]: ", np.percentile(matches_metric, [50, 5,1]))
	if full_match:
		print("Match percentiles [50, 5, 1]: ", np.percentile(matches, [50, 5,1]))
plotted_matches = matches if full_match else matches_metric
if plotted_matches is None:
	warnings.warn("You chose not to compute neither the match nor the metric match: there's nothing to do here.")
	quit()

if args.plot:

		###
		# Histogram
		###
	plot_match_histogram(matches_metric, matches,
			mm = args.mm, bank_name = os.path.basename(args.bank_file),
			save_folder = args.run_dir)

		###
		# Colored plots
		###
	#if isinstance(args.mm, (float, int)):
	#	ids_to_plot = np.where(plotted_matches<args.mm)[0]
	#else:
	#	ids_to_plot = range(len(plotted_matches))

	ids_to_plot = range(len(plotted_matches))
	plot_tiles_templates(bank.templates, args.variable_format, #t_obj,
				injections = projected_injs[ids_to_plot,:], inj_cmap =plotted_matches[ids_to_plot] ,
				dist_ellipse = None, save_folder = args.run_dir, savetag = 'injection_study', show = False)
	if args.show: plt.show()
