#!/usr/bin/env python
"""
mbank_place_templates
---------------------

A script to place the templates given a tiling file.

To generate a bank:

	mbank_place_templates --options-you-like
	
You can also load (some) options from an ini-file:

	mbank_place_templates --some-options other_options.ini

Make sure that the mbank is properly installed.
To know which options are available:

	mbank_place_templates --help
"""
import numpy as np
import matplotlib.pyplot as plt
import sys

from mbank import cbc_bank, tiling_handler, variable_handler
from mbank.utils import plot_tiles_templates, parse_from_file, avg_dist
from mbank.utils import updates_args_from_ini, int_tuple_type
from mbank.flow.utils import compare_probability_distribution

import argparse
import os

#TODO: add f_max in save_bank

##### Creating parser
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
	"--variable-format", required = False,
	help="Choose which variables to include in the bank. Valid formats are those of `mbank.handlers.variable_format`")
parser.add_argument(
	"--tiling-file", required = False, type = str,
	help="The input file with a tiling. It must be generated by a tiling_handler object. If no path to the file is provided, it is understood it is located in run-dir")
parser.add_argument(
	"--mm", required = False, type = float,
	help="Minimum match for the bank (a.k.a. average distance between templates)")
parser.add_argument(
	"--plot", action='store_true',
	help="Whether to plot the bank. Plot will be saved in run-dir")
parser.add_argument(
	"--run-dir", default = None,
	help="Output directory in which the bank will be saved. If default is used, the bank will be saved to the same dir of the tiling file")
parser.add_argument(
	"--run-name", default = 'cbc_mbank',
	help="Name for the bank output file")
parser.add_argument(
	"--placing-method", default = 'geometric', type = str, choices = cbc_bank('Mq_nonspinning').placing_methods,
	help="Which placing method to use for each tile")
parser.add_argument(
	"--n-livepoints", default = 10000, type = float,
	help="Parameter to control the number of livepoints to use in the `random` and `pruning` placing method. For `random` (or related), it represents the number of livepoints to use for the estimation of the coverage fraction. For `pruning`, it amounts to the the ratio between the number of livepoints and the number of templates placed by ``uniform`` placing method.")
parser.add_argument(
	"--covering-fraction", default = 0.01, type = float,
	help="Parameter to control the maximum fraction of livepoints alive before terminating the bank generation with the `random` and `pruning` placing method. The smaller the threshold, the higher the nuber of templates in the final bank.")
parser.add_argument(
	"--empty-iterations", default = 100, type = float,
	help="Number of consecutive rejected proposal after which the `stochastic` placing method stops.")
parser.add_argument(
	"--use-ray", action='store_true', default = False,
	help="Whether to use ray package to parallelize the metric computation")
parser.add_argument(
	"--show", action='store_true', default = False,
	help="Whether to show the plots")
parser.add_argument(
	"--f-max",  default = 1024., type=float,
	help="Final frequency for the templates stored in the bank (only for the xml format)")

parser.add_argument(
	"--flow-file", required = False, type = str,
	help="An optional file with weigths for a normalizing flow model. If provided it will be used to interpolate the metric during the template placing. If --train-flow is set, the flow model will be trained before ")
parser.add_argument(
	"--train-flow", action='store_true', default = False,
	help="Whether to train a normalizing flow model for the tiling. It will be used for metric interpolation")
parser.add_argument(
	"--n-layers", default = 2, type = int,
	help="Number of layers for the flow model to train (applicable only if --train-flow)")
parser.add_argument(
	"--hidden-features", default = 4, type = int,
	help="Number of hidden features for the masked autoregressive flow to train. (applicable only if --train-flow)")
parser.add_argument(
	"--n-epochs", default = 1000, type = int,
	help="Number of training epochs for the flow (applicable only if --train-flow)")

args, filenames = parser.parse_known_args()

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = updates_args_from_ini(f, args, parser)

##################################################
	######
	#	Interpreting the parser and initializing variables
	######
if (args.mm is None) or (args.variable_format is None) or (args.tiling_file is None):
	raise ValueError("The arguments --mm, --variable-format and --tiling-file must be set!")

assert args.variable_format in variable_handler().valid_formats, "Wrong value {} for variable-format".format(args.variable_format)

if args.run_dir is None: args.run_dir = args.tiling_file.replace(os.path.basename( args.tiling_file), '')
if args.run_dir =='': args.run_dir = './'
if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)

if args.tiling_file.find('/') <0: args.tiling_file = args.run_dir+args.tiling_file

if isinstance(args.flow_file, str):
	if args.flow_file.find('/') <0: args.flow_file = args.run_dir+args.flow_file
elif args.train_flow:
		#giving a default name in case the --flow-file is not given but the training is required
	args.flow_file = args.run_dir+'flow_{}.zip'.format(args.run_name)

plot_folder = None
if args.plot: plot_folder = args.run_dir
#plot_folder = 'show'

	######
	#	Generating objs and tiling
	######

bank = cbc_bank(args.variable_format)

t_obj = tiling_handler(args.tiling_file)

	#Loading the flow, if it's the case
t_obj = tiling_handler(args.tiling_file)

if args.train_flow:
	t_obj.train_flow(N_epochs = args.n_epochs, n_layers = args.n_layers, hidden_features = args.hidden_features, verbose = True)
	t_obj.flow.save_weigths(args.flow_file)
elif isinstance(args.flow_file, str): t_obj.load_flow(args.flow_file)

bank.place_templates(t_obj, args.mm, placing_method = args.placing_method, N_livepoints = args.n_livepoints, covering_fraction = args.covering_fraction, empty_iterations = args.empty_iterations, verbose = True)

	######
	#	Plotting & saving
	######

print("Generated bank with {} templates".format(len(bank.templates)))
print("Saving bank to {}".format(args.run_dir))
bank.save_bank(args.run_dir+'bank_{}.dat'.format(args.run_name))
bank.save_bank(args.run_dir+'bank_{}.xml.gz'.format(args.run_name), args.f_max)

if args.plot:
	dist = None #avg_dist(args.mm, bank.D) if bank.D == 2 else None
	plot_tiles_templates(bank.templates, args.variable_format, #t_obj,
			dist_ellipse = dist, save_folder = plot_folder, show = args.show)

	if t_obj.flow:
		compare_probability_distribution(t_obj.flow.sample(5000).detach().numpy(), data_true = t_obj.sample_from_tiling(5000),
			variable_format = args.variable_format,
			title = None, hue_labels = ['flow', 'tiling'],
			savefile = '{}/flow.png'.format(plot_folder), show = args.show)



















	
	
	
