#!/usr/bin/python3
# -*- coding: utf-8 -*-

# ------------------------------------------------------------------------------
# 
# Author: Gabriele Girelli
# Email: gigi.ga90@gmail.com
# Date: 2018-12-01
# Description: Design FISH probe set using ODN databases.
# 
# ------------------------------------------------------------------------------



# ==============================================================================

import argparse
from ggc.args import check_threads
import ifpd as fp
from joblib import Parallel, delayed
import logging
import numpy as np
import os
import pandas as pd
import re
import shutil
import sys
from tqdm import tqdm

parser = argparse.ArgumentParser(description = '''
Design a FISH probe set in a genomic region, with the aim of having the probes
as homogeneously spaced as possible. Concisely, the script does the following:
- Identify all probe candidates
- Calculate centrality, size, and homogeneity for each candidate
- Build a set of consecutive (nProbes+1) windows of the same size
	+ Shift the windows set to build additional windows sets. Each windows set
	  will produce one probe set candidate.
- For each window set:
	+ Find the best probe in each window.
	+ Aggregate each window's best probe into a candidate probe set.
- Rank candidate probe sets based on probe homogeneity.
- Return the top N candidates (maxSets), with plots, tables, fasta and bed.
''', formatter_class = argparse.RawDescriptionHelpFormatter)

parser.add_argument('region', metavar = 'region', type = str,
    help = '''Region of interest in chrN:XXX,YYY format.
    Use chrN only to run chromosome-wide.''')
parser.add_argument('nProbes', metavar = 'nProbes', type = int,
    help = 'Number of probes to design.')
parser.add_argument('database', metavar = 'database', type = str,
    help = 'Path to database folder.')
parser.add_argument('outdir', metavar = 'outputDirectory', type = str,
    help = 'Path to query output directory. Stops if it exists already.')

featureList = ['size', 'homogeneity', 'centrality']
parser.add_argument('--order', metavar = 'featOrder', type = str,
    default = featureList, nargs = '+',
    help = '''Space-separated features, used as explained in script description.
    The available features are: centrality, size, and homogeneity. At least 2
    features must be listed. Default: "size homogeneity centrality".''')
parser.add_argument('--filter-thr', metavar = 'filterThr', type = float,
    default = .1, help = '''Threshold of first feature filter, used to identify
    a range around the best value (percentage range around it). Accepts values
    from 0 to 1. Default: 0.1''')
parser.add_argument('--min-d', metavar = 'minD', type = int, default = 0,
    help = '*DEPRECATED* Minimum distance between consecutive oligos. Default: 1')
parser.add_argument('--n-oligo', metavar = 'nOligo', type = int,
    default = 48, help = 'Number of oligos per probe. Default: 48')
parser.add_argument('--max-sets', metavar = 'maxProbes', type = int,
    default = -1, help = '''Maximum number of probe set candidates to output.
        Set to -1 to retrieve all candidates. Default: -1''')
parser.add_argument('--window-shift', metavar = 'winShift', type = float,
	default = .1, help = '''Window fraction for windows shifting.''')
parser.add_argument('-t', '--threads', metavar = "nthreads", type = int,
    help = """Number of threads for parallelization. Default: 1""",
    default = 1)

parser.add_argument('-f', action = 'store_const',
    dest = 'forceRun', const = True, default = False,
    help = '''Force overwriting of the query if already run.
        This is potentially dangerous.''')
parser.add_argument('--no-net', action = 'store_const',
        dest = 'hasNetwork', const = False, default = True,
        help = '''Use when internet is not available. Then, chromosome size
            defaults to the end of the last oligo.''')

version = "0.0.1"
parser.add_argument('--version', action = 'version',
    version = f'{sys.argv[0]} v{version}')

args = parser.parse_args()

assert_msg = f'output folder expected, file found: {args.outdir}'
assert not os.path.isfile(args.outdir), assert_msg
if args.forceRun:
    if os.path.isdir(args.outdir):
        shutil.rmtree(args.outdir)
else:
    assert_msg = f'output folder already exists: {args.outdir}'
    assert not os.path.isdir(args.outdir), assert_msg
os.mkdir(args.outdir)

formatString = '%(asctime)-25s %(name)s %(levelname)-8s %(message)s'
formatter = logging.Formatter(formatString)
logging.basicConfig(
    filename = os.path.join(args.outdir, 'log'),
    level = logging.INFO,
    format = formatString)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
if args.forceRun and os.path.isdir(args.outdir):
    logging.info('Overwriting previously run query.')

roi_regexp = '^chr[0-9A-Za-z]+(:[0-9]+,[0-9]+)?$'
assert_msg = 'the provided region does not match the expected pattern:'
assert_msg += f' "{args.region}" [chrN:XXX,YYY|chrN]'
assert type(None) != type(re.match(roi_regexp, args.region)), assert_msg

logging.info("Reading database...")
oligoDB = fp.query.OligoDatabase(args.database, False)

if ":" in args.region:
    chrom = args.region.split(":")[0]
    chromStart, chromEnd = [int(x)
        for x in args.region.split(":")[1].split(',')]
else:
    chrom = args.region
    chromStart = 0
    chromEnd = None
    if args.hasNetwork:
        if fp.web.internet_on:
            chromEnd = fp.web.get_chromosome_size(
                chrom, oligoDB.get_reference_genome())
    if type(None) == type(chromEnd):
        assert_msg = f'chromosome "{chrom}" not in the database.'
        assert oligoDB.has_chromosome(chrom), assert_msg
        
        oligoDB.read_chromosome(chrom)
        chromEnd = oligoDB.chromData[chrom].iloc[:, 1].max()
queried_region = (chrom, chromStart, chromEnd)

assert_msg = 'databases with overlapping oligos are not supported yet.'
assert not oligoDB.has_overlaps(), assert_msg

assert_msg = f'at least 2 features need, only {len(args.order)} found.'
assert 2 <= len(args.order), assert_msg
for o in args.order:
    assert_msg =  f'unrecognized feature "{o}". Should be one of {featureList}.'
    assert o in featureList, assert_msg

assert_msg = f'first filter threshold must be a fraction: {args.filter_thr}'
assert 0 <= args.filter_thr and 1 >= args.filter_thr, assert_msg
assert_msg = f'window shift must be a fraction: {args.window_shift}'
assert 0 < args.window_shift and 1 >= args.window_shift, assert_msg

assert_msg = 'negative minimum distance between consecutive oligos: '
assert_msg += f'{args.min_d}'
assert args.min_d >= 0, assert_msg

assert args.n_oligo >= 1, f'a probe must have oligos: {args.n_oligo}'
if args.max_sets == -1:
    args.max_sets = np.inf
assert_msg = f'at least 1 probe set in output: {args.max_sets}'
assert args.max_sets >= 0, assert_msg

if not chrom in oligoDB.chromData.keys():
    oligoDB.read_chromosome(chrom)
chromData = oligoDB.chromData[chrom]

selectCondition = np.logical_and(
    chromData.iloc[:, 0] >= chromStart,
    chromData.iloc[:, 1] <= chromEnd)
selectedOligos = chromData.loc[selectCondition, :]

assert_msg = 'there are not enough oligos in the database.'
assert_msg += f' Asked for {args.n_oligo}, {selectCondition.sum()} found.'
assert args.n_oligo <= selectCondition.sum(), assert_msg
logging.info(f'Found {selectCondition.sum()} oligos in {args.region}')

if 3 > selectedOligos.shape[1]:
    logging.info("Retrieving sequences from UCSC...")
    sequences = []
    for i in tqdm(selectedOligos.index):
        region = selectedOligos.iloc[i, :2].tolist()
        regionSequence = fp.web.get_sequence_from_UCSC(
            config['DATABASE']['refGenome'], chrom, *region)
        sequences.append(regionSequence)
    selectedOligos.assign(name = pd.Series(sequences,
        index = selectedOligos.index).values)

logging.info("Building probe candidates...")

args.threads = check_threads(args.threads)
if 1 != args.threads:
    candidateList = Parallel(n_jobs = args.threads,
        backend = 'threading', verbose = 1)(
        delayed(fp.query.OligoProbe)(queried_region[0], 
            selectedOligos.iloc[ii:(ii+args.n_oligo), :], oligoDB)
        for ii in range(0, selectedOligos.shape[0]-args.n_oligo))
else:
    candidateList = []
    for i in tqdm(range(0, selectedOligos.shape[0]-args.n_oligo)):
        candidateList.append(fp.query.OligoProbe(queried_region[0], 
            selectedOligos.iloc[i:(i+args.n_oligo), :], oligoDB))
logging.info(f"Found {len(candidateList)} probe candidates.")

logging.info("Describing candidates...")
probeFeatureTable = fp.query.ProbeFeatureTable(
    candidateList, queried_region, True, args.threads)

logging.info("Writing description table...")
probeFeatureTable.data.to_csv(
    os.path.join(args.outdir, "probe_candidates.tsv"),
    "\t", index = False)

assert_msg  =  'not enough probes in the region of interest: '
assert_msg += f'{probeFeatureTable.data.shape[0]}/{args.nProbes}'
assert args.nProbes < probeFeatureTable.data.shape[0], assert_msg

logging.info("Building window sets...")
window_set = fp.query.GenomicWindowList(None)
window_size = chromEnd - chromStart
window_size /= (args.nProbes + 1)
window_size = int(window_size)

skip = 1 + (0 != (chromEnd - chromStart) % window_size)
for startPosition in range(chromStart, chromEnd, window_size)[:-skip]:
    window_set.add(chrom, startPosition, window_size)

window_shift = max(int(args.window_shift * window_size),
    args.min_d + oligoDB.get_oligo_length_range()[0])

window_setList = []
for s in range(0, window_size, window_shift):
    window_setList.append(window_set.shift(s))
logging.info(f' Built {len(window_setList)} window sets.')

for wsi in range(len(window_setList)):
    window_set = window_setList[wsi]
    probe_set = []
    for wi in range(len(window_set)):
        window = window_set[wi]
        probeFeatureTable.reset()

        selectCondition = np.logical_and(
            probeFeatureTable.data.loc[:, 'chromStart'] >= window.chromStart,
            probeFeatureTable.data.loc[:, 'chromEnd'] <= window.chromEnd)
        logging.info(f'Found {selectCondition.sum()} probe candidates' +
            f' in window #{wi} of set #{wsi}.')

        if 0 == selectCondition.sum():
            window_setList[wsi][wi].probe = None
            continue
        elif 1 != selectCondition.sum():
            probeFeatureTable.keep(selectCondition, cumulative = True)
            feature_range, feature = probeFeatureTable.filter(
                args.order[0], args.filter_thr, cumulative = True)
            feature_range = np.round(feature_range, 6)
            logging.info(f" Selected {probeFeatureTable.data.shape[0]}" +
                f" candidates in the range {feature_range} of '{feature}'.")
            logging.info(f" Ranking probe candidates based on" + 
                f" '{args.order[1]}'.")
            probeFeatureTable.rank(args.order[1])

        window_setList[wsi][wi].probe = candidateList[
            probeFeatureTable.data.index[0]]

logging.info("Comparing probe set candidates...")
probeSetSpread = np.array([ws.calc_probe_size_and_homogeneity()
    for ws in window_setList])
probeCount = np.array([ws.count_probes() for ws in window_setList])
logging.info(f" {sum(probeCount == args.nProbes)}/{len(window_setList)}" +
    f" probe set candidates have {args.nProbes} probes.")

logging.info("Ranking based on #probes and homogeneity (of probes and size)...")
probeSetData = pd.DataFrame.from_dict({
    'id' : range(len(window_setList)),
    'homogeneity' : probeSetSpread[np.argsort(probeCount)[::-1]],
    'nProbes' : probeCount[np.argsort(probeCount)[::-1]]
})
probeSetData.sort_values(by = ['nProbes', 'homogeneity'],
    ascending = [False, False], inplace = True)
probeSetData.drop("id", axis = 1).to_csv(
    os.path.join(args.outdir, "set_candidates.tsv"), "\t", index = False)

logging.info("Exporting probe set candidates...")
window_setList = [window_setList[i] for i in probeSetData['id'].values]

def export_window_set(window_setList, wsi):
    window_set  = window_setList[wsi]
    window_set_path = os.path.join(args.outdir, f"probe_set_{wsi}")
    assert not os.path.isfile(window_set_path)
    assert not os.path.isdir(window_set_path)
    os.mkdir(window_set_path)
    
    fasta, bed = window_set.export(window_set_path, queried_region)
    fasta_path = os.path.join(window_set_path, f'probe_set_{wsi}.fa')
    bed_path = os.path.join(window_set_path, f'probe_set_{wsi}.bed')
    with open(fasta_path, 'w+') as OH:
        OH.write(fasta)
    with open(bed_path, 'w+') as OH:
        OH.write(bed)

if 1 != args.threads:
    tmp = Parallel(n_jobs = args.threads, verbose = 1)(
        delayed(export_window_set)(window_setList, wsi)
        for wsi in range(len(window_setList)))
else:
    for wsi in tqdm(range(len(window_setList))):
        export_window_set(window_setList, wsi)

# END ==========================================================================

################################################################################
