#!/usr/bin/env python
"""
mbank_train_flow
----------------

Executable to train a normalizing flow to reproduce the volume element of the probability distribution:

.. math
				
	p(\\theta) \propto \sqrt{|M(\\theta)|}

The script will train a normalizing flow from a dataset with random points drawn within the space.
A dataset can be generated with mbank_generate_flow_dataset. Multiple datasets are supported

To train a flow:

	mbank_train_flow --options-you-like

You can also load (some) options from an ini-file:

	mbank_train_flow --some-options other_options.ini

Make sure that the mbank is properly installed.

To know which options are available:

	mbank_train_flow --help
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
import warnings
from ligo.lw.utils import load_filename

import torch
from torch import optim

from mbank import variable_handler, cbc_metric
from mbank.utils import load_PSD, avg_dist, plot_tiles_templates, get_boundaries_from_ranges, plot_colormap
from mbank.parser import get_boundary_box_from_args, boundary_keeper, make_sub_file
import mbank.parser

from mbank.flow import STD_GW_Flow, GW_Flow
from mbank.flow.utils import early_stopper, plot_loss_functions

from tqdm import tqdm
import argparse
import os
from lal import MTSUN_SI

############################################################################
##### Creating parser
parser = argparse.ArgumentParser(__doc__)

mbank.parser.add_general_options(parser)
mbank.parser.add_range_options(parser)
mbank.parser.add_flow_options(parser)

	#Options specific to this program
parser.add_argument(
	"--variable-format", required = False,
	help="Choose which variables to include in the bank. Valid formats are those of `mbank.handlers.variable_format`")
parser.add_argument(
	"--n-datapoints", default = None, type = int,
	help="Number of datapoints to be loaded from the dataset")
parser.add_argument(
	"--dataset", default = ['dataset.dat'], type = str, nargs = '+',
	help="Files with the datasets (more than one are accpted). If the file does not exist, the dataset will be created.")
parser.add_argument(
	"--ignore-boundaries", default = False, action = 'store_true',
	help="Whether to ignore the ranges given by command line. If set, the boudaries will be extracted from the dataset")
parser.add_argument(
	"--make-sub",  default = False, action='store_true',
	help="If set, it will make a condor submit file that the user can use to launch the job through condor.")

args, filenames = parser.parse_known_args()

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = mbank.parser.updates_args_from_ini(f, args, parser)

####################################################################################################
	######
	#	Interpreting the parser and initializing variables
	######

if (args.variable_format is None):
	raise ValueError("The argument --variable-format must be set!")

var_handler = variable_handler()
assert args.variable_format in var_handler.valid_formats, "Wrong value {} for variable-format".format(args.variable_format)
D = var_handler.D(args.variable_format)

if args.run_dir is None: args.run_dir = './out_{}/'.format(args.run_name)
if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)
if args.flow_file.find('/') <0: args.flow_file = args.run_dir+args.flow_file

if args.make_sub:
	make_sub_file(args.run_dir, args.run_name)
	quit()

for i, f in enumerate(args.dataset):
	if f.find('/') <0: f = args.run_dir+f
	args.dataset[i] = f

plot_folder = args.run_dir if args.plot else None
if args.verbose: print("## Running training of the flow: ", args.run_name)

mbank.parser.save_args(args, args.run_dir+'args_{}.json'.format(os.path.basename(__file__)))

	######
	#	Loading the dataset(s) and performing train and test splitting
	######

dataset = []
err_msg = "Wrong format of the datset! Each row must be {} dimensional or {} dimensional, i.e. with format D | D**2 | 1 or D | 1".format(D*(D+1)+1, D+1)
for f in args.dataset:
	if not os.path.isfile(f): continue
	
	if args.verbose: print("Loading dataset {}".format(f))
	
	new_dataset = np.loadtxt(f)
	assert new_dataset.ndim == 2, "Wrong dimensionality of the dataset, must be 2 dimensional"
	assert new_dataset.shape[1] in [D**2+D+1, D+1], err_msg
	
		#removing possible nans
	new_dataset = new_dataset[np.where(~np.any(np.isnan(new_dataset), axis = 1))]
	
	dataset.append(new_dataset)

if not len(dataset):
	raise RuntimeError("None of the given datasets exist")

dataset = np.concatenate(dataset, axis = 0)

boundaries = None if args.ignore_boundaries else get_boundary_box_from_args(args)
if not args.ignore_boundaries:
	boundaries_checker = boundary_keeper(args)
	ids_ = boundaries_checker(dataset[:,:D], args.variable_format)
	dataset = dataset[ids_]

if args.n_datapoints: dataset = dataset[:args.n_datapoints]
if args.verbose: print('Loaded dataset(s) with {} entries'.format(len(dataset)))

N_train = int(args.train_fraction*len(dataset))
#max_ll = np.max(dataset[:, -1])
#dataset[:, -1] = dataset[:, -1] - max_ll + 1
train_data, validation_data = dataset[:N_train,:D], dataset[N_train:,:D]
train_ll, validation_ll = dataset[:N_train, -1], dataset[N_train:, -1]

if args.verbose: print('Train {} | Validation {}'.format(len(train_data), len(validation_data)))

	######
	#	Training the flow and saving the output
	######

if args.load_flow:
	flow = STD_GW_Flow.load_flow(args.flow_file)
	history = None
else:
	flow = STD_GW_Flow(D, n_layers = args.n_layers, hidden_features = args.hidden_features, has_constant = True)

	early_stopper_callback = early_stopper(patience = args.patience, min_delta = args.min_delta, temp_file = args.flow_file+'.checkpoint', verbose = args.verbose)
	optimizer = optim.Adam(flow.parameters(), lr=args.learning_rate)
	scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', threshold = .02, factor = 0.5, patience = 4)
	
	history = flow.train_flow(args.loss_function, N_epochs = args.n_epochs,
		train_data = train_data, train_weights = train_ll,
		validation_data = validation_data, validation_weights = validation_ll,
		optimizer = optimizer, batch_size = args.batch_size, validation_step = 100,
		callback = early_stopper_callback, lr_scheduler = scheduler,
		boundaries = boundaries, verbose = args.verbose)

	if os.path.isfile(args.flow_file+'.checkpoint'): os.remove(args.flow_file+'.checkpoint')
	flow.save_weigths(args.flow_file)

if flow.constant:
	shift = flow.constant.detach().numpy()
else:
	with torch.no_grad():
		train_ll_flow = flow.log_prob(torch.Tensor(train_data[:10000])).numpy()
	shift = np.nanmedian(train_ll[:len(train_ll_flow)]-train_ll_flow)
if args.verbose: print('Log10 volume of the space (a.k.a. constant shift): ', shift/np.log(10))

	######
	#	Plotting
	######

if args.plot:

	if history:
		plot_loss_functions(history, args.run_dir)
		pass

	with torch.no_grad():
		validation_ll_flow = flow.log_prob(torch.Tensor(validation_data[:10000])).numpy()
		plot_samples = flow.sample(10000).numpy()
		
		#Why are there nans in the sampling??
	ids_nan = np.unique(np.where(np.isnan(plot_samples))[0])
	if len(ids_nan): plot_samples = np.delete(plot_samples, ids_nan, axis =0)
	if not args.ignore_boundaries and False:
		plot_samples = plot_samples[boundaries_checker(plot_samples, args.variable_format)]


	if not True: #DEBUG: see what happens to the tiling
		from mbank import tiling_handler
		t = tiling_handler('out_flow_paper_precessing/tiling_paper_precessing.npy')
		#t.flow = flow
		validation_ll_flow = 0.5*np.log(np.linalg.det(t.get_metric(validation_data[:10000], flow = False, kdtree = True)))
		shift = 0


	plt.figure()
	plt.hist((validation_ll_flow - validation_ll[:len(validation_ll_flow)]+shift)/np.log(10), histtype = 'step', bins = 100, label = 'flow', density = True)
	plt.legend()
	plt.xlabel(r"$\log_{10}(M_{flow}/M_{true})$")
	plt.savefig(args.run_dir+'flow_accuracy.png')

	plt.figure()
	m, M = int(np.min(validation_ll)/np.log(10)), int(np.max(validation_ll)/np.log(10))+1
	b_list = [i for i in range(m, M, 2)]
	b_list.insert(0, -np.inf)
	b_list.append(np.inf)
	log_M_diff = (validation_ll_flow - validation_ll[:len(validation_ll_flow)])+shift

	for start, stop in zip(b_list[:-1], b_list[1:]):
		ids_, = np.where(np.logical_and(
			validation_ll[:len(validation_ll_flow)]/np.log(10)>start,
			validation_ll[:len(validation_ll_flow)]/np.log(10)<stop)
		)
		print('[{}, {}]: {}'.format(start, stop, len(ids_)))
		if len(ids_)>10:
			plt.hist(log_M_diff[ids_]/np.log(10),
				histtype = 'step', bins = int(np.sqrt(len(ids_))), density = True,
				label = r"$\log_{10}M_{true} \in $"+'[{}, {}]'.format(start, stop))
	plt.xlabel(r"$\log_{10}(M_{pred}/M_{true})$")
	plt.legend(loc = 'upper left')
	plt.savefig(args.run_dir+'flow_accuracy_detailed.png')

	plot_tiles_templates(plot_samples, args.variable_format,
		tiling = None, dist_ellipse = None, save_folder = args.run_dir, show = False, title = 'Flow samples')
	os.rename(args.run_dir+'bank.png', args.run_dir+'flow_samples.png')
	os.rename(args.run_dir+'hist.png', args.run_dir+'flow_samples_hist.png')
	
	plot_colormap(validation_data, validation_ll/np.log(10), args.variable_format, statistics = 'mean', bins = 12,
		savefile = args.run_dir+'colormap_validation_data.png', values_label = 'log-likelihood', title = 'Validation data')

	plot_colormap(validation_data[:len(validation_ll_flow)], np.abs(log_M_diff)/np.log(10), args.variable_format, statistics = 'mean', bins = 30,
		savefile = args.run_dir+'colormap_residuals.png', values_label = r'$|\log_{10}\frac{p_{flow}}{p_{true}}|$', title = 'Residuals')

	if args.show: plt.show()
		
	
#### To access all the weights of the flow
#
#for transform in flow._transform._transforms:
#	for m in transform.modules():
#		for mm in m.parameters():
#			assert isinstance(mm, torch.nn.Parameter)

# For each mm, you can set requires_grad to False to disable training...

#[[isinstance(mm, torch.nn.Parameter) for mm in m.parameters()] for m in flow._transform._transforms[2].modules()]













