#!python

import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np 
import os
import sys
import argparse
import process_input_data as pid
import brnn_architecture
import train_network
import brnn_plot

"""
Driver script for the prot_brnn package.

Batch size should be 16-1024, Optimally ~1-5% of dataset
"""

# Parse the command line arguments
parser = argparse.ArgumentParser(description='Train and test a bi-directional RNN using entire sequence.')
parser.add_argument('data_file', help='path to tsv file with format: <idx> <sequence> <data>')
parser.add_argument('output_network', help='path for the returned trained network')
parser.add_argument('--datatype', metavar='dtype', default='sequence', type=str, required=True,
					help="Required. Format of the input data file, must be 'sequence' or 'residues'")
parser.add_argument('--split', default='', metavar='split_file', 
			type=str, help="file indicating how to split datafile into training, validation, and testing sets")
parser.add_argument('--stop', default='iter', metavar='stop_condition', 
						type=str, help="training stop condition: either 'auto' or 'iter' (default)")
parser.add_argument('-nc', default=1, type=int, metavar='num_classes', required=True,
					help='Required. Number of output classes, for regression put 1')
parser.add_argument('-hs', default=5, type=int, metavar='hidden_size', 
						help='hidden vector size (def=5)')
parser.add_argument('-nl', default=1, type=int, metavar='num_layers', 
						help='number of layers per direction (def=1)')
parser.add_argument('-b', default=32, type=int, metavar='batch_size', help='(def=32)')
parser.add_argument('-lr', default=0.001, type=float, metavar='learning_rate', help='(def=0.001)')
parser.add_argument('-e', default=30, type=int, metavar='num_epochs', 
						help='number of training epochs (def=30)')
parser.add_argument('--excludeSeqID', action='store_true')
parser.add_argument('--encodeBiophysics', action='store_true')
parser.add_argument('--verbose', '-v', action='count', default=0)

args = parser.parse_args()

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
hidden_size = args.hs
num_layers = args.nl
batch_size = args.b
learning_rate = args.lr
num_epochs = args.e
dtype = args.datatype
num_classes = args.nc
split_file = args.split
stop_cond = args.stop
verbosity = args.verbose

excludeSeqID = args.excludeSeqID
encodeBiophysics = args.encodeBiophysics

if encodeBiophysics:
	encoding_scheme = 'biophysics'
	input_size = 4
else:
	encoding_scheme = 'onehot'
	input_size = 20

###############################################################################
########################      Validate arguments:      ########################
# Ensure that provided data_file exists
data_file = os.path.abspath(args.data_file)
if not os.path.isfile(data_file):
	print('Error: Datafile does not exist.')
	sys.exit()

# Ensure that output network location is valid
saved_weights = os.path.abspath(args.output_network)
network_filename = saved_weights.split('/')[-1]
if not os.path.exists(saved_weights[:-len(network_filename)]):
	print('Error: Output network directory does not exist.')
	sys.exit()

# If provided, check that split_file exists
if split_file != '':
	split_file = os.path.abspath(split_file)
	if not os.path.isfile(split_file):
		print('Error: Split-file does not exist.')
		sys.exit()
else:
	split_file=None

# Initialize network as classifier or regressor
if num_classes > 1:
	problem_type = 'classification'
elif num_classes == 1:
	problem_type = 'regression'
else:
	print('Error: number of classes must be a positive integer.')
	sys.exit()

# Initialize network architecture depending on data format
if dtype == 'sequence':
	# Use a many-to-one architecture
	brnn_network = brnn_architecture.BRNN_MtO(input_size, hidden_size, 
									num_layers, num_classes, device).to(device)
	# Set collate function
	if problem_type == 'classification':
		collate_function = pid.seq_class_collate
	else:
		collate_function = pid.seq_regress_collate
elif dtype == 'residues':
	# Use a many-to-many architecture
	brnn_network = brnn_architecture.BRNN_MtM(input_size, hidden_size, 
									num_layers, num_classes, device).to(device)	
	# Set collate function
	if problem_type == 'classification':
		collate_function = pid.res_class_collate
	else:
		collate_function = pid.res_regress_collate
else:
	print("Error: Invalid datatype argument -- must be 'residues' or 'sequence'.")
	sys.exit()

# Ensure that learning rate is between 0 and 1
if learning_rate >= 1 or learning_rate <= 0:
	print("Error: Learning rate must be between 0 and 1.")
	sys.exit()

# Ensure that stop condition is 'iter' or 'auto'
if stop_cond == 'auto':
	if num_epochs > 10:
		print("Warning: stop condition is set to 'auto' and num_epochs > 10." +
			" Network training may take a long time.")
elif stop_cond != 'iter':
	print("Error: Invalid stop condtiion. Must be 'auto' or 'iter'.")
	sys.exit()

# Ensure that hidden size, num layers, batch size, and num epochs are all positive ints
if hidden_size < 1:
	print('Error: hidden vector size must be a positive integer.')
	sys.exit()
if num_layers < 1:
	print('Error: number of layers must be a positive integer.')
	sys.exit()
if num_epochs < 1:
	print('Error: number of epochs must be a positive integer.')
	sys.exit()
if batch_size < 1:
	print('Error: batch size must be a positive integer.')
	sys.exit()


###############################################################################

# Split data
train, val, test = pid.split_data(data_file, datatype=dtype, problem_type=problem_type, 
									excludeSeqID=excludeSeqID, split_file=split_file, 
									encoding_scheme=encoding_scheme)

# Add data to dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train,
                                           batch_size=batch_size,
                                           collate_fn=collate_function,
                                           shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val,
                                           batch_size=batch_size,
                                           collate_fn=collate_function,
                                           shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test,
                                          batch_size=1,		# Set test batch size to 1
                                          collate_fn=collate_function,
                                          shuffle=False)

# Train network
train_loss, val_loss = train_network.train(brnn_network, train_loader, val_loader, datatype=dtype, 
						problem_type=problem_type, weights_file=saved_weights, stop_condition=stop_cond,
						device=device, learn_rate=learning_rate, n_epochs=num_epochs, verbosity=verbosity)
brnn_plot.training_loss(train_loss, val_loss)

# Test network
test_loss = train_network.test_labeled_data(brnn_network, test_loader, datatype=dtype, 
						problem_type=problem_type, weights_file=saved_weights, 
						num_classes=num_classes, device=device)

if verbosity > 0:
	print('\nTest Loss: %.4f' % test_loss)
