#!python

"""
Driver script for automatic hyperparameter optimization.
"""

import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np 
import os
import sys
import argparse
import process_input_data as pid
import brnn_architecture
import train_network
import brnn_plot
import bayesian_optimization

# Parse the command line arguments
parser = argparse.ArgumentParser(description='Train and test a bi-directional RNN using entire sequence.')
parser.add_argument('data_file', help='path to tsv file with format: <idx> <sequence> <data>')
parser.add_argument('output_network', help='path for the returned trained network')
parser.add_argument('--datatype', metavar='dtype', default='sequence', type=str, required=True,
					help="Required. Format of the input data file, must be 'sequence' or 'residues'")
parser.add_argument('--split', default='', metavar='split_file', 
			type=str, help="file indicating how to split datafile into training, validation, and testing sets")
parser.add_argument('-nc', default=1, type=int, metavar='num_classes', required=True,
					help='Required. Number of output classes, for regression put 1')
parser.add_argument('-b', default=32, type=int, metavar='batch_size', help='(def=32)')
parser.add_argument('-e', default=200, type=int, metavar='num_epochs', 
						help='number of training epochs (def=200)')
parser.add_argument('--excludeSeqID', action='store_true')
parser.add_argument('--encodeBiophysics', action='store_true')
parser.add_argument('--verbose', '-v', action='count', default=0)

args = parser.parse_args()

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
batch_size = args.b
num_epochs = args.e
dtype = args.datatype
num_classes = args.nc
split_file = args.split
verbosity = args.verbose

excludeSeqID = args.excludeSeqID
encodeBiophysics = args.encodeBiophysics

if encodeBiophysics:
	encoding_scheme = 'biophysics'
	input_size = 4
else:
	encoding_scheme = 'onehot'
	input_size = 20

###############################################################################
########################      Validate arguments:      ########################
# Ensure that provided data_file exists
data_file = os.path.abspath(args.data_file)
if not os.path.isfile(data_file):
	print('Error: Datafile does not exist.')
	sys.exit()

# Ensure that output network location is valid
saved_weights = os.path.abspath(args.output_network)
network_filename = saved_weights.split('/')[-1]
if not os.path.exists(saved_weights[:-len(network_filename)]):
	print('Error: Output network directory does not exist.')
	sys.exit()

# If provided, check that split_file exists
if split_file != '':
	split_file = os.path.abspath(split_file)
	if not os.path.isfile(split_file):
		print('Error: Split-file does not exist.')
		sys.exit()
else:
	split_file=None

# Initialize network as classifier or regressor
if num_classes > 1:
	problem_type = 'classification'
elif num_classes == 1:
	problem_type = 'regression'
else:
	print('Error: number of classes must be a positive integer.')
	sys.exit()


# Initialize network architecture depending on data format
if dtype == 'sequence':
	# Set collate function
	if problem_type == 'classification':
		collate_function = pid.seq_class_collate
	else:
		collate_function = pid.seq_regress_collate
elif dtype == 'residues':
	# Set collate function
	if problem_type == 'classification':
		collate_function = pid.res_class_collate
	else:
		collate_function = pid.res_regress_collate
else:
	print("Error: Invalid datatype argument -- must be 'residues' or 'sequence'.")
	sys.exit()


# Ensure that  batch size and num epochs are both positive ints
if num_epochs < 1:
	print('Error: number of epochs must be a positive integer.')
	sys.exit()
if batch_size < 1:
	print('Error: batch size must be a positive integer.')
	sys.exit()

###############################################################################


# Split data
cvs, train, val, test = pid.split_data_cv(data_file, datatype=dtype, problem_type=problem_type, 
											excludeSeqID=excludeSeqID, split_file=split_file, 
											encoding_scheme=encoding_scheme)

# Convert CV datasets to dataloaders
cv_loaders = []
for cv_train, cv_val in cvs:
	cv_train_loader = torch.utils.data.DataLoader(dataset=cv_train, batch_size=batch_size,
                                           		collate_fn=collate_function, shuffle=True)
	cv_val_loader = torch.utils.data.DataLoader(dataset=cv_val, batch_size=batch_size,
                                           		collate_fn=collate_function, shuffle=False)
	cv_loaders.append((cv_train_loader, cv_val_loader))

optimizer = bayesian_optimization.BayesianOptimizer(cv_loaders, input_size, num_epochs, num_classes, 
													dtype, saved_weights, device, verbosity)
	
best_hyperparams = optimizer.optimize()
lr = 10**best_hyperparams[0]
nl = int(best_hyperparams[1])
hs = int(best_hyperparams[2])

## Use these best hyperparams to train the network from scratch using the entire train/val sets
# Add data to dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train,
                                           batch_size=batch_size,
                                           collate_fn=collate_function,
                                           shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val,
                                           batch_size=batch_size,
                                           collate_fn=collate_function,
                                           shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test,
                                          batch_size=1,		# Set test batch size to 1
                                          collate_fn=collate_function,
                                          shuffle=False)

# Initialize network:
if dtype == 'sequence':
	brnn_network = brnn_architecture.BRNN_MtO(input_size, hs, nl, num_classes, device).to(device)
else:	# dtype == 'residues'
	brnn_network = brnn_architecture.BRNN_MtM(input_size, hs, nl, num_classes, device).to(device)

# Train network
if verbosity > 0:
	print('Training with optimal hyperparams:')
train_loss, val_loss = train_network.train(brnn_network, train_loader, val_loader, datatype=dtype, 
						problem_type=problem_type, weights_file=saved_weights, stop_condition='iter',
						device=device, learn_rate=lr, n_epochs=num_epochs*2, verbosity=verbosity) 
brnn_plot.training_loss(train_loss, val_loss)

# Test network
test_loss = train_network.test_labeled_data(brnn_network, test_loader, datatype=dtype, 
						problem_type=problem_type, weights_file=saved_weights, 
						num_classes=num_classes, device=device)
if verbosity > 0:
	print('\nTest Loss: %.4f' % test_loss)

