#!/usr/bin/env python
"""
mbank_injections
----------------

A script to randomly draw injections inside a bank. Injections can be also read from an xml file.
It needs a tiling file so that the match can be computed with the metric approximation.

To perform injections:

	mbank_injections --options-you-like

You can also load (some) options from an ini-file:

	mbank_injections --some-options other_options.ini

Make sure that the mbank is properly installed.
To know which options are available:

	mbank_injections --help
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches
import sys
import pickle
import json
import warnings

from ligo.lw import utils as lw_utils
from ligo.lw import ligolw
from ligo.lw import table as lw_table
from ligo.lw import lsctables

from mbank import variable_handler, cbc_metric, cbc_bank, tiling_handler
from mbank.utils import ray_compute_injections_match, compute_injections_match, compute_injections_metric_match, load_PSD, plot_tiles_templates
from mbank.utils import updates_args_from_ini, read_xml

from itertools import combinations, permutations

import argparse
import os

@lsctables.use_in
class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
	pass

##### Creating parser
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
	"--n-injs", type = int, default = None,
	help="Number of injections. If inj-file is specified, they will be read from it; otherwise they will be randomly drawn from the tiling. If None, all the injections will be read from inj-file and it will throw an error if such file is not provided.")
parser.add_argument(
	"--inj-file", type = str, default = None,
	help="An xml injection file to load the injections from. If not provided, the injections will be performed at random in each tile (injs-per-tile). If no path to the file is provided, it is understood it is located in run-dir.")
parser.add_argument(
	"--bank-file", required = False, type = str,
	help="Path to the file of a bank. If no path to the file is provided, it is understood it is located in run-dir.")
parser.add_argument(
	"--tiling-file", required = False, type = str,
	help="The input file with a tiling. It must be generated by a tiling_handler object. If no path to the file is provided, it is understood it is located in run-dir.")
parser.add_argument(
	"--flow-file", required = False, type = str,
	help="An optional file with weigths for a normalizing flow model. If provided it will be used to interpolate the metric when computing the metric injection match.")
	#If not given, it is understood that a file 'tiling_${RUN_NAME}.npy' is located in run dir
parser.add_argument(
	"--match-threshold", type = float, default = 0.9,
	help="Threshold for the match. For each injection, the templates with metric match higher than match-threshold will be stored and their full match with the injection will be computed (if --full-match is specified)")
parser.add_argument(
	"--full-match", action='store_true', default = False,
	help="Whether to perform the full standard match computation. If False, a metric approximation to the match will be used")
parser.add_argument(
	"--full-match-symphony", action='store_true', default = False,
	help="Whether to perform the full symphony match computation. If False, a metric approximation to the match will be used")
parser.add_argument(
	"--seed", type = int, default = None,
	help="Random seed for extracting the random injections (if it applies)")
parser.add_argument(
	"--use-ray", action='store_true', default = False,
	help="Whether to use ray package to parallelize the match computation")
parser.add_argument(
	"--cache", action='store_true', default = False,
	help="Whether to cache the computation of the (full) match. It may fill the memory.")

###########
parser.add_argument(
	"--variable-format", required = False,
	help="Choose which variables to include in the bank. Valid formats are those of `mbank.handlers.variable_format`")
parser.add_argument(
	"--psd",  required = False,
	help="The input file for the PSD: it can be either a txt either a ligolw xml file. Only applies if full-match is True")
parser.add_argument(
	"--asd",  default = False, action='store_true',
	help="Whether the input file has an ASD (sqrt of the PSD)")
parser.add_argument(
	"--ifo", default = 'L1', type=str, choices = ['L1', 'H1', 'V1'],
	help="Interferometer name: it can be L1, H1, V1. This is a field for the xml PSD file")
parser.add_argument(
	"--plot", action='store_true',
	help="Whether to make some plots. They will be store in run-dir")
parser.add_argument(
	"--show", action='store_true',
	help="Whether to show the plots. (It authomatically sets the --plot option)")
parser.add_argument(
	"--run-dir", default = None,
	help="Output directory in which the output will be saved. If None, it will be the same folder of --bank-file. Unless explicitly stated, every input file will be understood to be in this run-dir")
parser.add_argument(
	"--f-min",  default = 10., type=float,
	help="Minium frequency for the scalar product")
parser.add_argument(
	"--f-max",  default = 1024., type=float,
	help="Maximum frequency for the scalar product")
parser.add_argument(
	"--mm", type = float,
	help="Minimum match for the bank - for plotting purposes")
parser.add_argument(
	"--approximant", default = None,
	help="LAL approximant for the bank generation")

args, filenames = parser.parse_known_args()

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = updates_args_from_ini(f, args, parser)

##################################################
	######
	#	Interpreting the parser and initializing variables
	######

var_handler = variable_handler()

if (args.bank_file is None) or (args.tiling_file is None) or (args.variable_format is None):
	raise ValueError("The arguments bank-file, tiling-file and variable_format must be set!")

assert args.variable_format in var_handler.valid_formats, "Wrong value {} for variable-format".format(args.variable_format)

	#FIXME: understand whether you want the default --run-dir to be './'
if args.run_dir is None: args.run_dir = args.bank_file.replace(os.path.basename(args.bank_file), '')
if args.run_dir =='': args.run_dir = './'
if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)

	#setting default locations for injection file, tiling file and bank file
if isinstance(args.inj_file, str):
	if args.inj_file.find('/') <0: args.inj_file = args.run_dir+args.inj_file
if isinstance(args.flow_file, str):
	if args.flow_file.find('/') <0: args.flow_file = args.run_dir+args.flow_file
if args.tiling_file.find('/') <0: args.tiling_file = args.run_dir+args.tiling_file
if args.bank_file.find('/') <0: args.bank_file = args.run_dir+args.bank_file
if isinstance(args.psd, str): args.psd = args.run_dir+args.psd if args.psd.find('/') <0 else args.psd

plot_folder = None
if args.plot: plot_folder = args.run_dir


	######
	#	Loading PSD, and creating an istance of metric (if it applies)
	######
if args.full_match or args.full_match_symphony: #in this case we need to load the PSD and instantiate the metric
		#load PSD
	f, PSD = load_PSD(args.psd, args.asd, args.ifo)

		#create metric
	metric = cbc_metric(PSD = (f, PSD), approx = args.approximant,
			f_min = args.f_min, f_max = args.f_max,
			variable_format = args.variable_format)

	######
	#	Loading the bank and instatianting a variable handler
	######
bank = cbc_bank(args.variable_format, args.bank_file)
print("Loaded bank {} with {} templates".format(args.bank_file, bank.templates.shape[0]))

	#Loading the flow, if it's the case
t_obj = tiling_handler(args.tiling_file)
if isinstance(args.flow_file, str): t_obj.load_flow(args.flow_file)

	######
	#	Generating injections
	######

if args.inj_file is None:

	assert isinstance(args.n_injs, int), "Argument --n-injs must be specified if an injection file is not given"
		#generating tiles injections
	projected_injs = t_obj.sample_from_tiling(args.n_injs, seed = args.seed)
	
	#projected_injs = bank.templates; print("######### Bank injections!! ###########") #FIXME: this will become a bank-injs option!! Or something similar
	
	injections_full = np.array(var_handler.get_BBH_components(projected_injs, args.variable_format)).T
	del projected_injs #storing projected_injs is a waste of memory that can scale super badly
else:
		#loading injections from file and projecting them on the bank manifold
		#The injection fle can be either xml or json
	if args.inj_file.endswith('xml') or args.inj_file.endswith('xml.gz'):
		warnings.warn('When loading injections from an xml file, eccentricity and meanano are currently not supported')
		
		injections_full = read_xml(args.inj_file, lsctables.SimInspiralTable, args.n_injs)
		args.n_injs = injections_full.shape[0]
		
	elif args.inj_file.endswith('json'):
		warnings.warn('When loading injections from an xml file, eccentricity and meanano are currently not supported')
		
		injections_full = []
		with open(args.inj_file) as f:
			inj_json = json.load(f)
		
		if args.n_injs is None: args.n_injs = len(inj_json['items'])
		if args.n_injs > len(inj_json['items']): args.n_injs = len(inj_json['items'])
		
		for i, row in enumerate(inj_json['items']):
			if i >= args.n_injs: break
			injections_full.append([row['m1'], row['m2'],
					row['s_1x'], row['s_1y'], row['s_1z'], row['s_2x'], row['s_2y'], row['s_2z'],
					0.,0., row['inclination'], 0.]) 

		injections_full = np.asarray(injections_full) #(N,12)		
	else:
		raise ValueError("Format for the input file for the injections must be either xml or json")
	
	print("Loaded {} injections from file: {}".format(args.n_injs, args.inj_file))

	######
	#	Computing tiling_injections with metric approximation
	#	Here injections are projected to the bank manifold
	######
injection_stat_dict = compute_injections_metric_match(injections_full, bank,  t_obj, match_threshold = args.match_threshold, verbose = True)

	######
	#	Recomputing the match with the actual match and with the actual injections (non-projected)
	#	(if it's the case)
	######
if args.full_match or args.full_match_symphony:
	templates_full = np.array(var_handler.get_BBH_components(bank.templates, args.variable_format)).T #(N,10)
		#this function should take the stat_dict and for each best_id, it should replace the metric match with the actual match
		#The actual match shall be computed with metric.WF_match
		#The function returns an updated version of injection_stat_dict
	metric.set_variable_format('m1m2_fullspins_emeanano_iotaphi')

	#best_matches_metric = np.array(injection_stat_dict['match'])

	if args.use_ray:
		injection_stat_dict = ray_compute_injections_match(injection_stat_dict, templates_full, metric,
				symphony_match = args.full_match_symphony, cache = args.cache)
	else:
		injection_stat_dict = compute_injections_match(injection_stat_dict, templates_full, metric,
				symphony_match = args.full_match_symphony, cache = args.cache)
	metric.set_variable_format(args.variable_format)

	######
	#	Writing to file
	######
save_name = os.path.basename(args.bank_file).split('.')[0]+'-injections_stat_dict.pkl'

with open(args.run_dir+save_name, 'wb') as filehandler:
	pickle.dump(injection_stat_dict, filehandler)
	
	#to load... keeping the snippet just in case :D
#with open(args.run_dir+save_name, 'rb') as filehandler:
#	injection_stat_dict = pickle.load( filehandler)

	######
	#	Making some plots & printing percentiles
	######

	#projecting over the template manifold
projected_injs = var_handler.get_theta(injections_full, args.variable_format) #(N,D)
	
	#histogram for the overall match
matches_metric = np.array(injection_stat_dict['metric_match'])	
matches_metric[matches_metric<=0] = 0.01
matches = injection_stat_dict['match']

full_match = args.full_match or args.full_match_symphony
print("Metric match percentiles [50, 5, 1]: ", np.percentile(matches_metric, [50, 5,1]))
if full_match:
	print("Match percentiles [50, 5, 1]: ", np.percentile(matches, [50, 5,1]))
plotted_matches = matches if full_match else matches_metric

if args.plot:

	fs = 15
		###
		# Histogram
		###
	plt.figure(figsize = (15,15))
	plt.title("Injection recovery for bank {}".format(os.path.basename(args.bank_file)), fontsize = fs+10)
	N_bins = int(np.sqrt(len(matches_metric)))
	plt.gca().tick_params(axis='x', labelsize=fs)
	plt.gca().tick_params(axis='y', labelsize=fs)
	logbins = np.logspace(np.log10(np.percentile(matches_metric, .5)),np.log10(max(matches_metric)), N_bins)
	plt.hist(matches_metric, bins = logbins, density = True,
				color = 'blue',	label = 'metric match',
				histtype = 'step' if full_match else 'bar')
	if full_match:
		logbins = np.logspace(np.log10(np.percentile(matches, .5)),np.log10(max(matches)), N_bins)
		plt.hist(matches, bins = logbins, histtype='bar',
				density = True, color = 'orange', label = 'match')
	
	if isinstance(args.mm, (float, int)): plt.axvline(x = args.mm, c = 'r') #DEBUG
	#plt.xscale('log') #Do you want this log-scale?
	plt.legend(fontsize = fs+3)

	plt.savefig(args.run_dir+'FF_hist.png', transparent = False)

	if isinstance(args.mm, (float, int)):
		ids_to_plot = np.where(plotted_matches<args.mm)[0]
	else:
		ids_to_plot = range(len(plotted_matches))


		###
		# Colored plots
		###
	plot_tiles_templates(bank.templates, args.variable_format, #t_obj,
				injections = projected_injs[ids_to_plot,:], inj_cmap =plotted_matches[ids_to_plot] ,
				dist_ellipse = None, save_folder = args.run_dir, show = False)
	if args.show: plt.show()
	quit()

