#!/usr/bin/env python
"""
Creates a condor dag workflow to train several NN given a PCA dataset
"""

import yaml
import sys
import argparse
import os
import subprocess

res = subprocess.run('which mlgw_fit_NN', shell = True, capture_output = True)
if res.returncode:
	raise OSError("Unable to find a valid installation of mlgw_fit_NN, unable to continue. Is mlgw installed properly?")
else:
	executable = res.stdout.decode('UTF-8').strip()

sub_file = """
Universe   = vanilla
Executable = {}
arguments = $(pycmd)
getenv = true
Log = logs/$(nodename).log
Error = logs/$(nodename).err
Output = logs/$(nodename).out
request_memory = 5GB
request_cpus = 2
request_disk = 4GB

queue
""".format(executable)

def get_dag_str(mode, vals, pca_dataset, model_directory):
	dag_str="""
JOB train_{0}_{1}_{2}{6} launch_train_NN.sub
VARS train_{0}_{1}_{2}{6} nodename="train_{0}_{1}_{2}{6}" pycmd="--pca-dataset {3} --model-directory {4} {5}"
RETRY train_{0}_{1}_{2}{6} 2
"""
	residual_str = '_residual' if 'residual' in v.keys() else ''
	args_str = ' '.join([ '--{} {}'.format(k,v if v else '') for k,v in vals.items()])
	if residual_str:
		args_str += '--base-residual-model-file {}/ph_weights_{}.keras'.format(model_directory, vals['components'].replace(' ',''))
		dag_str += "PARENT train_{0}_{1}_{2} CHILD train_{0}_{1}_{2}_residual\n".format(vals['quantity'], mode, vals['components'].replace(' ',''))
	return dag_str.format(vals['quantity'], mode, vals['components'].replace(' ',''),
		pca_dataset, model_directory, args_str, residual_str)
	
	
#order in format dag_str:
#	quantity | mode | comps | pca dataset | model dir | the rest

#######################################

parser = argparse.ArgumentParser(__doc__)

parser.add_argument(
	"--dag-filename", type = str, required = False, default = 'train_NN.dag',
	help="Name of the output dag file")
parser.add_argument(
	"--sub-filename", type = str, required = False, default = 'launch_train_NN.sub',
	help="Name of the condor submission file")
parser.add_argument(
	'config', type = str,
	help='Config file to load the information from')
parser.add_argument(
	'--clean', action = 'store_true', default = False,
	help='Whether to clean all the dag files and products')

args = parser.parse_args()
if args.clean:
	subprocess.run("rm -rf logs/ cp_pca_model.sh {}* {}".format(args.dag_filename, args.sub_filename), shell = True)
	quit()

with open(args.config, "r") as stream:
	config = yaml.safe_load(stream)
if isinstance(config['modes'], int): config['modes'] = str(config['modes'])

print(config)
if not config['pca-dataset'].endswith('/'):  config['pca-dataset']+='/'
if not config['model-dir'].endswith('/'):  config['model-dir']+='/'

dag_file = ''
cp_str = '' #to gather the PCA dataset
for mode in config['modes'].split(' '):
	pca_dataset = config['pca-dataset']+mode
	model_directory = config['model-dir']+mode
	cp_str += 'cp {0}/amp_PCA_model.dat {0}/ph_PCA_model.dat {0}/times.dat {1}\n'.format(pca_dataset, model_directory)
	for k, v in config['networks'].items():
		if k.find('_residual')>-1:
			assert k.replace('_residual', '') in config['networks'].keys(), "If the residual model '{}' has to be created, the base model '{}' must be trained".format(k, k.replace('_residual', ''))
		dag_file += get_dag_str(mode, v, pca_dataset, model_directory)


with open(args.dag_filename, "w") as text_file:
	text_file.write(dag_file)
with open(args.sub_filename, "w") as text_file:
	text_file.write(sub_file)
with open('cp_pca_model.sh', "w") as text_file:
	text_file.write(cp_str)

os.makedirs('logs', exist_ok = True)

print("DAG file saved to {0}\nRun your dag with\n\tcondor_submit_dag {0}".format(args.dag_filename))
print("Once the dag is done, you can add the PCA model weights with:\n\tbash cp_pca_model.sh")
		
		
		









