#!/usr/bin/env python3

##############################################################################
# deldenoiser is a command line tool to denoise read counts of a
# DNA-encoded experiment.
#
# Copyright (C) 2020  Totient, Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License v3.0
# along with this program.
# If not, see <https://www.gnu.org/licenses/gpl-3.0.en.html>.
#
# Developer:
# Peter Komar (peter.komar@totient.bio)
##############################################################################

"""deldenoiser

Command line tool to denoise read counts of a DNA-encoded experiment.

Usage:

    deldenoiser --design <DEL_design.tsv[.gz]>  \
                --yields <yields.tsv[.gz]>  \
                --preselection_readcount <readcounts_pre.tsv[.gz]>  \
                --postselection_readcount <readcounts_post.tsv[.gz]>  \
                --output_prefix <prefix> \
                --regularization_strength <alpha> \
                --dispersion <gamma> \
                --maxiter <maxiter> \
                --inner_maxiter <inner_maxiter> \
                --tolerance <tol> \
                --parallel_processes <processes> \
                --minyield <minyield> \
                --maxyield <maxyield> \
                --F_init <F_init> \
                --max_downsteps <max_downsteps>

Parameters:

    <DEL_design.tsv[.gz]>: tab-separated values that encode the number of synthesis
        cycles and the number of lanes in each cycle, with two columns:
            * "cycle": cycle index (1,2,... cmax)
            * "lanes": number of lanes in the corresponding cycle (>= 1)

    <yields.tsv[.gz]>: tab-separated values that encode the yields of the reactions
        during synthesis, with three columns:
            * "cycle": cycle index (1,2,... cmax)
            * "lane": lane index (1,2, ... lmax[c])
            * "yield": yield of reaction in the corresponding lane
            (real number between 0.0 and 1.0)

    <readcounts_pre.tsv[.gz]>: tab-separated values that encode the read counts
        obtained from sequencing done before the DEL selection steps,
        with cmax + 1 columns:
            * "cycle_1_lane": lane index of cycle 1
            * "cycle_2_lane": lane index of cycle 2
            * ...
            * "cycle_<cmax>_lane": lane index of cycle cmax
            * "readcount": number of reads of the DNA tag that identifies
            the corresponding lane index combination (non-negative integers)

    <readcounts_post.tsv[.gz]>, same structure as <readcounts_pre.tsv>,
        but for reads obtained from sequencing done after
        the DEL selection step.

    <prefix>: string (that can include the path) to name the output files.

    <alpha>: strength of L1 regularization, high alpha results is stronger
        background subtraction. (default = 1.0)

    <gamma>: dispersion parameter of the dispersed Poisson noise,
        ratio of variance and expected value. (default = 1.0)

    <maxiter>: maximum number of coordinate descent iterations during fitting
        truncate fitness coefficients (default = 20)

    <inner_maxiter>: maximum number of iterations used for each coordinate
        descent step (these are usually Newton-Raphson steps,
        but if those fail, they become Anderson-Bjoerck steps.

    <tol>: tolerance, if the intensity due to truncates changes less than this
        between consecutive iterations of coordinate descent, the the fitting
        is stopped, before reaching maxiter number of iterations
        (default = 0.1)

    <processes>: max number of parallel processes to start
        during fitting truncates (default = number of system CPUs)

    <minyield>: lowest allowed input yield value, yields lower than this
        get censored to this level during preprocessing (default = 1e-10)

    <maxyield>: highest allowed input yield value, yields higher than this
        get censored to this level during preprocessing (default = 0.95)

    <F_init>: initial value for truncate fitness
        (default: internal guess is used)

    <max_downsteps>: max number of allowed iterations when logL is decreasing
        If it is reached, the optimization terminates. (default = 5)


Output files:

    1. <prefix>_fullcycleproducts.tsv.gz: tab-separated values containing the
        results about full-cycle products, each identified by their extended
        lane index combination. The cmax + 3 columns contain
            * `cycle_<cid>_lane`: lane index of cycle cid = 1,2,... cmax
            * `fitness`: fitness coefficients
            * `clean_reads`: posterior mode of clean reads
        Note: Only records corresponding to non-zero input read counts are
            printed in this file. Compounds with zero observed reads are
            implicitly assumed to have zero fitness, and zero clean reads.

    2. <prefix>_truncates.tsv.gz: tab-separated encoding the fitness
        coefficients of the truncates, each identified by their extended lane
        index combination. The cmax + 1 columns contain
            * `cycle_<cid>_lane`: extended lane index (which can take 0
            as well, as an indication that the synthesis cycle failed)
            of cycle cid = 0,1,2,... cmax
            * `fitness`: fitness coefficient truncated compounds
        Note: Only records corresponding to truncates that are estimated to
            have non-zero fitness are printed in this file. The truncates
            missing from here should be understood to have zero fitness.

    3. <prefix>_tag_imbalance_factors.tsv.gz: tab-separated values containing
        the estimated tag imbalance factors (bhat) for each cycle and lane.
        It has 3 columns (the same shape as the optional <yields.tsv[.gz]>
        input file):
            * "cycle": cycle index (1,2,... cmax)
            * "lane": lane index (1,2, ... lmax[c])
            * "imbalance_factor": imbalance factor (bhat[cycle][lane])
"""

import os
import sys
import datetime
import argparse
import numpy as np
import pandas as pd
import warnings
from itertools import product
from multiprocessing import cpu_count

from deldenoiser.nullblockmodel import NullBlockModel

DEFAULT_YIELDS = 0.5
DEBUG = True


def debugging_message(msg):
    ts = datetime.datetime.now().isoformat(sep=' ', timespec='seconds')
    print(' : '.join([ts, msg]), file=sys.stderr)


def preprocess_design_file(design_file):
    if design_file.endswith('.gz'):
        compression = 'gzip'
    else:
        compression = None
    df_design = pd.read_csv(design_file, sep='\t', compression=compression)
    if list(df_design.columns) != ['cycle', 'lanes']:
        raise ValueError('design file must have columns ["cycle", "lanes"]')

    if (df_design['cycle'].values !=
        np.arange(1, len(df_design) + 1, 1)).any():
        raise ValueError('"cycle" column of design file must contain '
                         'consecutive positive integers starting from 1')

    error_message = '"lanes" column must contain positive integers'
    try:
        design = df_design['lanes'].values.astype(int)
    except:
        raise ValueError(error_message)
    if (df_design['lanes'] <= 0).any():
        raise ValueError(error_message)
    if (design != df_design['lanes'].values).any():
        raise ValueError(error_message)

    return design


def preprocess_output_prefix(prefix):
    if os.path.sep in prefix:
        directory, fname_prefix = os.path.split(prefix)
        try:
            if not os.path.isdir(directory):
                print(f'directory {directory} does not exist, it is created',
                      file=sys.stderr)
                os.makedirs(directory)
            with open(prefix + '_test.file', 'w'):
                pass
            os.remove(prefix + '_test.file')
        except:
            raise ValueError('cannot write to path specified by prefix')

    return prefix


def preprocess_regularization_strength(alpha):
    if alpha <= 0.0:
        raise ValueError('regularization strength (alpha) must be positive')
    return alpha


def preprocess_dispersion(gamma):
    if gamma <= 0.0:
        raise ValueError('dispersion (gamma) must be positive')
    return gamma


def preprocess_maxiter(maxiter):
    if maxiter < 1:
        raise ValueError('maxiter must be positive integer')
    return maxiter

def preprocess_inner_maxiter(inner_maxiter):
    if inner_maxiter < 1:
        raise ValueError('inner_maxiter must be positive integer')
    return inner_maxiter

def preprocess_tolerance(tol):
    if tol <= 0.0:
        raise ValueError('tolerance (tol) must be positive')
    return tol


def preprocess_processes(processes):
    if processes < 1:
        raise ValueError('processes must be positive integer')
    cpus = cpu_count()
    if processes > cpus:
        warnings.warn(f'parallel_processes ({processes}) is higher '
                      f'than the number of system CPUs found ({cpus}). '
                      f'This may not be optimal.')
    return processes


def main(args):

    ##########################################################################
    # Step 0: Pre-process data
    ##########################################################################
    if DEBUG:
        debugging_message('Pre-processing input data')
    prefix = preprocess_output_prefix(args.output_prefix)
    alpha = preprocess_regularization_strength(args.regularization_strength)
    gamma = preprocess_dispersion(args.dispersion)
    design = preprocess_design_file(args.design)
    maxiter = preprocess_maxiter(args.maxiter)
    inner_maxiter = preprocess_maxiter(args.inner_maxiter)
    tol = preprocess_tolerance(args.tolerance)
    processes = preprocess_processes(args.parallel_processes)

    nbm = NullBlockModel(design, alpha, gamma, default_yield=DEFAULT_YIELDS)

    ##########################################################################
    # Step 1: Fit sequencing bias from pre-selection read counts
    ##########################################################################
    if DEBUG:
        debugging_message('Fitting sequencing bias')
    if args.preselection_readcount is not None:
        nbm.fit_sequence_imbalance(args.preselection_readcount)
    else:
        debugging_message('Assuming uniform imbalance factors')
        nbm.bhat = [np.ones(lmax) / lmax for lmax in nbm.design]
        nbm.x = nbm._compute_x()

    ##########################################################################
    # Step 2: Load yields
    ##########################################################################
    if DEBUG:
        debugging_message('Loading yields')
    if args.yields is not None:
        nbm.load_yields(args.yields,
                        float(args.minyield), float(args.maxyield))
    else:
        debugging_message('Assuming uniform yields')
        nbm.yields = [DEFAULT_YIELDS * np.ones(lmax) for lmax in nbm.design]
        nbm.x = nbm._compute_x()

    ##########################################################################
    # Step 3: Load post-selection read count data
    ##########################################################################
    if DEBUG:
        debugging_message('Loading post-selection read count data')
    nbm.load_postselection_readcount(args.postselection_readcount)

    ##########################################################################
    # Step 4 & 5: Guess and fit fitness of truncated compounds
    ##########################################################################
    if DEBUG:
        debugging_message('Guessing initial fitness values of truncates')
    nbm.guess_truncates(processes=processes)

    if DEBUG:
        debugging_message('Fitting fitness values of truncates')
    nbm.fit_truncates(maxiter, inner_maxiter,tol=tol,
                      processes=processes, debug=DEBUG,
                      max_downsteps=args.max_downsteps,
                      F_init=args.F_init)

    ##########################################################################
    # Step 6: Fit fitness of full-cycle compounds
    ##########################################################################
    if DEBUG:
        debugging_message('Fitting fitness of full-cycle compounds')
    F_fullcycle = nbm.fit_fullcycles()

    ##########################################################################
    # Step 7: Compute breakdown of read counts
    ##########################################################################
    if DEBUG:
        debugging_message('Computing read count breakdown')
    N_clean = nbm.compute_clean_readcounts(F_fullcycle)

    ##########################################################################
    # Step 8: Write results to files
    ##########################################################################
    if DEBUG:
        debugging_message('Compiling output tables')
    cmax = len(nbm.design)
    output_kwargs = {
        'compression': 'gzip',
        'sep': '\t',
        'index': False,
        'float_format': '%.4f'
    }

    # Step 7.1
    # Write table <prefix>_fullcycleproducts.tsv.gz
    # Columns:
    # * `cycle_<cid>_lane`: lane index of cycle cid = 1,2,... cmax
    # * `fitness`: fitness coefficients
    # * `clean_reads`: posterior mode of clean reads
    df_fullcycleproducts = pd.DataFrame()
    for c in range(cmax):
        df_fullcycleproducts[f'cycle_{c+1}_lane'] = nbm.r_arr[:, c]
    df_fullcycleproducts['fitness'] = F_fullcycle
    df_fullcycleproducts['clean_reads'] = N_clean
    df_fullcycleproducts.to_csv(f'{prefix}_fullcycleproducts.tsv.gz',
                                **output_kwargs)

    # Step 7.2
    # Write table <prefix>_truncates.tsv.gz
    # Columns:
    # * `cycle_<cid>_lane`: extended lane index
    # * `fitness`: fitness coefficient truncated compounds
    df_truncates = pd.DataFrame()
    for c in range(cmax):
        df_truncates[f'cycle_{c+1}_lane'] = nbm.q_arr[:, c]
    df_truncates['fitness'] = nbm.F
    df_truncates.to_csv(f'{prefix}_truncates.tsv.gz', **output_kwargs)

    # Step 7.3
    # Write table <prefix>_tag_imbalance_factors.tsv.gz
    # Columns:
    # * "cycle": cycle index (1,2,... cmax)
    # * "lane": lane index (1,2, ... lmax[c])
    # * "imbalance_factor": imbalance factor (bhat[cycle][lane])
    df_tagimbalance = pd.DataFrame()
    df_tagimbalance['cycle'] = np.concatenate(
        [[c+1] * nbm.design[c] for c in range(cmax)])
    df_tagimbalance['lane'] = np.concatenate(
        [list(range(1, nbm.design[c]+1)) for c in range(cmax)]
    )
    df_tagimbalance['imbalance_factor'] = np.concatenate(nbm.bhat)
    df_tagimbalance.to_csv(f'{prefix}_tag_imbalance_factors.tsv.gz',
                           **output_kwargs)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(
        description="Command line tool for DEL readcount denoising.")

    parser.add_argument('--design', help=
    """ TSV file that encodes the number of synthesis
        cycles and the number of lanes in each cycle, with two columns:
            "cycle": cycle index (1,2,... cmax),
            "lanes": number of lanes in the corresponding cycle (must be >= 1)
    """,
                        required=True)

    parser.add_argument('--postselection_readcount', help=
    """ TSV file that encode the read counts
            obtained from sequencing done after the DEL selection steps,
            with cmax + 1 columns:
                "cycle_1_lane": lane index of cycle 1,
                "cycle_2_lane": lane index of cycle 2,
                ...
                "cycle_<cmax>_lane": lane index of cycle <cmax>,
                "readcount": number of reads of the DNA tag that identifies
                    the corresponding lane index combination
                    (non-negative integers)
    """,
                        required=True)

    parser.add_argument('--output_prefix', help=
    """ string (which can include the path) to name the output files.
    """,
                        required=True,
                        type=str)

    parser.add_argument('--dispersion', help=
    """ dispersion parameter of the dispersed Poisson noise, ratio of variance
        and expected value. (default: 1.0)
    """,
                        required=False,
                        type=float, default=1.0)

    parser.add_argument('--regularization_strength', help=
    """ strength of regularization towards sparse solutions,
        (default: 1.0)
    """,
                        required=False,
                        type=float, default=1.0)

    parser.add_argument('--maxiter', help=
    """ maximum number of coordinate descent iterations
        during fitting truncates. (default: 20)
    """,
                        required=False,
                        type=int, default=20)

    parser.add_argument('--inner_maxiter', help=
    """ maximum number of iterations for each coordinate descent step
        during fitting truncates. (default: 10)
    """,
                        required=False,
                        type=int, default=10)

    parser.add_argument('--tolerance', help=
    """ tolerance, if the intensity due to truncates changes less than this
        between consecutive iterations of coordinate descent, the the fitting
        is stopped, before reaching maxiter number of iterations (default: 0.1)
    """,
                        required=False,
                        type=float, default=0.1)

    parser.add_argument('--parallel_processes', help=
    """ max number of parallel processes to run during fitting truncates
        (default: number of system CPUs)
    """,
                        required=False,
                        type=int, default=cpu_count())

    parser.add_argument('--yields', help=
    f""" TSV file that encodes the yields of the reactions
        during synthesis, with three columns:
            "cycle": cycle index (1,2,... cmax),
            "lane": lane index (1,2, ...
                [number of lanes in corresponding cycle]),
            "yield": yield of reaction in the corresponding lane
                (real number between 0.0 and 1.0)
        (default: all values = {DEFAULT_YIELDS} )
    """,
                        required=False)

    parser.add_argument('--preselection_readcount', help=
    """ Same structure as POSTSELECTION_READCOUNTS, for read counts obtained
        from sequencing before DEL selection step.
    """,
                        required=False)

    parser.add_argument('--minyield', help=
    """ Lowest allowed yield value, inputs lower than this are censored to
        this level. (default = 1e-10)
    """,
                        type=float, default=1e-10 ,
                        required=False)

    parser.add_argument('--maxyield', help=
    """ Highest allowed yield value, inputs higher than this are censored to
        this level. (default = 0.95)
    """,
                        type=float, default=0.95 ,
                        required=False)

    parser.add_argument('--F_init', help=
    """ Initial guess of truncate fintess values.
    """,
                        default=None, required=False)

    parser.add_argument('--max_downsteps', help=
    """ if log-likelihood decreases this number of times throughout 
        the optimization, the optimization is terminated
    """,
                        type=int, default=5,
                        required=False)

    if DEBUG:
        debugging_message('Reading command line arguments')
    args = parser.parse_args()

    main(args)
