#!/usr/bin/env python
"""
mbank_place_templates
---------------------

A script to place the templates,
It operates in two modes:
	- If a tiling file is given, the templates are placed according to the tiling file (plus an optional flow)
	- If only a flow is given, the templates are placed according to the flow. In this case a dataset needs to be provided as a source of livepoints

To generate a bank:

	mbank_place_templates --options-you-like
	
You can also load (some) options from an ini-file:

	mbank_place_templates --some-options other_options.ini

Make sure that the mbank is properly installed.
To know which options are available:

	mbank_place_templates --help
"""
import numpy as np
import matplotlib.pyplot as plt
import sys

from mbank import cbc_bank, tiling_handler, variable_handler, cbc_metric
from mbank.utils import plot_tiles_templates, avg_dist, load_PSD
from mbank.placement import place_random_flow, place_stochastically_flow, place_geometric_flow
from mbank.flow.utils import compare_probability_distribution
from mbank.flow import STD_GW_Flow
from mbank.parser import make_sub_file
import mbank.parser

import argparse
import os

##### Creating parser
parser = argparse.ArgumentParser(__doc__)

mbank.parser.add_general_options(parser)
mbank.parser.add_metric_options(parser)
mbank.parser.add_template_placement_options(parser)
mbank.parser.add_flow_options(parser)
mbank.parser.add_range_options(parser) #only for the flow

parser.add_argument(
	"--bank-file", required = False, type = str,
	help="Path to the file to save the bank. If no path to the file is provided, it is understood it is located in run-dir. If not given, a suitable default name will be set.")
parser.add_argument(
	"--seed-bank", required = False, type = str,
	help="Path to a seed bank: it will be used a starting point for the stochastic method")
parser.add_argument(
	"--tiling-file", required = False, type = str,
	help="The input file with a tiling. It must be generated by a tiling_handler object. If no path to the file is provided, it is understood it is located in run-dir")
parser.add_argument(
	"--train-flow", action='store_true', default = False,
	help="Whether to train a normalizing flow model for the tiling. It will be used for metric interpolation (only applicable if a tiling is provided)")
parser.add_argument(
	"--dry-run", action = 'store_true', default = False,
	help="If set, it will only compute the number of templates, without actually placing the templates. Useful for volume studies. (only applicable for volume estimation purposes)")
parser.add_argument(
	"--make-sub",  default = False, action='store_true',
	help="If set, it will make a condor submit file that the user can use to launch the job through condor.")
	

args, filenames = parser.parse_known_args()

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = mbank.parser.updates_args_from_ini(f, args, parser)

##################################################
	######
	#	Interpreting the parser and initializing variables
	######
if (args.mm is None) or (args.variable_format is None):
	raise ValueError("The arguments --mm, --variable-formatmust be set!")
assert args.tiling_file or args.flow_file, "At least one argument between --tiling-file and --flow-file"

assert args.mm <1 and args.mm>0, "The minimum match must be lower than 1 and greater than 0"

var_handler = mbank.variable_handler()
assert args.variable_format in var_handler.valid_formats, "Wrong value {} for variable-format".format(args.variable_format)
D = var_handler.D(args.variable_format)

if not args.run_dir:
	args.run_dir = './out_{}/'.format(args.run_name)
	if args.tiling_file:
		if args.tiling_file.find('/') >=0: args.run_dir = args.tiling_file.replace(os.path.basename( args.tiling_file), '')
	elif args.flow_file:
		if args.flow_file.find('/') >=0: args.run_dir = args.flow_file.replace(os.path.basename( args.flow_file), '')

if args.run_dir =='': args.run_dir = './'
if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)

if args.tiling_file:
	if args.tiling_file.find('/') <0: args.tiling_file = args.run_dir+args.tiling_file

if args.flow_file:
	if args.flow_file.find('/') <0: args.flow_file = args.run_dir+args.flow_file
elif args.train_flow:
		#giving a default name in case the --flow-file is not given but the training is required
	args.flow_file = args.run_dir+'flow_{}.zip'.format(args.run_name)

if args.bank_file:
	if args.bank_file.find('/') <0: args.bank_file = args.run_dir+args.bank_file
else:
	args.bank_file = args.run_dir+'bank_{}.xml.gz'.format(args.run_name)

if args.make_sub:
	make_sub_file(args.run_dir, args.run_name)
	quit()

plot_folder = None
if args.plot: plot_folder = args.run_dir

if not args.dry_run: mbank.parser.save_args(args, args.run_dir+'args_{}.json'.format(os.path.basename(__file__)))

	######
	#	Generating objs and tiling
	######

bank = cbc_bank(args.variable_format)

if args.seed_bank:
	if args.seed_bank.find('/') <0: args.seed_bank = args.run_dir+args.seed_bank
	seed_bank = cbc_bank(args.variable_format, args.seed_bank)
	if args.verbose: print('Loading seed bank from: ', args.seed_bank)

if args.tiling_file:
	t_obj = tiling_handler(args.tiling_file)

		#Loading the flow, if it's the case
	t_obj = tiling_handler(args.tiling_file)

	if args.train_flow:
		t_obj.train_flow(N_epochs = args.n_epochs, n_layers = args.n_layers, hidden_features = args.hidden_features, verbose = True)
		t_obj.flow.save_weigths(args.flow_file)
	elif args.flow_file: t_obj.load_flow(args.flow_file)

	bank.place_templates(t_obj, args.mm, placing_method = args.placing_method, N_livepoints = args.n_livepoints, covering_fraction = args.covering_fraction, empty_iterations = args.empty_iterations, verbose = True)
	
	if args.dry_run:
		print("Placed {} templates".format(len(bank.templates)))
		quit()

else:

	#TODO: this part here should be re-arranged in such a way that MC estimation of the livepoints makes sense
	#You will need sample the flow for some livepoints and evaluate the metric for each of them
	#Mybe you can also do importance sampling to estimate the covering fraction more seriously... :D

		#Loading boundaries
	bk = mbank.parser.boundary_keeper(args)
	def boundaries_checker(theta):
		return bk(theta, args.variable_format)

		#Loading metric
	m = cbc_metric(args.variable_format,
		PSD = load_PSD(args.psd, args.asd, args.ifo, df = args.df),
		approx = args.approximant,
		f_min = args.f_min, f_max = args.f_max)

	flow = STD_GW_Flow.load_flow(args.flow_file)

	if args.placing_method == 'random':
		new_templates = place_random_flow(args.mm, flow, m, args.n_livepoints, boundaries_checker, covering_fraction = args.covering_fraction, dry_run = args.dry_run, importance_sampling = False, metric_type = args.metric_type, verbose = args.verbose)
	elif args.placing_method == 'stochastic':
		new_templates = place_stochastically_flow(args.mm, flow, m, boundaries_checker, empty_iterations = args.empty_iterations,
			seed_bank = seed_bank.templates if args.seed_bank else None,
			dry_run = args.dry_run, verbose = args.verbose)
	elif args.placing_method == 'geometric' or args.placing_method == 'qmc':
		new_templates = place_geometric_flow(args.mm, flow, m, args.n_livepoints, boundaries_checker,
			covering_fraction = args.covering_fraction, qmc = (args.placing_method == 'qmc'), dry_run = args.dry_run, verbose = args.verbose)
	else:
		raise ValueError("Input method {} not implemented for the normalizing flow model".format(args.placing_method))
	
	if args.dry_run:
		print("Placed {} templates".format(new_templates))
		quit()
	bank.add_templates(new_templates)


	######
	#	Plotting & saving
	######

print("Generated bank with {} templates".format(len(bank.templates)))
print("Saving bank to {}".format(args.bank_file))
bank.save_bank(args.bank_file, args.f_max)

if args.plot:
	dist = None #avg_dist(args.mm, bank.D) if bank.D == 2 else None
	plot_tiles_templates(bank.templates, args.variable_format,
			dist_ellipse = dist, save_folder = plot_folder, show = args.show)

	if args.tiling_file:
		if t_obj.flow:
			compare_probability_distribution(t_obj.flow.sample(5000).detach().numpy(), data_true = t_obj.sample_from_tiling(5000),
				variable_format = args.variable_format,
				title = None, hue_labels = ['flow', 'tiling'],
				savefile = '{}/flow.png'.format(plot_folder), show = args.show)



















	
	
	
