#!/usr/bin/env python
# tomtom-lite command line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import os
import math
import numpy
import pandas
import pyfaidx
import argparse

from memelite.io import read_meme
from memelite.utils import characters
from memelite.utils import one_hot_encode

from memelite import tomtom


def _check_download_targets(targets):
	"""An internal function for downloading JASPAR if no target database is set."""
	
	if targets is None:
		targets = __file__.replace("ttl", "") 
		targets += "JASPAR2024_CORE_non-redundant_pfms_jaspar.meme"
		f = "https://jaspar.elixir.no/download/data/2024/CORE/JASPAR2024_CORE_non-redundant_pfms_meme.txt"

		if not os.path.isfile(targets):
			print("Downloading {}...".format(f))
			os.system("wget -O {} {}".format(targets, f))
	
	return targets


def _run_tomtom(args):
	"""An internal function for running TOMTOM from the command-line.
	
	This function handles two of the more standard use-cases for Tomtom: where you
	have two MEME-formatted PWM files and want to compare them, and when you have
	a single query. In both cases, you are mapping a query against a target
	database and each row of the output should be a hit in the Tomtom formatted
	file.
	"""

	args.targets = _check_download_targets(args.targets)
	
	targets = read_meme(args.targets)
	target_names = numpy.array(list(targets.keys()))
	target_pwms = list(targets.values())
	target_seqs = numpy.array([characters(x, force=True) for x in target_pwms])

	if os.path.isfile(args.query):
		queries = read_meme(args.query)
		query_names = list(queries.keys())
		query_pwms = list(queries.values())

	else:
		query_names = ['.']
		query_pwms = [one_hot_encode(args.query)]

	query_seqs = numpy.array([characters(x, force=True) for x in query_pwms])

	p, scores, offsets, overlaps, strands = tomtom(query_pwms, target_pwms, 
		n_nearest=args.n_nearest, n_score_bins=args.n_score_bins, 
		n_median_bins=args.n_median_bins, n_target_bins=args.n_target_bins, 
		n_cache=args.n_cache, reverse_complement=not args.norc, 
		n_jobs=args.n_jobs)


	q_names, q_seqs = [], []
	t_names, t_seqs, t_ps, t_scores, t_offsets = [], [], [], [], []
	t_overlaps, t_strands = [], []

	for qidx, tidx in zip(*numpy.where(p <= args.thresh)):
		q_names.append(query_names[qidx])
		q_seqs.append(query_seqs[qidx])

		t_names.append(target_names[tidx])
		t_seqs.append(target_seqs[tidx])
		t_ps.append(p[qidx, tidx])
		t_scores.append(int(scores[qidx, tidx]))
		t_offsets.append(int(offsets[qidx, tidx]))
		t_overlaps.append(int(overlaps[qidx, tidx]))
		t_strands.append('+-'[int(strands[qidx, tidx])])

	max_q_name_len = max([len(name) for name in q_names])
	max_q_seq_len = max([len(seq) for seq in q_seqs]) + 2

	max_t_name_len = max([len(name) for name in t_names])
	max_t_seq_len = max([len(seq) for seq in t_seqs]) + 2
	max_offset = max(t_offsets)

	print("Query Name\tQuery Sequence\tTarget Name\tTarget Sequence\tp-value"
		"\tScore\tOffset\tOverlap\tStrand")

	for i in numpy.argsort(t_ps):
		nq = query_pwms[0].shape[-1]
		seq, offset, overlap = t_seqs[i], t_offsets[i], t_overlaps[i]

		if overlap == nq and offset >= 0:
			s1 = seq[:offset].rjust(max_offset)
			s2 = seq[offset:offset + overlap]
			s3 = seq[offset + overlap:]

		elif offset >= 0:
			s1 = seq[:offset].rjust(max_offset)
			s2 = seq[offset:] + '-' * (nq - overlap)
			s3 = ''

		else:
			s1 = ' ' * max_offset
			s2 = '-'*-offset + seq[:overlap]
			s3 = seq[overlap:]

		s2 = [c if c == c0 else c.lower() for c, c0 in zip(s2, q_seqs[i])]
		s2 = ''.join(s2)
		
		seq = s1.lower() + '.' + s2 + '.' + s3.lower()

		str_format = ("{:" + str(max_q_name_len) + "}\t{:" + str(max_q_seq_len)
			+ "}\t{:" + str(max_t_name_len) + "}\t{:" + 
			str(max_t_seq_len+max_offset) + "}\t{:.8}\t{:5}\t{:6}\t{:7}\t{:6}")

		print(str_format.format(q_names[i], q_seqs[i], t_names[i], seq, t_ps[i], 
			t_scores[i], t_offsets[i], t_overlaps[i], t_strands[i]))


def _run_annotate(args):
	"""An internal function for using tomtom-lite to annotate coordinates.
	
	This function takes in a BED file, a FASTA file, and optionally a target
	database, extracts the discrete sequence from the FASTA file based on the
	coordinates in the BED file, and maps these discrete sequences to the PWM
	database using tomtomlite. The output is a BED file with additional columns.
	"""
	
	# Load BED file of coordinates
	names = 'chrom', 'start', 'end'
	df = pandas.read_csv(args.bed, sep="\t", usecols=(0, 1, 2), header=None, 
		names=names)
	
	# Load FASTA file of sequences
	fa = pyfaidx.Fasta(args.fasta)
	
	# Load target database
	args.targets = _check_download_targets(args.targets)
		
	targets = read_meme(args.targets)
	target_names = numpy.array(list(targets.keys()))
	target_pwms = list(targets.values())
	target_seqs = numpy.array([characters(x, force=True) for x in target_pwms])
	
	###
	
	# Extract discrete sequences
	seqs = []
	for i, (chrom, start, end) in df.iterrows():
		seq = fa[chrom][start:end].seq.upper()
		seq = one_hot_encode(seq)
		seqs.append(seq)
		
	# Run tomtom-lite
	p, scores, offsets, overlaps, strands, idxs = tomtom(seqs, target_pwms, 
		n_nearest=1, n_score_bins=args.n_score_bins, 
		n_median_bins=args.n_median_bins, n_target_bins=args.n_target_bins, 
		n_cache=args.n_cache, reverse_complement=not args.norc, 
		n_jobs=args.n_jobs)
	
	max_chrom = max([len(key) for key in fa.keys()])
	max_start = len(str(int(df['start'].max())))
	max_end = len(str(int(df['end'].max())))
	max_motif = max([len(name) for name in target_names])
	
	fmt = ('{:' + str(max_chrom) + 
		   '}\t{:' + str(max_start) +
		   '}\t{:' + str(max_end) +
		   '}\t{:' + str(max_motif) +
		   '}\t{:6.6}')
	
	# Save the scores
	for i, (chrom, start, end) in df.iterrows():
		print(fmt.format(
			chrom,
			start,
			end,
			target_names[int(idxs[i, 0])],
			-math.log(p[i, 0]) if p[i, 0] > 0 else float("inf")
		))


###
# TOMTOM command-line options
###


desc = """tomtom is a method for determining whether the similarity between two
		two PWMs is statistically significant and tomtom-lite is a re-implementation
		of this algorithm to significantly speed it up. A common usage of Tomtom is 
		to map a discrete sequence to a motif database to determine which motif it
		most resembles. This tool allows users to easily annotate individual
		discrete sequences from the command-line, or to bulk-annotate coordinates."""


parser = argparse.ArgumentParser(description=desc)

# Core Arguments: The target motif database and the p-value threshold
parser.add_argument("-t", "--targets", type=str,
	help="""The filename of a MEME file. By default, will download and use the
	JASPAR2024 Core non-redundant MEME file.""")

parser.add_argument("-p", "--thresh", type=float, default=0.01,
	help="""The p-value threshold for returning matches.""")


# Option 1: Pass in a string query or a database and get output to stdout
parser.add_argument("-q", "--query", type=str,
	help="""Either the filename of a MEME file ending in `.meme` or the string 
	of a single motif.""")


# Option 2: Pass in a FASTA and BED file and get output to stdout
parser.add_argument("-f", "--fasta", type=str,
	help="A filename of the FASTA file corresponding to the BED file.")

parser.add_argument("-b", "--bed", type=str,
	help="A BED file of coordinates whose sequences should be extracted and annotated.")


# Additional optional arguments related to TOMTOM
parser.add_argument("-n", "--n_nearest", type=int, default=None,
	help="""The number of nearest targets to return for each query.""")
parser.add_argument("-s", "--n_score_bins", type=int, default=100,
	help="""The number of query-target score bins to use. `t` in the paper.""")
parser.add_argument("-m", "--n_median_bins", type=int, default=1000,
	help="""The number of bins to use for approximate median calculations.""")
parser.add_argument("-a", "--n_target_bins", type=int, default=100,
	help="""The number of bins to use for approximate target hashing.""")
parser.add_argument("-c", "--n_cache", type=int, default=100,
	help="""The amount of cache to provide for calculations.""")
parser.add_argument("-r", "--norc", action='store_true',
	help="""Whether to not score reverse complements.""")
parser.add_argument("-j", "--n_jobs", type=int, default=-1,
	help="""The number of threads to use for processing queries.""")


##############
# Run appropriate command
##############

args = parser.parse_args()

#

if args.query is not None:
	_run_tomtom(args)

elif args.fasta is not None and args.bed is not None:
	_run_annotate(args)

else:
	raise ValueError("Must provided either a query or a BED and FASTA file.")
