#!/usr/bin/env python
"""
mbank_merge
-----------

It merges the mbank output products: banks, tilings or injection statistics.
Banks must be saved in dat or xml format.
Tiling objects mustbe saved in the npy format, as produces by `tiling_handler.load()`.
The injection statistics are dictionaries in pkl format

To merge the 3 banks into a single merged_banks.xml you can type:

	mbank_merge --bank --variable-format your-format --out-name merged_banks.xml bank_1.xml bank_2.xml bank_3.xml

To merge 3 tiling you can type:

	mbank_merge --tiling --out-name merged_tiling.npy tiling_1.npy tiling_2.npy tiling_3.npy

To merge 3 injection statistics files (either json or pkl) you can type:

	mbank_merge --tiling --out-name merged_injections_stat.npy stat_1.pkl stat_2.json stat_3.pkl

Make sure that the mbank is properly installed.
To know which options are available:

	mbank_merge --help
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
import warnings

from mbank import cbc_bank, tiling_handler, variable_handler
from mbank.utils import plot_match_histogram, plot_tiles_templates, save_inj_stat_dict, load_inj_stat_dict

import argparse
import os

##### Creating parser
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
	"--tiling",  default = False, action='store_true',
	help="Whether the files to be merged are tiling files")
parser.add_argument(
	"--bank",  default = False, action='store_true',
	help="Whether the files to be merged are bank files")
parser.add_argument(
	"--injection-stat",  default = False, action='store_true',
	help="Whether the files to be merged are injection statistics files, generated by 'mbank_injections'")
parser.add_argument(
	"--variable-format", default = None, type = str,
	help="Variable format for the bank files to merge")
parser.add_argument(
	"--out-name", type = str,
	help="Name of the file to store the merged inputs")
parser.add_argument(
	"--plot",  default = False, action='store_true',
	help="Whether to produce some plots. They will be stored in the same folder as the --out-file.")


args, filenames = parser.parse_known_args()

######################################################
######################################################
######################################################

assert np.sum([args.tiling, args.bank, args.injection_stat])==1, "You must specify exactly one of the options '--bank', '--tiling', '--injection-stat'"


if args.bank:
	assert args.variable_format is not None, "If --bank is set, a variable format must be set"
	if args.out_name is None: args.out_name = 'merged_bank.dat'
	input_file_type = 'bank'
	bank = cbc_bank(args.variable_format)
if args.tiling:
	if args.out_name is None: args.out_name = 'merged_tiling.npy'
	input_file_type = 'tiling'
	tiling = tiling_handler()
if args.injection_stat:
	if args.out_name is None: args.out_name = 'merged_injections_stat_dict.json'
	input_file_type = 'injection statistics'
	stat_dict_list = []

print("Saving merged {} to file {}".format(input_file_type, args.out_name))

	##
	# Loops on given input files

for f in filenames:
	if args.bank:
		temp_bank = cbc_bank(args.variable_format, f)
		bank.add_templates(temp_bank.templates)
		del temp_bank
		print("After adding '{}' bank has {} templates".format(f, len(bank.templates)))
	if args.tiling:
		temp_tiling = tiling_handler(f)
		tiling.extend(temp_tiling)
		del temp_tiling
		print("Merged tiling has {} templates".format(len(tiling)))
	if args.injection_stat:
		stat_dict_list.append( load_inj_stat_dict(f))

save_dir = os.path.dirname(args.out_name)

if args.bank: bank.save_bank(args.out_name)
if args.tiling: tiling.save(args.out_name)

if args.injection_stat:
	merged_stat_dict = {}
	for stat_dict in stat_dict_list:
		for k,v in stat_dict.items():
			if k in merged_stat_dict:
				if k == 'sky_loc':
					if v.ndim == 1: assert np.allclose(v, merged_stat_dict[k]), "The sky locations of the injection statistics dictionaries are not compatible"
					#continue #FIXME: remove (and test) this bullshit :D
				if isinstance(v, (float, int)) or (v is None):
					assert v == merged_stat_dict[k], "The given injection statistics dictionary are not compatible: are the metadata the same?"
					continue
				merged_stat_dict[k] = np.concatenate([merged_stat_dict[k], v], axis = 0)
			else:
				merged_stat_dict[k] = v
	save_inj_stat_dict(args.out_name, merged_stat_dict)


	#Making plots
	if args.plot:

		plot_match_histogram(merged_stat_dict['metric_match'], matches = merged_stat_dict['match'],
			mm = None, bank_name = None,
			save_folder = save_dir)

		if args.variable_format:
			injs = merged_stat_dict['theta_inj']
			vh = variable_handler()
			plotted_matches = merged_stat_dict['metric_match'] if merged_stat_dict['match'] is None else merged_stat_dict['match']
			
			theta_injs = vh.get_theta(injs, args.variable_format)
			plot_tiles_templates(theta_injs, args.variable_format,
				injections = theta_injs, inj_cmap =plotted_matches,
				dist_ellipse = None, save_folder = save_dir, show = False)
		plt.show()









