#!python
# BPNet command-line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import sys
import numpy
import torch
import pyfaidx
import argparse

from bpnetlite import BPNet
from bpnetlite.io import PeakGenerator
from bpnetlite.io import extract_loci

from bpnetlite.attributions import calculate_attributions
from bpnetlite.marginalize import marginalization_report

from bpnetlite.negatives import extract_matching_loci

import pandas
import pyBigWig
import json

desc = """BPNet is an neural network primarily composed of dilated residual
	convolution layers for modeling the associations between biological
	sequences and biochemical readouts. This tool will take in a fasta
	file for the sequence, a bed file for signal peak locations, and bigWig
	files for the signal to predict and the control signal, and train a
	BPNet model for you."""

# Read in the arguments
parser = argparse.ArgumentParser(description=desc)
subparsers = parser.add_subparsers(help="Must be either 'negatives', 'fit', 'predict', 'interpret', or 'marginalize'.", required=True, dest='cmd')

negatives_parser = subparsers.add_parser("negatives", help="Sample GC-matched negatives.")
negatives_parser.add_argument("-i", "--peaks", required=True, help="Peak bed file.")
negatives_parser.add_argument("-f", "--fasta", help="Genome FASTA file.")
negatives_parser.add_argument("-b", "--bigwig", required=True, help="GC content bigwig.")
negatives_parser.add_argument("-o", "--output", required=True, help="Output bed file.")
negatives_parser.add_argument("-l", "--bin_width", type=float, default=0.02, help="GC bin width to match.")
negatives_parser.add_argument("-w", "--width", type=int, default=2114, help="Width for calculating GC content.")
negatives_parser.add_argument("-v", "--verbose", default=True, action='store_true')

train_parser = subparsers.add_parser("fit", help="Fit a BPNet model.")
train_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for fitting the model.")

predict_parser = subparsers.add_parser("predict", help="Make predictions using a trained BPNet model.")
predict_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for making predictions.")

interpret_parser = subparsers.add_parser("interpret", help="Make interpretations using a trained BPNet model.")
interpret_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for calculating attributions.")

marginalize_parser = subparsers.add_parser("marginalize", help="Run marginalizations given motifs.")
marginalize_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for calculating attributions.")

###
# Default Parameters
###

default_fit_parameters = {
	'n_filters': 64,
	'n_layers': 8,
	'n_outputs': 2,
	'n_control_tracks': 2,
	'profile_output_bias': True,
	'count_output_bias': True,
	'name': None,

	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'max_jitter': 128,
	'reverse_complement': True,
	'max_epochs': 250,
	'validation_iter': 100,
	'lr': 0.001,
	'alpha': 1,
	'verbose': False,

	'min_counts': 0,
	'max_counts': 99999999,

	'training_chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
		'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
		'chr18', 'chr19', 'chr20', 'chr22'],
	'validation_chroms': ['chr4', 'chr15', 'chr21'],
	'sequences': None,
	'loci': None,
	'signals': None,
	'controls': None,
	'random_state': None
}

default_predict_parameters = {
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
		'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
		'chr18', 'chr19', 'chr20', 'chr22'],
	'sequences': None,
	'loci': None,
	'controls': None,
	'model': None,
	'profile_filename': 'y_profile.npz',
	'count_filename': 'y_count.npz'
}

default_interpret_parameters = {
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
		'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
		'chr18', 'chr19', 'chr20', 'chr22'],
	'sequences': None,
	'loci': None,
	'model': None,
	'output': 'count',
	'ohe_filename': 'ohe.npz',
	'shap_filename': 'shap.npz',
	'random_state':0,
}

default_marginalize_parameters = {
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
		'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
		'chr18', 'chr19', 'chr20', 'chr22'],
	'sequences': None,
	'motifs': None,
	'loci': None,
	'n_loci': None,
	'shuffle': False,
	'model': None,
	'output_filename':'marginalize/',
	'random_state':0,
}

###
# Commands
###

def merge_parameters(parameters, default_parameters):
	"""Merge the provided parameters with the default parameters.

	
	Parameters
	----------
	parameters: str
		Name of the JSON folder with the provided parameters

	default_parameters: dict
		The default parameters for the operation.


	Returns
	-------
	params: dict
		The merged set of parameters.
	"""

	with open(parameters, "r") as infile:
		parameters = json.load(infile)

	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None and parameter != "controls":
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	return parameters


# Pull the arguments
args = parser.parse_args()


# Calculate GC-matched negatives
if args.cmd == 'negatives':
	if args.fasta is not None:
		chroms = ['chr{}'.format(i) for i in range(1, 23)] + ['chrX']
		calculate_gc_genomewide(args.fasta, args.bigwig, args.width, chroms, 
			args.verbose)

	numpy.random.seed(0)

	# Extract regions that match the GC content of the peaks
	matched_loci = extract_matching_loci(args.peaks, args.bigwig, args.width, 
		args.bin_width, args.verbose)

	matched_loci.to_csv(args.output, header=False, sep='\t', index=False)

# Fit a BPNet model to data
if args.cmd == "fit":
	parameters = merge_parameters(args.parameters, default_fit_parameters)

	training_data = PeakGenerator(
		loci=parameters['loci'], 
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=parameters['controls'],
		chroms=parameters['training_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=parameters['max_jitter'],
		reverse_complement=parameters['reverse_complement'],
		min_counts=parameters['min_counts'],
		max_counts=parameters['max_counts'],
		random_state=parameters['random_state'],
		batch_size=parameters['batch_size'],
		verbose=parameters['verbose']
	)

	valid_data = extract_loci(
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=parameters['controls'],
		loci=parameters['loci'],
		chroms=parameters['validation_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	if parameters['controls'] is not None:
		valid_sequences, valid_signals, valid_controls = valid_data
	else:
		valid_sequences, valid_signals = valid_data
		valid_controls = None

	trimming = (parameters['in_window'] - parameters['out_window']) // 2

	model = BPNet(n_filters=parameters['n_filters'], 
		n_layers=parameters['n_layers'],
		n_outputs=parameters['n_outputs'],
		n_control_tracks=parameters['n_control_tracks'],
		profile_output_bias=parameters['profile_output_bias'],
		count_output_bias=parameters['count_output_bias'],
		alpha=parameters['alpha'],
		trimming=trimming,
		name=parameters['name'],
		verbose=parameters['verbose']).cuda()

	optimizer = torch.optim.AdamW(model.parameters(), lr=parameters['lr'])

	model.fit(training_data, optimizer, X_valid=valid_sequences, 
		X_ctl_valid=valid_controls, y_valid=valid_signals, 
		max_epochs=parameters['max_epochs'], 
		validation_iter=parameters['validation_iter'], 
		batch_size=parameters['batch_size'])

elif args.cmd == 'predict':
	parameters = merge_parameters(args.parameters, default_predict_parameters)

	model = torch.load(parameters['model']).cuda()

	examples = extract_loci(
		sequences=parameters['sequences'],
		controls=parameters['controls'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	if parameters['controls'] == None:
		X = examples
		if model.n_control_tracks > 0:
			X_ctl = torch.zeros(X.shape[0], model.n_control_tracks, X.shape[-1])
		else:
			X_ctl = None
	else:
		X, X_ctl = examples

	y_profiles, y_counts = model.predict(X, X_ctl=X_ctl, 
		batch_size=parameters['batch_size'])

	numpy.savez_compressed(parameters['profile_filename'], y_profiles)
	numpy.savez_compressed(parameters['count_filename'], y_counts)

elif args.cmd == 'interpret':
	parameters = merge_parameters(args.parameters, default_interpret_parameters)

	model = torch.load(parameters['model']).cuda()

	X = extract_loci(
		sequences=parameters['sequences'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	if model.n_control_tracks > 0:
		X_ctl = (torch.zeros(X.shape[0], model.n_control_tracks, X.shape[-1]),)
	else:
		X_ctl = None

	X_attr = calculate_attributions(model, X, args=X_ctl,
		model_output=parameters['output'], hypothetical=True, 
		random_state=parameters['random_state'],
		verbose=parameters['verbose'])

	numpy.savez_compressed(parameters['ohe_filename'], X)
	numpy.savez_compressed(parameters['shap_filename'], X_attr)

# Marginalize motifs
elif args.cmd == 'marginalize':
	parameters = merge_parameters(args.parameters, 
		default_marginalize_parameters)

	model = torch.load(parameters['model']).cuda()

	X = extract_loci(
		sequences=parameters['sequences'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		n_loci=parameters['n_loci'],
		verbose=parameters['verbose']
	)

	if parameters['shuffle'] == True:
		idxs = numpy.arange(X.shape[0])
		numpy.random.shuffle(idxs)
		X = X[idxs]

	if parameters['n_loci'] is not None:
		X = X[:parameters['n_loci']]

	marginalization_report(model, parameters['motifs'], X, 
		parameters['output_filename'])
