#! /usr/bin/env Rscript

# Usage: run-dea <-config CONFIG> [-lfcThreshold L2F] [-alpha ALPHA] [-symbolCol COLUMN] [-orfCol COLUMN_NUMBER] [-delim TAB/CSV] [-batch]
# 1. <-config CONFIG>        Yaml config file (same as used for run-htseq-workflow)
# 2. [-lfcThreshold L2F]     LFC threshold - default log2(1.2)
# 3. [-alpha ALPHA]          FDR threshold - default 0.05
# 4. [-symbolCol COLUMN]     Column for feature symbol/names to add to results - default 2
# 5. [-orfCol COLUMN_NUMBER] Column for ORF type to add to results
# 6. [-delim TAB/CSV]        Field separator character for read.table - default ""
# 7. [-batch]                If present, uses "batch" from sample table

library(yaml)
library(DESeq2)
library(IHW)
library(ashr)
library(apeglm)

library(dplyr)
library(tibble)
library(purrr)

library(openxlsx)

# ---------------------------------------------------------
## Functions

run_analysis_fit_sub <- function(contrast, coldata, dds) {

    # fit one given contrast

    # subset the DESeqDataSet for the selected contrast
    num.condition <- as.character(contrast[1])
    denom.condition <- as.character(contrast[2]) # reference

    coldata.sub <- coldata %>% dplyr::filter(condition %in% c(num.condition, denom.condition))
    dds.sub <- dds[,colnames(dds) %in% coldata.sub$sampleName]
    dds.sub$condition <- relevel(dds.sub$condition, ref=denom.condition)
    dds.sub$condition <- droplevels(dds.sub$condition)

    stopifnot(all(rownames(colData(dds.sub)) == colnames(dds.sub)))

    # fit the model for this contrast only
    dds.sub <- DESeq(dds.sub)

    coef <- resultsNames(dds.sub)[length(resultsNames(dds.sub))]
    print(paste0("Testing for: ", coef, " @ lfcThreshold: ", lfcThreshold.set, " and alpha: ", alpha.set))
    res <- results(dds.sub,
                   contrast=c("condition", num.condition, denom.condition), # name=coef,
                   lfcThreshold=lfcThreshold.set,
                   altHypothesis=altHypothesis.set,
                   alpha=alpha.set,
                   filterFun=ihw)
    res$padj[is.na(res$padj)] <- 1
    res <- lfcShrink(dds.sub, coef=coef, type="apeglm", res=res)

    # write to disk
    write_results(dds.sub, res, as.character(contrast[3]))
}


write_results <- function(dds, res, contrast) {

    res <- res %>%
        data.frame() %>%
        rownames_to_column(var="id") %>%
        as_tibble()

    # add raw counts, annotations and re-order columns
    cts <- counts(dds, normalized=FALSE) %>%
        data.frame() %>%
        rownames_to_column(var="id") %>%
        as_tibble()

    res$symbol <- id.mapping

    if (!is.null(args$orfCol)) {
        res$orfType <- orfType.mapping
        res <- res %>%
            dplyr::select(id, symbol, orfType, baseMean,
                    log2FoldChange, pvalue, padj)
    } else {
        res <- res %>%
                dplyr::select(id, symbol, baseMean,
                        log2FoldChange, pvalue, padj)
    }
    res <- res %>% left_join(cts, by="id")

    ## write to disk
    wb <- createWorkbook()

    # if sheetName is too long this won't work...
    addWorksheet(wb, sheetName=contrast)
    writeDataTable(wb, sheet=1, x=res)

    created <- ifelse(!dir.exists(dirloc.out), dir.create(dirloc.out, recursive=TRUE), FALSE)
    if (!created) { print(paste0("Using existing directory ", dirloc.out, " ...")) }

    dirloc.sub <- file.path(dirloc.out, contrast, fsep=.Platform$file.sep)
    created <- ifelse(!dir.exists(dirloc.sub), dir.create(dirloc.sub, recursive=TRUE), FALSE)
    if (!created) { print(paste0("Using existing directory ", dirloc.sub, " ...")) }
    filen <- file.path(dirloc.sub, paste0(contrast, ".xlsx"), fsep=.Platform$file.sep)
    print(paste("Writing results to: ", filen, sep=""))
    saveWorkbook(wb, filen, overwrite = TRUE)
}


read_symbols <- function(samples, dirloc, col, sep="") {
    # read gene symbols/other column from HTSeq output tables
    # same as DESeqDataSetFromHTSeqCount_delim, but for a given column
    l <- lapply( as.character( samples$fileName ), function(fn) read.table( file.path( dirloc, fn ), sep=sep, fill=TRUE ))
    if( ! all( sapply( l, function(a) all( head(a[,col], -5) == head(l[[1]][,col], -5) ) ) ) )
        stop( paste0("Column ", col, " differ between files!" ))
    tbl <- sapply( l, function(a) a[,col] )
    colnames(tbl) <- samples$sampleName
    rownames(tbl) <- l[[1]]$V1
    oldSpecialNames <- c("no_feature","ambiguous","too_low_aQual","not_aligned","alignment_not_unique")
    specialRows <- (substr(rownames(tbl),1,1) == "_") | rownames(tbl) %in% oldSpecialNames
    tbl <- tbl[ !specialRows, , drop=FALSE ]
    # in fact we just need the first columnn - format as mapIds
    tbl[,1]
}

# same as https://rdrr.io/bioc/DESeq2/src/R/AllClasses.R except for sep passed to read.table
DESeqDataSetFromHTSeqCount_delim <- function(sampleTable, directory=".", design, sep="", ignoreRank=FALSE, ...)
{
  if (missing(design)) stop("design is missing")
  l <- lapply( as.character( sampleTable[,2] ), function(fn) read.table( file.path( directory, fn ), sep=sep, fill=TRUE ) )
  if( ! all( sapply( l, function(a) all( a$V1 == l[[1]]$V1 ) ) ) )
    stop( "Gene IDs (first column) differ between files." )
  # select last column of 'a', works even if htseq was run with '--additional-attr'
  tbl <- sapply( l, function(a) a[,ncol(a)] )
  colnames(tbl) <- sampleTable[,1]
  rownames(tbl) <- l[[1]]$V1
  rownames(sampleTable) <- sampleTable[,1]
  oldSpecialNames <- c("no_feature","ambiguous","too_low_aQual","not_aligned","alignment_not_unique")
  # either starts with two underscores
  # or is one of the old special names (htseq-count backward compatability)
  specialRows <- (substr(rownames(tbl),1,1) == "_") | rownames(tbl) %in% oldSpecialNames
  tbl <- tbl[ !specialRows, , drop=FALSE ]
  object <- DESeqDataSetFromMatrix(countData=tbl, colData=sampleTable[,-(1:2),drop=FALSE], design=design, ignoreRank, ...)
  return(object)
}


# ---------------------------------------------------------
## Call

# arguments
defaults <- list(lfcThreshold=log2(1.2), alpha=0.05, symbolCol=2, orfCol=NULL,
                 delim="", batch=FALSE)
args <- R.utils::commandArgs(trailingOnly=TRUE, asValues=TRUE, defaults=defaults)
if (is.null(args$config)) {
    stop("<-config CONFIG> [-lfcThreshold L2F] [-alpha ALPHA] [-symbolCol COLUMN] [-orfCol COLUMN_NUMBER]
        [-delim TAB/CSV] [-batch]\n",
         call.=FALSE)
}
# delimiter
if (tolower(args$delim) == "tab") {
    sep = "\t"
} else if (tolower(args$delim) == "csv") {
    sep = ","
} else {
    sep = ""
}
# defaults for calling DESEq results
lfcThreshold.set <- args$lfcThreshold
altHypothesis.set <- "greaterAbs"
alpha.set <- args$alpha
# config
params <- yaml::read_yaml(args$config)
# output directory
dirloc.out <- params$tea_data
# sample and count tables
if (!is.null(params$sample_table)) {
    filen <- params$sample_table
} else {
    project <- ""
    if (!is.null(params$project_name)) {
        project <- params$project_name
        project <- paste0("-", project)
    }
    filen <- file.path(dirloc.out, paste0("sample-table", project, ".csv"))
}
print(paste0("Using ", filen, " as sample table..."))
samples <- read.delim(filen, sep=",")
# wrangle
samples <- as.data.frame(samples) # need dataframe for DESEq
samples$condition <- factor(samples$condition)
if ("assay" %in% colnames(samples)) {
    samples$assay <- tolower(samples$assay)
    samples$assay <- factor(samples$assay)
    if (nlevels(samples$assay) > 1) {
        stop("Sample table contains more than 1 level for assay!\n", call.=FALSE)
    }
    samples$assay <- NULL
}
from_htseq <- TRUE
if (!is.null(params$count_table)) {
    from_htseq <- FALSE
    counts <- read.delim(params$count_table, sep=",", row.names=1)
    print(paste0("Using ", params$count_table, " as count table..."))
} else {
    # write tmp files to
    tmp.loc <- uuid::UUIDgenerate()
    dirloc.tmp <- file.path(dirloc.out, tmp.loc, fsep=.Platform$file.sep)
    dir.create(dirloc.tmp)
    # link all files for DESeq
    for (from in samples$fileName) {
        to <- file.path(dirloc.tmp, basename(from), fsep=.Platform$file.sep)
        file.symlink(from, to)
    }
    # now clean names to remove path
    samples$fileName <- basename(samples$fileName)
}
# check if we use batch
if (args$batch) {
    samples$batch <- factor(samples$batch)
    if (nlevels(samples$batch)<2) {
        args$batch <- FALSE
        print("Ignoring <-batch>! Use 'batch' as column header, and at least 2 levels!")
    }
}
# create DESEq data object
model <- ~condition
if (args$batch) {
    model <- ~batch+condition
}
if (from_htseq) {
    ddsAll <- DESeqDataSetFromHTSeqCount_delim(samples,
                                               directory=dirloc.tmp,
                                               design=model,
                                               sep=sep)
    # annotation
    # the order should be the same...
    id.mapping <- read_symbols(samples, dirloc.tmp, as.integer(args$symbolCol), sep=sep)
    if (!is.null(args$orfCol)) {
        # add ORF types
        orfType.mapping <- read_symbols(samples, dirloc.tmp, as.integer(args$orfCol), sep=sep)
    }
    # remove files
    unlink(dirloc.tmp, recursive=TRUE)
} else {
    ddsAll <- DESeqDataSetFromMatrix(countData=counts,
                                     colData=samples,
                                     design=model)
    id.mapping <- rownames(counts)
    names(id.mapping) <- id.mapping
}
stopifnot(all(samples$sampleName == colnames(ddsAll)))

## fit contrasts
contrasts <- data.frame(t(sapply(params$contrast, c)))
contrasts$name <- rownames(contrasts)
apply(contrasts, 1, run_analysis_fit_sub, coldata=samples, dds=ddsAll)
print("Done!")
