#!python
"""
Usage: $ brnn_optimize data_file output_network <flags>
  
Driver script for finding optimal hyperparameters for a bidirectional recurrent 
neural network on a given dataset, then training a network with those parameters
For more information on usage, use the '-h' flag.

.............................................................................
prot_brnn was developed by the Holehouse lab
     Original release ---- 2020

Question/comments/concerns? Raise an issue on github:
https://github.com/holehouse-lab/prot-brnn

Licensed under the MIT license. 
"""

import os
import sys

import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np 
import argparse

from prot_brnn import process_input_data as pid
from prot_brnn import brnn_architecture
from prot_brnn import train_network
from prot_brnn import brnn_plot
from prot_brnn import bayesian_optimization
from prot_brnn import encode_sequence

# Parse the command line arguments
parser = argparse.ArgumentParser(description='Train and test a bi-directional RNN using entire sequence.')
parser.add_argument('data_file', help='path to tsv file with format: <idx> <sequence> <data>')
parser.add_argument('output_network', help='path for the returned trained network')
parser.add_argument('--datatype', metavar='dtype', default='sequence', type=str, required=True,
					help="Required. Format of the input data file, must be 'sequence' or 'residues'")
parser.add_argument('--split', default='', metavar='split_file', 
			type=str, help="file indicating how to split datafile into training, validation, and testing sets")
parser.add_argument('-nc', default=1, type=int, metavar='num_classes', required=True,
					help='Required. Number of output classes, for regression put 1')
parser.add_argument('-b', default=32, type=int, metavar='batch_size', help='(def=32)')
parser.add_argument('-e', default=200, type=int, metavar='num_epochs', 
						help='number of training epochs (def=200)')
parser.add_argument('--setFractions', nargs=3, default=[0.7, 0.15, 0.15], type=float, 
			metavar=('train', 'val', 'test'),
			help='Proportion of dataset that should be divided into training, validation, and test sets')
parser.add_argument('--encode', default='onehot', type=str, metavar='encoding_scheme',
					help="'onehot' (default), 'biophysics', or specify a path to a user-created scheme")
parser.add_argument('--excludeSeqID', action='store_true')
parser.add_argument('--forceCPU', action='store_true')
parser.add_argument('--verbose', '-v', action='count', default=0)

args = parser.parse_args()

# Hyper-parameters
batch_size = args.b
num_epochs = args.e

# Other flags
dtype = args.datatype
num_classes = args.nc
split_file = args.split
encode = args.encode
verbosity = args.verbose
forceCPU = args.forceCPU
setFractions = args.setFractions
excludeSeqID = args.excludeSeqID

# Device configuration
if forceCPU:
	device = 'cpu'
else:
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

###############################################################################
########################      Validate arguments:      ########################
# Ensure that provided data_file exists
data_file = os.path.abspath(args.data_file)
if not os.path.isfile(data_file):
	print('Error: Datafile does not exist.')
	sys.exit()

# Ensure that output network location is valid
saved_weights = os.path.abspath(args.output_network)
network_filename = saved_weights.split('/')[-1]
if not os.path.exists(saved_weights[:-len(network_filename)]):
	print('Error: Output network directory does not exist.')
	sys.exit()

# If provided, check that split_file exists
if split_file != '':
	split_file = os.path.abspath(split_file)
	if not os.path.isfile(split_file):
		print('Error: Split-file does not exist.')
		sys.exit()
else:
	split_file=None

# Set encoding scheme and/or validate user scheme
if encode == 'onehot':
	encoding_scheme = 'onehot'
	input_size = 20
	encoder = None
elif encode == 'biophysics':
	encoding_scheme = 'biophysics'
	input_size = 9
	encoder = None
else:
	encoding_scheme = 'user'
	encode_file = encode
	encoder = encode_sequence.UserEncoder(encode_file)
	input_size = len(encoder)

# Initialize network as classifier or regressor
if num_classes > 1:
	problem_type = 'classification'
elif num_classes == 1:
	problem_type = 'regression'
else:
	print('Error: number of classes must be a positive integer.')
	sys.exit()

# Initialize network architecture depending on data format
if dtype == 'sequence':
	# Set collate function
	if problem_type == 'classification':
		collate_function = pid.seq_class_collate
	else:
		collate_function = pid.seq_regress_collate
elif dtype == 'residues':
	# Set collate function
	if problem_type == 'classification':
		collate_function = pid.res_class_collate
	else:
		collate_function = pid.res_regress_collate
else:
	print("Error: Invalid datatype argument -- must be 'residues' or 'sequence'.")
	sys.exit()

# Ensure that  batch size and num epochs are both positive ints
if num_epochs < 1:
	print('Error: number of epochs must be a positive integer.')
	sys.exit()
if batch_size < 1:
	print('Error: batch size must be a positive integer.')
	sys.exit()

# Ensure that the sum of setFractions adds up to 1
for frac in setFractions:
	if 0 >= frac or frac >= 1:
		print('Error: all set fractions must be between 0 and 1.')
		sys.exit()
if sum(setFractions) != 1.0:
	print('Error: set fractions must sum to 1')
	sys.exit()

###############################################################################


# Split data
cvs, train, val, test = pid.split_data_cv(data_file, datatype=dtype, problem_type=problem_type, 
										excludeSeqID=excludeSeqID, split_file=split_file, 
										encoding_scheme=encoding_scheme, encoder=encoder,
										percent_val=setFractions[1], percent_test=setFractions[2])

# Convert CV datasets to dataloaders
cv_loaders = []
for cv_train, cv_val in cvs:
	cv_train_loader = torch.utils.data.DataLoader(dataset=cv_train, batch_size=batch_size,
                                           		collate_fn=collate_function, shuffle=True)
	cv_val_loader = torch.utils.data.DataLoader(dataset=cv_val, batch_size=batch_size,
                                           		collate_fn=collate_function, shuffle=False)
	cv_loaders.append((cv_train_loader, cv_val_loader))

optimizer = bayesian_optimization.BayesianOptimizer(cv_loaders, input_size, num_epochs, num_classes, 
													dtype, saved_weights, device, verbosity)
	
best_hyperparams = optimizer.optimize()
lr = 10**best_hyperparams[0]
nl = int(best_hyperparams[1])
hs = int(best_hyperparams[2])

# Save these hyperparamters to a file so that the user has a record
params_file = saved_weights[:-len(network_filename)] + 'optimal_hyperparams.txt'
with open(params_file, 'w') as f:
	f.write('Learning rate: %.5f\n' % lr)
	f.write('Num Layers: %d\n' % nl)
	f.write('Hidden vector size: %d\n' % hs)

## Use these best hyperparams to train the network from scratch using the entire train/val sets
# Add data to dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train,
                                           batch_size=batch_size,
                                           collate_fn=collate_function,
                                           shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val,
                                           batch_size=batch_size,
                                           collate_fn=collate_function,
                                           shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test,
                                          batch_size=1,		# Set test batch size to 1
                                          collate_fn=collate_function,
                                          shuffle=False)

# Initialize network:
if dtype == 'sequence':
	brnn_network = brnn_architecture.BRNN_MtO(input_size, hs, nl, num_classes, device).to(device)
else:	# dtype == 'residues'
	brnn_network = brnn_architecture.BRNN_MtM(input_size, hs, nl, num_classes, device).to(device)

# Train network
if verbosity > 0:
	print('Training with optimal hyperparams:')
train_loss, val_loss = train_network.train(brnn_network, train_loader, val_loader, datatype=dtype, 
						problem_type=problem_type, weights_file=saved_weights, stop_condition='iter',
						device=device, learn_rate=lr, n_epochs=num_epochs*2, verbosity=verbosity) 
brnn_plot.training_loss(train_loss, val_loss)

# Test network
test_loss = train_network.test_labeled_data(brnn_network, test_loader, datatype=dtype, 
						problem_type=problem_type, weights_file=saved_weights, 
						num_classes=num_classes, device=device)
if verbosity > 0:
	print('\nTest Loss: %.4f' % test_loss)
