#!/usr/bin/env python
"""
mbank_injections
----------------

A script to randomly draw injections inside a bank. Injections can be also read from an xml file.
It needs a tiling file so that the match can be computed with the metric approximation.

To perform injections:

	mbank_injections --options-you-like

You can also load (some) options from an ini-file:

	mbank_injections --some-options other_options.ini

Make sure that the mbank is properly installed.
To know which options are available:

	mbank_injections --help
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches
import sys
import warnings

from ligo.lw import utils as lw_utils
from ligo.lw import ligolw
from ligo.lw import table as lw_table
from ligo.lw import lsctables

from mbank import variable_handler, cbc_metric, cbc_bank, tiling_handler
from mbank.utils import ray_compute_injections_match, compute_injections_match, compute_injections_metric_match, load_PSD, plot_tiles_templates, save_injs, get_antenna_patterns, get_random_sky_loc, plot_match_histogram, save_inj_stat_dict, initialize_inj_stat_dict
from mbank.utils import updates_args_from_ini, read_xml

from itertools import combinations, permutations

import argparse
import os

@lsctables.use_in
class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
	pass

##################################
##### Creating parser
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
	"--n-injs", type = int, default = None,
	help="Number of injections. If inj-file is specified, they will be read from it; otherwise they will be randomly drawn from the tiling and saved to file. If None, all the injections will be read from inj-file and it will throw an error if such file is not provided.")
parser.add_argument(
	"--fixed-sky-loc-polarization", type = float, nargs = 3, default = None,
	help="Sky localization and polarization angles for the signal injections. They must be a tuple of float in the format (longitude,latitude,polarization). If None, the angles will be loaded from the injection file, if given, or uniformly drawn from the sky otherwise.")
parser.add_argument(
	"--inj-file", type = str, default = None,
	help="An xml injection file to load the injections from. If not provided, the injections will be performed at random in each tile (injs-per-tile). If no path to the file is provided, it is understood it is located in run-dir.")
parser.add_argument(
	"--stat-dict", type = str, default = None,
	help="The name of the file in which the results of the injection study will be saved (either json or pkl). If None, a suitable default will be provided.")
parser.add_argument(
	"--bank-file", required = False, type = str,
	help="Path to the file of a bank. If no path to the file is provided, it is understood it is located in run-dir.")
parser.add_argument(
	"--tiling-file", required = False, type = str,
	help="The input file with a tiling. It must be generated by a tiling_handler object. If no path to the file is provided, it is understood it is located in run-dir.")
parser.add_argument(
	"--flow-file", required = False, type = str,
	help="An optional file with weigths for a normalizing flow model. If provided it will be used to interpolate the metric when computing the metric injection match.")
	#If not given, it is understood that a file 'tiling_${RUN_NAME}.npy' is located in run dir
parser.add_argument(
	"--mchirp-window", type = float, default = 0.1,
	help="The window in relative chirp mass inside which the templates are considered for full match (if --full-match is specified)")
parser.add_argument(
	"--full-match", action='store_true', default = False,
	help="Whether to perform the full standard match computation. If False, a metric approximation to the match will be used")
parser.add_argument(
	"--full-symphony-match", action='store_true', default = False,
	help="Whether to perform the full symphony match computation. If False, a metric approximation to the match will be used")
parser.add_argument(
	"--seed", type = int, default = None,
	help="Random seed for extracting the random injections (if it applies)")
parser.add_argument(
	"--use-ray", action='store_true', default = False,
	help="Whether to use ray package to parallelize the match computation")

###########
parser.add_argument(
	"--variable-format", required = False,
	help="Choose which variables to include in the bank. Valid formats are those of `mbank.handlers.variable_format`")
parser.add_argument(
	"--psd",  required = False,
	help="The input file for the PSD: it can be either a txt either a ligolw xml file. Only applies if full-match is True")
parser.add_argument(
	"--asd",  default = False, action='store_true',
	help="Whether the input file has an ASD (sqrt of the PSD)")
parser.add_argument(
	"--ifo", default = 'L1', type=str, choices = ['L1', 'H1', 'V1'],
	help="Interferometer name: it can be L1, H1, V1. This is a field for the xml PSD file")
parser.add_argument(
	"--plot", action='store_true',
	help="Whether to make some plots. They will be store in run-dir")
parser.add_argument(
	"--show", action='store_true',
	help="Whether to show the plots. (It authomatically sets the --plot option)")
parser.add_argument(
	"--verbose", action='store_true',
	help="Whether to print the output.")
parser.add_argument(
	"--run-dir", default = None,
	help="Output directory in which the output will be saved. If None, it will be the same folder of --bank-file. Unless explicitly stated, every input file will be understood to be in this run-dir")
parser.add_argument(
	"--f-min",  default = 10., type=float,
	help="Minium frequency for the scalar product")
parser.add_argument(
	"--f-max",  default = 1024., type=float,
	help="Maximum frequency for the scalar product")
parser.add_argument(
	"--mm", type = float,
	help="Minimum match for the bank - for plotting purposes")
parser.add_argument(
	"--approximant", default = 'IMRPhenomPv2',
	help="LAL approximant for the bank generation")

args, filenames = parser.parse_known_args()

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = updates_args_from_ini(f, args, parser)

##################################################
	######
	#	Interpreting the parser and initializing variables
	######

var_handler = variable_handler()

if (args.bank_file is None) or (args.variable_format is None):
	raise ValueError("The arguments bank-file and variable_format must be set!")
if (args.tiling_file is None) and (args.inj_file is None):
	raise ValueError("If the argument --tiling-file is not set, you must provide an injection file through argument --inj-file")

assert args.variable_format in var_handler.valid_formats, "Wrong value {} for variable-format".format(args.variable_format)

	#Defining some convenience variables
full_match = args.full_match or args.full_symphony_match
metric_match = isinstance(args.tiling_file, str)

	#FIXME: understand whether you want the default --run-dir to be './'
if args.run_dir is None: args.run_dir = args.bank_file.replace(os.path.basename(args.bank_file), '')
if args.run_dir =='': args.run_dir = './'
if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)

	#setting default locations for injection file, tiling file and bank file
if isinstance(args.inj_file, str):
	if args.inj_file.find('/') <0: args.inj_file = args.run_dir+args.inj_file
if isinstance(args.stat_dict, str):
	if args.stat_dict.find('/') <0: args.stat_dict = args.run_dir+args.stat_dict
else:
	args.stat_dict = args.run_dir+os.path.basename(args.bank_file).split('.')[0]+'-injections_stat_dict.json'
if isinstance(args.flow_file, str):
	if args.flow_file.find('/') <0: args.flow_file = args.run_dir+args.flow_file
if metric_match:
	if args.tiling_file.find('/') <0: args.tiling_file = args.run_dir+args.tiling_file
if args.bank_file.find('/') <0: args.bank_file = args.run_dir+args.bank_file
if isinstance(args.psd, str): args.psd = args.run_dir+args.psd if args.psd.find('/') <0 else args.psd

plot_folder = None
if args.plot: plot_folder = args.run_dir

	######
	#	Loading PSD, and creating an istance of metric (if it applies)
	######
if full_match: #in this case we need to load the PSD and instantiate the metric
		#load PSD
	f, PSD = load_PSD(args.psd, args.asd, args.ifo)

		#create metric
	metric = cbc_metric(PSD = (f, PSD), approx = args.approximant,
			f_min = args.f_min, f_max = args.f_max,
			variable_format = args.variable_format)

	######
	#	Loading the bank and instatianting a variable handler
	######
bank = cbc_bank(args.variable_format, args.bank_file)
if args.verbose: print("Loaded bank {} with {} templates".format(args.bank_file, bank.templates.shape[0]))

	#Loading the tiling and the flow, if it's the case
if metric_match:
	t_obj = tiling_handler(args.tiling_file)
	if isinstance(args.flow_file, str): t_obj.load_flow(args.flow_file)

	######
	#	Generating injections
	######

if args.inj_file is None:

	assert isinstance(args.n_injs, int), "Argument --n-injs must be specified if an injection file is not given"
		#generating tiles injections
	projected_injs = t_obj.sample_from_tiling(args.n_injs, seed = args.seed)
	
	#projected_injs = bank.templates; print("######### Bank injections!! ###########") #FIXME: this will become a bank-injs option!! Or something similar
	
	injections_full = var_handler.get_BBH_components(projected_injs, args.variable_format)
	del projected_injs #storing projected_injs is a waste of memory that can scale super badly
	
		#dealing with sky location
	
		#Storing injections to file
	inj_filename = args.run_dir+'injections.xml'
	if args.fixed_sky_loc_polarization:
		sky_locs = np.array(args.fixed_sky_loc_polarization)
	else:
		sky_locs = np.stack([*get_random_sky_loc(args.n_injs, args.seed)], axis = 1)

	save_injs(inj_filename, injections_full, 0, 10*args.n_injs, time_step = 10,
			approx = args.approximant,
			sky_locs = sky_locs, luminosity_distance = 100,
			f_min = args.f_min, f_max = args.f_max)
else:
		#loading injections from file and projecting them on the bank manifold
		#The injection fle can be either xml or json
	if args.inj_file.endswith('xml') or args.inj_file.endswith('xml.gz'):
		warnings.warn('When loading injections from an xml file, eccentricity and meanano are currently not supported')
		
		injections_full, sky_locs_xml = read_xml(args.inj_file, lsctables.SimInspiralTable, args.n_injs)
		args.n_injs = injections_full.shape[0]
			#the sky_locs are read only if args.sky_loc_polarization is not set
		sky_locs = np.array(args.fixed_sky_loc_polarization) if args.fixed_sky_loc_polarization else sky_locs_xml
		
	elif args.inj_file.endswith('json'):
		raise NotImplementedError("json format not implemented yet")
	else:
		raise ValueError("Format for the input file for the injections must be an xml")
	
	if args.verbose: print("Loaded {} injections from file: {}".format(args.n_injs, args.inj_file))

if not full_match: sky_locs = None

	######
	#	Computing tiling_injections with metric approximation
	#	Here injections are projected to the bank manifold
	######
injection_stat_dict = initialize_inj_stat_dict(injections_full, sky_locs)

if metric_match: 
	injection_stat_dict = compute_injections_metric_match(injection_stat_dict, bank, t_obj, verbose = args.verbose)

	######
	#	Recomputing the match with the actual match and with the actual injections (non-projected)
	#	(if it's the case)
	######
if full_match:
		#this function should take the stat_dict and for each best_id, it should replace the metric match with the actual match
		#The actual match shall be computed with metric.WF_match
		#The function returns an updated version of injection_stat_dict
	metric.set_variable_format('m1m2_fullspins_emeanano_iotaphi')

	if args.use_ray:
		injection_stat_dict = ray_compute_injections_match(injection_stat_dict, bank, metric,
				mchirp_window = args.mchirp_window,
				symphony_match = args.full_symphony_match, verbose = args.verbose)
	else:
		injection_stat_dict = compute_injections_match(injection_stat_dict, bank, metric,
				mchirp_window = args.mchirp_window,
				symphony_match = args.full_symphony_match, verbose = args.verbose)
	metric.set_variable_format(args.variable_format)

	######
	#	Writing to file
	######
save_inj_stat_dict(args.stat_dict, injection_stat_dict)

	######
	#	Making some plots & printing percentiles
	######

	#projecting over the template manifold
projected_injs = var_handler.get_theta(injections_full, args.variable_format) #(N,D)
	
	#histogram for the overall match
matches_metric = injection_stat_dict['metric_match']
if metric_match: matches_metric[matches_metric<=0] = 0.01
matches = injection_stat_dict['match']

if args.verbose:
	if metric_match:
		print("Metric match percentiles [50, 5, 1]: ", np.percentile(matches_metric, [50, 5,1]))
	if full_match:
		print("Match percentiles [50, 5, 1]: ", np.percentile(matches, [50, 5,1]))
plotted_matches = matches if full_match else matches_metric
if plotted_matches is None:
	warnings.warn("You chose not to compute neither the match nor the metric match: there's nothing to do here.")
	quit()

if args.plot:

		###
		# Histogram
		###
	plot_match_histogram(matches_metric, matches,
			mm = args.mm, bank_name = os.path.basename(args.bank_file),
			save_folder = args.run_dir)

		###
		# Colored plots
		###
	if isinstance(args.mm, (float, int)):
		ids_to_plot = np.where(plotted_matches<args.mm)[0]
	else:
		ids_to_plot = range(len(plotted_matches))

	plot_tiles_templates(bank.templates, args.variable_format, #t_obj,
				injections = projected_injs[ids_to_plot,:], inj_cmap =plotted_matches[ids_to_plot] ,
				dist_ellipse = None, save_folder = args.run_dir, show = False)
	if args.show: plt.show()
