#!python

##############################################################################
# deldenoiser is a command line tool to denoise read counts of a
# DNA-encoded experiment.
#
# Copyright (C) 2020  Totient, Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License v3.0
# along with this program.
# If not, see <https://www.gnu.org/licenses/gpl-3.0.en.html>.
#
# Developer:
# Peter Komar (peter.komar@totient.bio)
##############################################################################

"""deldenoiser

Command line tool to denoise read counts of a DNA-encoded experiment.

Usage:

    deldenoiser --design <DEL_design.tsv>  \
                --yields <yields.tsv>  \
                --preselection_readcount <readcounts_pre.tsv>  \
                --postselection_readcount <readcounts_post.tsv>  \
                --output_prefix <prefix> \
                --regularization_strength <alpha> \
                --dispersion <gamma> \
                --maxiter <maxiter> \
                --tolerance <tol> \
                --parallel_processes <processes>

Parameters:

    <DEL_design.tsv>: tab-separated values that encode the number of synthesis
        cycles and the number of lanes in each cycle, with two columns:
            * "cycle": cycle index (1,2,... cmax)
            * "lanes": number of lanes in the corresponding cycle (>= 1)

    <yields.tsv>: tab-separated values that encode the yields of the reactions
        during synthesis, with three columns:
            * "cycle": cycle index (1,2,... cmax)
            * "lane": lane index (1,2, ...
            [number of lanes in the corresponding cycle])
            * "yield": yield of reaction in the corresponding lane
            (real number between 0.0 and 1.0)

    <readcounts_pre.tsv>: tab-separated values that encode the read counts
        obtained from sequencing done before the DEL selection steps,
        with cmax + 1 columns:
            * "cycle_1_lane": lane index of cycle 1
            * "cycle_2_lane": lane index of cycle 2
            * ...
            * "cycle_<cmax>_lane": lane index of cycle cmax
            * "readcount": number of reads of the DNA tag that identifies
            the corresponding lane index combination (non-negative integers)

    <readcounts_post.tsv>, same structure as <readcounts_pre.tsv>,
        but for reads obtained from sequencing done after
        the DEL selection step.

    <prefix>: string (that can include the path) to name the output files.

    <alpha>: strength of L1 regularization, high alpha results is stronger
        background subtraction. (default = 1.0)

    <gamma>: dispersion parameter of the dispersed Poisson noise,
        ratio of variance and expected value. (default = 1.0)

    <maxiter>: maximum number of coordinate descent iterations during fitting
        truncate fitness coefficients (default = 20)

    <tol>: tolerance, if the intensity due to truncates changes less than this
        between consecutive iterations of coordinate descent, the the fitting
        is stopped, before reaching maxiter number of iterations
        (default = 0.1)

    <processes>: max number of parallel processes to start
        during fitting truncates (default = number of system CPUs)

Output files:

    1. <prefix>_product_yields.tsv: tab-separated values with the same
        set of cycle lane index columns as the readcount input files,
        but instead of the readcounts this file contains the overall
        yield of the complete compounds:
            * "cycle_<cid>_lane": lane index of cycle cid = 1,2,... cmax
            * "compound_yield": overall yield of the compound corresponding
            lane index combination

    2. <prefix>_count_breakdown.tsv: tab-separated values with the
        same set of cycle lane index columns as the readcount input files,
        plus a breakdown of the read counts by success patterns:
            * `cycle_<cid>_lane`: lane index of cycle cid = 1,2,... cmax
            * `readcount_<pattern>`: fractional read count associated with
            lane combination and the success pattern, which is a string of
            cmax "0" or "1" characters. (E.g. if cmax = 3, one of the columns
            will be `readcounts_010`, which will contain the read counts for
            the truncates due the faliure of both the first and the
            third cycle.)

    3. <prefix>_fitness.tsv: tab-separated values with the same set of cycle
        lane index columns as the readcount input files, and a single column
        containing the fitness coefficients of the legitimate product
        associated with the lane combination.
            * `cycle_<cid>_lane`: lane index of cycle cid = 1,2,... cmax
            * `fitness`: fitness coefficient legitimate compounds

    4. <prefix>_truncate_fitness.tsv: tab-separated encoding the fitness
        coefficients of the truncates, each identified by their extended lane
        index combination. The cmax + 1 columns contain
            * `cycle_<cid>_lane`: extended lane index (which can take 0
            as well, as an indication that the synthesis cycle failed)
            of cycle cid = 1,2,... cmax
            * `fitness`: fitness coefficient truncated compounds

    5. <prefix>_tag_imbalance.tsv: tab-separated values
        with the same set of cycle lane index columns as the readcount
        input files, and a single column
        containing the tag imbalance for each tag.
            * `cycle_<cid>_lane`: lane index of cycle cid = 1,2,... cmax
            * `tag_imbalance`: tag imbalance for each tag
"""

import os
import sys
import datetime
import argparse
import numpy as np
import pandas as pd
import warnings
from itertools import product
from multiprocessing import cpu_count

from deldenoiser.deldenoiser import NullblockModel

DEFAULT_YIELDS = 0.5
DEBUG = True


def debugging_message(msg):
    ts = datetime.datetime.now().isoformat(sep=' ', timespec='seconds')
    print(' : '.join([ts, msg]), file=sys.stderr)


def preprocess_design_file(design_file):
    df_design = pd.read_csv(design_file, sep='\t')
    if list(df_design.columns) != ['cycle', 'lanes']:
        raise ValueError('design file must have columns ["cycle", "lanes"]')

    if (df_design['cycle'].values !=
        np.arange(1, len(df_design) + 1, 1)).any():
        raise ValueError('"cycle" column of design file must contain '
                         'consecutive positive integers starting from 1')

    error_message = '"lanes" column must contain positive integers'
    try:
        design = df_design['lanes'].values.astype(int)
    except:
        raise ValueError(error_message)
    if (df_design['lanes'] <= 0).any():
        raise ValueError(error_message)
    if (design != df_design['lanes'].values).any():
        raise ValueError(error_message)

    return design


def preprocess_readcount_file(readcount_file, model):
    df_counts = pd.read_csv(readcount_file, sep='\t')
    cols = len(df_counts.columns)
    cmax = cols - 1
    lane_index_cols = [f'cycle_{c}_lane' for c in range(1, cmax + 1, 1)]
    if (list(df_counts.columns) != lane_index_cols + ['readcount']):
        raise ValueError(
            'readcount file must have columns: "cycle_<cid>_lane", '
            'where <cid> = 1,2,... cmax, and "readcount".')
    for c in range(cmax):
        col = f'cycle_{c+1}_lane'
        error_message = f'column "{col}" must contain positive integers'
        lane = df_counts[col].values
        if (lane != lane.astype(int)).any():
            raise ValueError(error_message)
        if (lane <= 0).any():
            raise ValueError(error_message)

    df_counts.sort_values(by=lane_index_cols, inplace=True)

    lane_indexes = df_counts[lane_index_cols].values
    if len(set(map(tuple, lane_indexes))) < len(lane_indexes):
        raise ValueError('cycle lane columns in read count file must not '
                         'contain duplicate '
                         'lane index combinations')

    error_message = f'column "readcount" in read count file ' \
                    f'must contain non-negative integers'
    counts = df_counts['readcount'].values
    if (counts != counts.astype(int)).any():
        raise ValueError(error_message)
    if (counts < 0).any():
        raise ValueError(error_message)

    error_message = 'cycle lane columns in read count file must contain a ' \
                    'exactly the lane combinations allowed by the design, ' \
                    'and must be sorted'
    if (lane_indexes != model.lane_indexes).any():
        raise ValueError(error_message)

    return counts


def preprocess_yields_file(yields_file, model):
    df_yields = pd.read_csv(yields_file, sep='\t')
    if list(df_yields.columns) != ['cycle', 'lane', 'yield']:
        raise ValueError('yields file must have columns '
                         '"cycle", "lane", "yields')

    error_message = '"cycle" column of yields file ' \
                    'must contain positive integers'
    if (df_yields['cycle'].values !=
        df_yields['cycle'].values.astype(int)).any():
        raise ValueError(error_message)
    if (df_yields['cycle'].values <= 0).any():
        raise ValueError(error_message)

    error_message = '"lane" column of yields file ' \
                    'must contain positive integers'
    if (df_yields['lane'].values !=
        df_yields['lane'].values.astype(int)).any():
        raise ValueError(error_message)
    if (df_yields['lane'].values <= 0).any():
        raise ValueError(error_message)

    error_message = '"yield" column of yields file must contain ' \
                    'floating point values, 0.0 < yield < 1.0'
    if (df_yields['yield'] <= 0.0).any() or (df_yields['yield'] >= 1.0).any():
        raise ValueError(error_message)
    if np.isnan(df_yields['yield']).any():
        raise ValueError(error_message)

    df_yields.sort_values(by=['cycle', 'lane'], inplace=True)
    cycles = df_yields['cycle'].values
    lanes = df_yields['lane'].values
    yields = df_yields['yield'].values

    error_message = '"cycle" and "lane" columns in yields file ' \
                    'must list every cycle-lane ' \
                    'combination in the design.'
    cmax = len(model.design)
    expected_cycles = np.concatenate([(c + 1) * np.ones(model.design[c])
                                      for c in range(cmax)])
    expected_lanes = np.concatenate([np.arange(1, model.design[c] + 1, 1)
                                     for c in range(cmax)])
    if (cycles != expected_cycles).any() or (lanes != expected_lanes).any():
        raise ValueError(error_message)

    yield_list = []
    idx_start = 0
    for c in range(cmax):
        yield_list.append(yields[idx_start:idx_start + model.design[c]])
        idx_start += model.design[c]
    return yield_list


def preprocess_output_prefix(prefix):
    if os.path.sep in prefix:
        directory, fname_prefix = os.path.split(prefix)
        try:
            if not os.path.isdir(directory):
                print(f'directory {directory} does not exist, it is created',
                      file=sys.stderr)
                os.makedirs(directory)
            with open(prefix + '_test.file', 'w'):
                pass
            os.remove(prefix + '_test.file')
        except:
            raise ValueError('cannot write to path specified by prefix')

    return prefix


def preprocess_regularization_strength(alpha):
    if alpha <= 0.0:
        raise ValueError('regularization strength (alpha) must be positive')
    return alpha


def preprocess_dispersion(gamma):
    if gamma <= 0.0:
        raise ValueError('dispersion (gamma) must be positive')
    return gamma


def preprocess_maxiter(maxiter):
    if maxiter < 1:
        raise ValueError('maxiter must be positive integer')
    return maxiter


def preprocess_tolerance(tol):
    if tol <= 0.0:
        raise ValueError('tolerance (tol) must be positive')
    return tol


def preprocess_processes(processes):
    if processes < 1:
        raise ValueError('processes must be positive integer')
    cpus = cpu_count()
    if processes > cpus:
        warnings.warn(f'parallel_processes ({processes}) is higher '
                      f'than the number of system CPUs found ({cpus}). '
                      f'This may not be optimal.')
    return processes


def main(args):

    ##########################################################################
    # Step 1: Pre-process data
    ##########################################################################
    if DEBUG:
        debugging_message('Pre-processing input data')
    prefix = preprocess_output_prefix(args.output_prefix)
    alpha = preprocess_regularization_strength(args.regularization_strength)
    gamma = preprocess_dispersion(args.dispersion)
    design = preprocess_design_file(args.design)
    maxiter = preprocess_maxiter(args.maxiter)
    tol = preprocess_tolerance(args.tolerance)
    processes = preprocess_processes(args.parallel_processes)

    model = NullblockModel(design, alpha, gamma)

    post_counts = preprocess_readcount_file(args.postselection_readcount,
                                            model)

    if args.preselection_readcount is not None:
        pre_counts = preprocess_readcount_file(args.preselection_readcount,
                                               model)
    else:
        pre_counts = None

    if args.yields is not None:
        yields = preprocess_yields_file(args.yields, model)
    else:
        yields = None

    ##########################################################################
    # Step 2: Load yields
    ##########################################################################
    if DEBUG:
        debugging_message('Loading yields')
    if yields is None:
        yields = [DEFAULT_YIELDS * np.ones(lmax) for lmax in model.design]
    model.load_yields(yields)

    ##########################################################################
    # Step 3: Fit sequencing bias from pre-selection read counts
    ##########################################################################
    if DEBUG:
        debugging_message('Fitting sequencing bias')
    if pre_counts is None:
        # No pre-sequencing data -> assume uniform sequencing efficiency
        sequencing_bias = np.ones_like(post_counts)
    else:
        sequencing_bias, _ = model.fit_sequencing_bias(pre_counts)

    ##########################################################################
    # Step 4: Fit fitness of truncated compounds
    ##########################################################################
    if DEBUG:
        debugging_message('Fitting fitness of truncates')
    truncates_fitness, truncates_floor = model.fit_truncates(
        sequencing_bias,
        post_counts,
        max_processes=processes,
        maxiter=maxiter,
        tol=tol,
        debug=DEBUG)

    ##########################################################################
    # Step 5: Fit fitness of legitimate compounds
    ##########################################################################
    if DEBUG:
        debugging_message('Fitting fitness of legitimate compounds')
    legitimates_fitness = model.fit_legitimates(
        sequencing_bias,
        post_counts,
        truncates_floor
    )

    ##########################################################################
    # Step 6: Compute breakdown of read counts
    ##########################################################################
    if DEBUG:
        debugging_message('Computing read count breakdown')
    count_breakdown_matrix = model.compute_readcount_breakdown(
        sequencing_bias,
        truncates_fitness,
        legitimates_fitness['mean'],
        post_counts
    )

    ##########################################################################
    # Step 7: Write results to files
    ##########################################################################
    if DEBUG:
        debugging_message('Compiling output tables')
    cmax = len(model.design)
    s_list = list(product(*[[0, 1] for _ in range(cmax)]))
    df_lanes = pd.DataFrame(
        dict(zip([f'cycle_{c}_lane' for c in range(1, cmax + 1, 1)],
                 model.lane_indexes.T))
    )
    output_kwargs = {
        'sep': '\t',
        'index': False
    }

    # Step 7.1
    # Create overall yields table
    # For each lane combination this contains the appropriate product of
    # individual cycle yields.
    df_yields = df_lanes.copy()
    for s_idx, s in enumerate(s_list):
        s_str = ''.join(map(str, s))
        df_yields[f'fraction_of_{s_str}'] = \
            model.inventory_matrix[:, s_idx]
    df_yields.to_csv(f'{prefix}_inventory.tsv',
                     **output_kwargs, float_format='%.4f')

    # Step 7.2
    # Create count breakdown table
    df_count_breakdown = df_lanes.copy()
    for s_idx, s in enumerate(s_list):
        s_str = ''.join(map(str, s))
        df_count_breakdown[f'readcount_{s_str}'] = \
            count_breakdown_matrix[:, s_idx]
    df_count_breakdown.to_csv(f'{prefix}_count_breakdown.tsv',
                              **output_kwargs, float_format='%.4f')

    # Step 7.3
    # Create fitness table of legitimate compounds
    df_fitness_legitimates = df_lanes.copy()
    df_fitness_legitimates['fitness_mode'] = legitimates_fitness['mode']
    df_fitness_legitimates['fitness_mean'] = legitimates_fitness['mean']
    df_fitness_legitimates['fitness_minimum_stdev'] = \
        legitimates_fitness['std']
    df_fitness_legitimates.to_csv(f'{prefix}_fitness.tsv', **output_kwargs)

    # Step 7.4
    # Create fitness table of only the truncated compounds
    truncates = (model.bb_indexes == 0).any(axis=1)
    index_cols = model.bb_indexes[truncates, :]
    fitness = truncates_fitness[truncates]
    df_fitness_truncates = pd.DataFrame(
        dict(zip([f'cycle_{c}_lane' for c in range(1, cmax + 1, 1)],
                 index_cols.T))
    )
    df_fitness_truncates['fitness'] = fitness
    df_fitness_truncates.to_csv(f'{prefix}_truncate_fitness.tsv',
                                **output_kwargs)

    # Step 7.5
    # Create tag_imbalance table
    df_tag_imbalance = df_lanes.copy()
    df_tag_imbalance['tag_imbalance'] = (sequencing_bias
                                         / np.mean(sequencing_bias)
                                         )
    df_tag_imbalance.to_csv(f'{prefix}_tag_imbalance.tsv', **output_kwargs)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(
        description="Command line tool for DEL readcount denoising.")

    parser.add_argument('--design', help=
    """ TSV file that encodes the number of synthesis
        cycles and the number of lanes in each cycle, with two columns:
            "cycle": cycle index (1,2,... cmax),
            "lanes": number of lanes in the corresponding cycle (must be >= 1)
    """,
                        required=True,
                        type=argparse.FileType('r'))

    parser.add_argument('--postselection_readcount', help=
    """ TSV file that encode the read counts
            obtained from sequencing done after the DEL selection steps,
            with cmax + 1 columns:
                "cycle_1_lane": lane index of cycle 1,
                "cycle_2_lane": lane index of cycle 2,
                ...
                "cycle_<cmax>_lane": lane index of cycle cmax,
                "readcount": number of reads of the DNA tag that identifies
                    the corresponding lane index combination
                    (non-negative integers)
    """,
                        required=True,
                        type=argparse.FileType('r'))

    parser.add_argument('--output_prefix', help=
    """ string (which can include the path) to name the output files.
    """,
                        required=True,
                        type=str)

    parser.add_argument('--dispersion', help=
    """ dispersion parameter of the dispersed Poisson noise, ratio of variance
        and expected value. (default: 1.0)
    """,
                        required=False,
                        type=float, default=1.0)

    parser.add_argument('--regularization_strength', help=
    """ strength of regularization towards sparse solutions,
        (default: 1.0)
    """,
                        required=False,
                        type=float, default=1.0)

    parser.add_argument('--maxiter', help=
    """ maximum number of coordinate descent iterations
        during fitting truncates. (default: 20)
    """,
                        required=False,
                        type=int, default=20)
    parser.add_argument('--tolerance', help=
    """ tolerance, if the intensity due to truncates changes less than this
        between consecutive iterations of coordinate descent, the the fitting
        is stopped, before reaching maxiter number of iterations (default: 0.1)
    """,
                        required=False,
                        type=float, default=0.1)
    parser.add_argument('--parallel_processes', help=
    """ max number of parallel processes to run during fitting truncates
        (default: number of system CPUs)
    """,
                        required=False,
                        type=int, default=cpu_count())
    parser.add_argument('--yields', help=
    f""" TSV file that encodes the yields of the reactions
        during synthesis, with three columns:
            "cycle": cycle index (1,2,... cmax),
            "lane": lane index (1,2, ...
                [number of lanes in corresponding cycle]),
            "yield": yield of reaction in the corresponding lane
                (real number between 0.0 and 1.0)
        (default: all values = {DEFAULT_YIELDS} )
    """,
                        required=False,
                        type=argparse.FileType('r'))

    parser.add_argument('--preselection_readcount', help=
    """ Same structure as POSTSELECTION_READCOUNTS, for read counts obtained
        from sequencing before DEL selection step.
    """,
                        required=False,
                        type=argparse.FileType('r'))

    if DEBUG:
        debugging_message('Reading command line arguments')
    args = parser.parse_args()

    main(args)
