#!python
# BPNet command-line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import sys
import numpy
import torch
import argparse

from bpnetlite import BPNet
from bpnetlite.io import PeakGenerator
from bpnetlite.io import extract_peaks
from bpnetlite.attributions import calculate_attributions

import json

desc = """BPNet is an neural network primarily composed of dilated residual
	convolution layers for modeling the associations between biological
	sequences and biochemical readouts. This tool will take in a fasta
	file for the sequence, a bed file for signal peak locations, and bigWig
	files for the signal to predict and the control signal, and train a
	BPNet model for you."""

# Read in the arguments
parser = argparse.ArgumentParser(description=desc)
subparsers = parser.add_subparsers(help="Must be either 'train', 'predict', or 'interpret'.", required=True, dest='cmd')

train_parser = subparsers.add_parser("train", help="Train a BPNet model.")
train_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for training the model.")

predict_parser = subparsers.add_parser("predict", help="Make predictions using a trained BPNet model.")
predict_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for making predictions.")

interpret_parser = subparsers.add_parser("interpret", help="Make interpretations using a trained BPNet model.")
interpret_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for calculating attributions.")

# Pull the arguments
args = parser.parse_args()

if args.cmd == "train":
	with open(args.parameters, "r") as infile:
		parameters = json.load(infile)

	default_parameters = {
		'n_filters': 64,
		'n_layers': 8,
		'batch_size': 64,
		'in_window': 2114,
		'out_window': 1000,
		'max_jitter': 128,
		'reverse_complement': True,
		'max_epochs': 250,
		'validation_iter': 100,
		'lr': 0.001,
		'alpha': 1,
		'verbose': False,

		'min_counts': 0,
		'max_counts': 99999999,

		'training_chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
			'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
			'chr18', 'chr19', 'chr20', 'chr22'],
		'validation_chroms': ['chr4', 'chr15', 'chr21'],
		'sequences': None,
		'peaks': None,
		'signals': None,
		'controls': None,
		'random_state': None
	}

	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None and parameter != "controls":
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	###

	training_data = PeakGenerator(
		peaks=parameters['peaks'], 
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=parameters['controls'],
		chroms=parameters['training_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=parameters['max_jitter'],
		reverse_complement=parameters['reverse_complement'],
		min_counts=parameters['min_counts'],
		max_counts=parameters['max_counts'],
		random_state=parameters['random_state'],
		batch_size=parameters['batch_size'],
		verbose=parameters['verbose']
	)


	valid_sequences, valid_signals, valid_controls = extract_peaks(
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=parameters['controls'],
		peaks=parameters['peaks'],
		chroms=parameters['validation_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	trimming = (parameters['in_window'] - parameters['out_window']) // 2

	model = BPNet(n_filters=parameters['n_filters'], 
		n_layers=parameters['n_layers'],
		alpha=parameters['alpha'],
		trimming=trimming).cuda()

	optimizer = torch.optim.Adam(model.parameters(), lr=parameters['lr'])

	model.fit_generator(training_data, optimizer, X_valid=valid_sequences, 
		X_ctl_valid=valid_controls, y_valid=valid_signals, 
		max_epochs=parameters['max_epochs'], 
		validation_iter=parameters['validation_iter'], 
		batch_size=parameters['batch_size'])

elif args.cmd == 'predict':
	with open(args.parameters, "r") as infile:
		parameters = json.load(infile)

	default_parameters = {
		'batch_size': 64,
		'in_window': 2114,
		'out_window': 1000,
		'verbose': False,
		'chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
			'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
			'chr18', 'chr19', 'chr20', 'chr22'],
		'sequences': None,
		'peaks': None,
		'controls': None,
		'model': None,
		'profile_filename': 'y_profile.npz',
		'count_filename': 'y_count.npz'
	}

	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None and parameter != "controls":
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	model = torch.load(parameters['model'])

	examples = extract_peaks(
		sequences=parameters['sequences'],
		controls=parameters['controls'],
		peaks=parameters['peaks'],
		chroms=parameters['chroms'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	if parameters['controls'] == None:
		X = examples
		if model.n_control_tracks > 0:
			X_ctl = torch.zeros(X.shape[0], model.n_control_tracks, X.shape[-1])
		else:
			X_ctl = None
	else:
		X, X_ctl = examples

	y_profiles, y_counts = model.predict(X, X_ctl=X_ctl, 
		batch_size=parameters['batch_size'])

	numpy.savez_compressed(parameters['profile_filename'], y_profiles)
	numpy.savez_compressed(parameters['count_filename'], y_counts)

elif args.cmd == 'interpret':
	with open(args.parameters, "r") as infile:
		parameters = json.load(infile)

	default_parameters = {
		'batch_size': 64,
		'in_window': 2114,
		'out_window': 1000,
		'verbose': False,
		'chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
			'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
			'chr18', 'chr19', 'chr20', 'chr22'],
		'sequences': None,
		'peaks': None,
		'model': None,
		'output': 'count',
		'ohe_filename': 'ohe.npz',
		'shap_filename': 'shap.npz',
		'random_state':0,
	}

	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None and parameter != "controls":
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	model = torch.load(parameters['model'])

	X = extract_peaks(
		sequences=parameters['sequences'],
		peaks=parameters['peaks'],
		chroms=parameters['chroms'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	X_ctl = torch.zeros(X.shape[0], model.n_control_tracks, X.shape[-1])

	X_attr = calculate_attributions(model, X, args=(X_ctl,),
		model_output=parameters['output'], hypothetical=True, 
		random_state=parameters['random_state'],
		verbose=parameters['verbose'])

	numpy.savez_compressed(parameters['ohe_filename'], X)
	numpy.savez_compressed(parameters['shap_filename'], X_attr)