#!/usr/bin/env python
"""
mbank_run
---------

A script to run mbank for generating a bank

To generate a bank:

	mbank_run --options-you-like

You can also load (some) options from an ini-file:

	mbank_run --some-options other_options.ini

Make sure that the mbank is properly installed.

To create a sub file and run in condor:

	mbank_run --make-sub options.ini

To know which options are available:

	mbank_run --help
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
import warnings
from ligo.lw.utils import load_filename

from mbank import variable_handler, cbc_metric, cbc_bank
from mbank.parser import get_boundary_box_from_args, make_sub_file
from mbank.utils import load_PSD, avg_dist, plot_tiles_templates, get_boundaries_from_ranges
from mbank.flow.utils import compare_probability_distribution
import mbank.parser

import argparse
import os

##### Creating parser
parser = argparse.ArgumentParser(__doc__)

mbank.parser.add_general_options(parser)
mbank.parser.add_metric_options(parser)
mbank.parser.add_range_options(parser)
mbank.parser.add_tiling_generation_options(parser)
mbank.parser.add_template_placement_options(parser)
mbank.parser.add_flow_options(parser)

parser.add_argument(
	"--bank-file", required = False, type = str,
	help="Path to the file to save the bank. If no path to the file is provided, it is understood it is located in run-dir. If not given, a suitable default name will be set.")
parser.add_argument(
	"--use-ray", action='store_true', default = False,
	help="Whether to use ray package to parallelize the code execution")
parser.add_argument(
	"--train-flow", action='store_true', default = False,
	help="Whether to train a normalizing flow model for the tiling. It will be used for metric interpolation")
parser.add_argument(
	"--make-sub",  default = False, action='store_true',
	help="If set, it will make a condor submit file that the user can use to launch the job through condor.")

args, filenames = parser.parse_known_args()

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = mbank.parser.updates_args_from_ini(f, args, parser)

####################################################################################################
	######
	#	Interpreting the parser and initializing variables
	######

if (args.psd is None) or (args.mm is None) or (args.variable_format is None):
	raise ValueError("The arguments mm, psd and variable_format must be set!")

var_handler = variable_handler()
assert args.variable_format in var_handler.valid_formats, "Wrong value {} for variable-format".format(args.variable_format)

if args.run_dir == './out_$(run_name)':	args.run_dir = './out_{}/'.format(args.run_name)
if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)
if args.psd.find('/') <0: args.psd = args.run_dir+args.psd

if args.grid_size is None: args.grid_size = tuple([1 for i in range(var_handler.D(args.variable_format))])

if args.bank_file:
	if args.bank_file.find('/') <0: args.bank_file = args.run_dir+args.bank_file
else:
	args.bank_file = args.run_dir+'bank_{}.xml.gz'.format(args.run_name)

if args.make_sub:
	make_sub_file(args.run_dir, args.run_name)
	quit()

plot_folder = None
if args.plot: plot_folder = args.run_dir

mbank.parser.save_args(args, args.run_dir+'args_{}.json'.format(os.path.basename(__file__)))

format_info = var_handler.format_info[args.variable_format]

	######
	#	Setting boundaries: shape (2,D)
	######
boundaries = get_boundary_box_from_args(args)

	######
	#	Loading PSD and initializing metric
	######
m = cbc_metric(args.variable_format,
			PSD = load_PSD(args.psd, args.asd, args.ifo, df = args.df),
			approx = args.approximant,
			f_min = args.f_min, f_max = args.f_max)

print("## Running: ", args.run_name)

	######
	#	Running and saving the output
	######
bank = cbc_bank(args.variable_format)

if format_info['mass_format'] == 'm1m2':
	raise NotImplementedError("Currently no template placement is implemented for the mass format m1m2.")

else:
	t_obj = bank.generate_bank(m, minimum_match = args.mm, boundaries = boundaries,
		max_depth = args.max_depth, tolerance = args.tile_tolerance,
		placing_method = args.placing_method, metric_type = args.metric_type,
		grid_list = args.grid_size, train_flow = args.train_flow,
		N_livepoints = args.n_livepoints, covering_fraction = args.covering_fraction, empty_iterations = args.empty_iterations,
		use_ray = args.use_ray, n_layers = args.n_layers, hidden_features = args.hidden_features, N_epochs = args.n_epochs)
	
	tiling_file = args.run_dir+'tiling_{}.npy'.format(args.run_name)
	flow_file = args.run_dir+'flow_{}.zip'.format(args.run_name) if args.train_flow else None
	t_obj.save(tiling_file, flow_file) #saving also the flow

if bank.templates is None:
	print("No templates were added to the bank: the bank is not saved")
	quit()

print("Generated bank with {} templates and {} tiles".format(len(bank.templates), len(t_obj)))
print("Saving bank to {}".format(args.run_dir))
bank.save_bank(args.bank_file, args.f_max)

if args.plot:
	dist = None #avg_dist(args.mm, bank.D) if bank.D == 2 else None
	plot_tiles_templates(bank.templates, args.variable_format, #t_obj,
		dist_ellipse = dist, save_folder = plot_folder, show = args.show)
	
	if t_obj.flow:
		compare_probability_distribution(t_obj.flow.sample(5000).detach().numpy(), data_true = t_obj.sample_from_tiling(5000),
			variable_format = args.variable_format,
			title = None, hue_labels = ['flow', 'tiling'],
			savefile = '{}/flow.png'.format(plot_folder), show = args.show)
		
	













