#!/usr/bin/env python
"""
Launch an instance of keras tuner to find the best network to fit a PCA dataset

Typical usage:

	python tune_NN.py --pca-dataset pca_datasets/IMRPhenomTPHM/22/ --working-dir bayesian_tuning_22/ --project-name tuning_amp_22_1234 --quantity amp --components 4 --units 1 2 3 5 10 20 35 50 75 100 200 --n-layers 1 2 4 6 8 10 --polynomial-features 1 2 3 4 5 --max-epochs 10000


"""
from mlgw.NN_model import tune_model, analyse_tuner_results
import argparse
from pathlib import Path

parser = argparse.ArgumentParser(__doc__)

parser.add_argument(
	"--pca-dataset", type = str, required = False,
	help="Folder for the PCA dataset")

parser.add_argument(
	"--working-dir", type = str, required = True,
	help="Folder where all the tuning products will be stored")

parser.add_argument(
	"--quantity", type = str, required = False, choices = ['amp', 'ph'], default = 'ph',
	help="Wheter to create the dataset for amplitude of phase")

parser.add_argument(
	"--components", type = int, required = False, nargs = '+', default = 2,
	help="PCA components to be included in the model")

parser.add_argument(
	"--project-name", type = str, required = False,
	help="A name for the tuner project")

parser.add_argument(
	"--max-epochs", type = int, required = False, default = 10000,
	help="Maximum number of epochs for each training istance")

parser.add_argument(
	"--units", type = int, required = False, nargs = '+', default = [1, 2, 3, 5, 10, 50, 100],
	help="List of units per layers for the optimizer to try")

parser.add_argument(
	"--n-layers", type = int, required = False, nargs = '+', default = [1, 2, 4, 6, 8, 10],
	help="Number of layers for the optimizer to try")

parser.add_argument(
	"--polynomial-order", type = int, required = False, nargs = '+', default = [1, 2, 3, 4, 5],
	help="Data augmentation features for the optimizer to try")
	
parser.add_argument(
	"--features", type = str, required = False, nargs = '+', default = 'mc_chieff',
	help="Features to use for data augmenation purposes")
	

parser.add_argument(
	"--analyse", action='store_true',
	help="Whether to analyze the validation results")

args = parser.parse_args()

if args.analyse:
	analyse_tuner_results(Path(args.working_dir)/args.project_name, save_loc=None)
	quit()

assert args.pca_dataset is not None, "--pca-dataset must be given if data are to be analysed"

hyperparameters = {
		#list --> keeps a choice in build_model
	"units" : args.units, #units per hidden layer
	"layers" : args.n_layers, #num of hidden layers
	"activation" : ("sigmoid"),
	"learning_rate" : (0.00001, 0.00003, 0.0001,0.0003,0.001,0.005),
	"feature_order" : args.polynomial_order,
	"features": args.features
}

if not args.project_name:
	comp = str(args.components) if isinstance(args.components, float) else ''.join([str(s) for s in args.components])
	args.project_name = 'tuning_{}_{}'.format(args.quantity, comp)

tune_model(args.working_dir, args.project_name , args.quantity, args.pca_dataset, args.components, hyperparameters,
		max_epochs = args.max_epochs, trials=100, init_trials=25)

