#! /usr/bin/env Rscript

# Usage: run-dea [--help] [--batch] [--lfcThreshold LFCTHRESHOLD] [--alpha ALPHA] [--symbolCol SYMBOLCOL]
# [--orfCol ORFCOL] [--delim {TAB,CSV}] [--lfcShrinkMethod {apeglm,ashr}] config
# config                            Yaml config file (e.g. config used for 'run-htseq-workflow')
# [--batch]                         Use 'batch' from sample table
# [--lfcThreshold LFCTHRESHOLD]     LFC threshold (default: log2(1.2))
# [--alpha ALPHA]                   FDR threshold (default: 0.05)
# [--symbolCol SYMBOLCOL]           Column for features (symbol/names) -  HTSeq workflow only (default: 2)
# [--orfCol ORFCOL]                 Column for ORF types -  HTSeq workflow only
# [--delim {TAB,CSV}]               Field separator for the count table (default: TAB)
# [--lfcShrinkMethod {apeglm,ashr}] LFC shrinkage method (default: apeglm)

suppressPackageStartupMessages(library(argparser))
suppressPackageStartupMessages(library(yaml))
suppressPackageStartupMessages(library(DESeq2))
suppressPackageStartupMessages(library(IHW))
suppressPackageStartupMessages(library(ashr))
suppressPackageStartupMessages(library(apeglm))
suppressPackageStartupMessages(library(dplyr))
suppressPackageStartupMessages(library(tibble))
suppressPackageStartupMessages(library(purrr))
suppressPackageStartupMessages(library(openxlsx))

# ---------------------------------------------------------
## Functions

check_keys <- function(required, keys, func) {
  func(sapply(required, function(key) key %in% keys))
}

run_analysis_fit_sub <- function(contrast, coldata, dds) {
  # fit one given contrast
  # subset the DESeqDataSet for the selected contrast
  num.condition <- as.character(contrast[1])
  denom.condition <- as.character(contrast[2]) # reference

  coldata.sub <- coldata %>% dplyr::filter(condition %in% c(num.condition, denom.condition))
  dds.sub <- dds[,colnames(dds) %in% coldata.sub$sampleName]
  dds.sub$condition <- relevel(dds.sub$condition, ref=denom.condition)
  dds.sub$condition <- droplevels(dds.sub$condition)

  stopifnot(all(rownames(colData(dds.sub)) == colnames(dds.sub)))

  # fit the model for this contrast only
  dds.sub <- DESeq(dds.sub)

  coef <- resultsNames(dds.sub)[length(resultsNames(dds.sub))]
  cat("Testing for: ", coef, " @ lfcThreshold: ", lfcThreshold.set, " and alpha: ", alpha.set, "\n")
  res <- results(dds.sub,
                 name=coef,
                 lfcThreshold=lfcThreshold.set,
                 altHypothesis=altHypothesis.set,
                 alpha=alpha.set,
                 filterFun=ihw)
  res$padj[is.na(res$padj)] <- 1
  res <- lfcShrink(dds.sub, coef=coef, type=shrink.type, res=res)

  # write to disk
  write_results(dds.sub, res, as.character(contrast[3]))
}

write_results <- function(dds, res, contrast) {
  res <- res %>%
      data.frame() %>%
      rownames_to_column(var="id") %>%
      as_tibble()

  # add raw counts, annotations and re-order columns
  cts <- counts(dds, normalized=FALSE) %>%
      data.frame() %>%
      rownames_to_column(var="id") %>%
      as_tibble()

  res$symbol <- id.mapping

  if (!is.null(args$orfCol) & from_htseq) {
    res$orfType <- orfType.mapping
    res <- res %>%
        dplyr::select(id, symbol, orfType, baseMean,
                log2FoldChange, pvalue, padj)
  } else {
    res <- res %>%
            dplyr::select(id, symbol, baseMean,
                    log2FoldChange, pvalue, padj)
  }
  res <- res %>% left_join(cts, by="id")

  ## write to disk
  wb <- createWorkbook()
  addWorksheet(wb, sheetName=substr(contrast,1,30))
  writeDataTable(wb, sheet=1, x=res)

  created <- ifelse(!dir.exists(dirloc.out), dir.create(dirloc.out, recursive=TRUE), FALSE)
  if (!created) { cat("Info: using existing directory ", dirloc.out, " ...\n") }

  dirloc.sub <- file.path(dirloc.out, contrast, fsep=.Platform$file.sep)
  created <- ifelse(!dir.exists(dirloc.sub), dir.create(dirloc.sub, recursive=TRUE), FALSE)
  if (!created) { cat("Info: using existing directory ", dirloc.sub, " ...\n") }
  filen <- file.path(dirloc.sub, paste0(contrast, ".xlsx"), fsep=.Platform$file.sep)
  cat("Info: writing results to: ", filen, "\n")
  saveWorkbook(wb, filen, overwrite = TRUE)
}

read_symbols <- function(samples, dirloc, col, sep="") {
  # read gene symbols/other column from HTSeq output tables
  # same as DESeqDataSetFromHTSeqCount_delim, but for a given column
  l <- lapply( as.character( samples$fileName ), function(fn) read.table( file.path( dirloc, fn ), sep=sep, fill=TRUE ))
  if( ! all( sapply( l, function(a) all( head(a[,col], -5) == head(l[[1]][,col], -5) ) ) ) )
      stop( paste0("Column ", col, " differ between files!" ))
  tbl <- sapply( l, function(a) a[,col] )
  colnames(tbl) <- samples$sampleName
  rownames(tbl) <- l[[1]]$V1
  oldSpecialNames <- c("no_feature","ambiguous","too_low_aQual","not_aligned","alignment_not_unique")
  specialRows <- (substr(rownames(tbl),1,1) == "_") | rownames(tbl) %in% oldSpecialNames
  tbl <- tbl[ !specialRows, , drop=FALSE ]
  # in fact we just need the first columnn - format as mapIds
  tbl[,1]
}

# same as https://rdrr.io/bioc/DESeq2/src/R/AllClasses.R except for sep passed to read.table
DESeqDataSetFromHTSeqCount_delim <- function(sampleTable, directory=".", design, sep="", ignoreRank=FALSE, ...)
{
  if (missing(design)) stop("design is missing")
  l <- lapply( as.character( sampleTable[,2] ), function(fn) read.table( file.path( directory, fn ), sep=sep, fill=TRUE ) )
  if( ! all( sapply( l, function(a) all( a$V1 == l[[1]]$V1 ) ) ) )
    stop( "Gene IDs (first column) differ between files." )
  # select last column of 'a', works even if htseq was run with '--additional-attr'
  tbl <- sapply( l, function(a) a[,ncol(a)] )
  colnames(tbl) <- sampleTable[,1]
  rownames(tbl) <- l[[1]]$V1
  rownames(sampleTable) <- sampleTable[,1]
  oldSpecialNames <- c("no_feature","ambiguous","too_low_aQual","not_aligned","alignment_not_unique")
  # either starts with two underscores
  # or is one of the old special names (htseq-count backward compatability)
  specialRows <- (substr(rownames(tbl),1,1) == "_") | rownames(tbl) %in% oldSpecialNames
  tbl <- tbl[ !specialRows, , drop=FALSE ]
  object <- DESeqDataSetFromMatrix(countData=tbl, colData=sampleTable[,-(1:2),drop=FALSE], design=design, ignoreRank, ...)
  return(object)
}


# ---------------------------------------------------------
## Call

# arguments
parser <- arg_parser("Run differential expression (DE)", hide.opts = TRUE)
parser <- add_argument(parser, "config", help="Yaml config file (same as used for 'run-htseq-workflow')", type="character")
parser <- add_argument(parser, "--lfcThreshold", help="LFC threshold", type="numeric", default=log(1.2))
parser <- add_argument(parser, "--alpha", help="FDR threshold", type="numeric", default=0.05)
parser <- add_argument(parser, "--symbolCol", help="Column for features (symbol/names) - HTSeq workflow only", type="integer", default=2)
parser <- add_argument(parser, "--orfCol", help="Column for ORF types - HTSeq workflow only", type="integer", default=NULL)
parser <- add_argument(parser, "--delim", help="Field separator for the count tables {TAB,CSV}", type="character", default="TAB")
parser <- add_argument(parser, "--lfcShrinkMethod", help="LFC shrinkage method {apeglm,ashr}", type="character", default="apeglm")
parser <- add_argument(parser, "--batch", help="Use 'batch' from sample table", flag=TRUE)
args <- parse_args(parser)

# force default to NULL
if (is.na(args$orfCol)) { args$orfCol <- NULL }

# delimiter
if (tolower(args$delim) == "tab") {
  sep = "\t"
} else if (tolower(args$delim) == "csv") {
  sep = ","
} else {
  sep = ""
}

# LFC shrinkage
if (tolower(args$lfcShrinkMethod) == "ashr") {
  shrink.type = "ashr"
} else {
  shrink.type = "apeglm"
}

# defaults for calling DESEq results
lfcThreshold.set <- args$lfcThreshold
altHypothesis.set <- "greaterAbs"
alpha.set <- args$alpha

# config
params <- yaml::read_yaml(args$config)
keys <- c("dea_data", "contrasts")
if (!check_keys(keys, names(params), all)) {
  cat("Error: one or more keys were not found: ", paste(keys, collapse=", "), "\nExecution halted\n")
  quit(save = "no", status = 0, runLast = FALSE)
}

# output directory
dirloc.out <- params$dea_data

# sample and count tables
if (!is.null(params$sample_table)) {
  filen <- params$sample_table
} else {
  project <- ""
  if (!is.null(params$project_name)) {
    project <- params$project_name
    project <- paste0("-", project)
  }
  filen <- file.path(dirloc.out, paste0("sample-table", project, ".csv"))
}
cat("Info: using ", filen, " as sample table...\n")
samples <- read.delim(filen, sep=",")
if (ncol(samples)==1) {
  cat("Error: sample table does not appear to be CSV-formatted\nExecution halted\n")
  quit(save = "no", status = 0, runLast = FALSE)
}

# wrangle
samples <- as.data.frame(samples) # need dataframe for DESEq
samples$condition <- factor(samples$condition)
if ("assay" %in% colnames(samples)) {
  samples$assay <- tolower(samples$assay)
  samples$assay <- factor(samples$assay)
  if (nlevels(samples$assay) > 1) {
    cat("Error: sample table contains more than 1 level for assay\nExecution halted\n")
    quit(save = "no", status = 0, runLast = FALSE)
  }
  samples$assay <- NULL
}
from_htseq <- TRUE
if (!is.null(params$count_table)) {
  from_htseq <- FALSE
  counts <- read.delim(params$count_table, sep=sep, row.names=1, check.names=FALSE)
  cat("Info: using ", params$count_table, " as count table...\n")
} else {
  # write tmp files to
  tmp.loc <- uuid::UUIDgenerate()
  dirloc.tmp <- file.path(dirloc.out, tmp.loc, fsep=.Platform$file.sep)
  dir.create(dirloc.tmp)
  # link all files for DESeq
  for (from in samples$fileName) {
    to <- file.path(dirloc.tmp, basename(from), fsep=.Platform$file.sep)
    file.symlink(from, to)
  }
  # now clean names to remove path
  samples$fileName <- basename(samples$fileName)
}
# check if we use batch
if (args$batch) {
  msg <- "Info: adding batch effect...\n"
  tryCatch({
    samples$batch <- factor(samples$batch)
  }, error = function(e) {
    msg <- "Warning: '--batch' set to FALSE: use 'batch' as column header, and at least 2 levels\n"
    args$batch <- FALSE
  })
  if (nlevels(samples$batch)<2) {
    msg <- "Warning: '--batch' set to FALSE: use 'batch' as column header, and at least 2 levels\n"
    args$batch <- FALSE
  }
  cat(msg)
}
# create DESEq data object
model <- ~condition
if (args$batch) {
  model <- ~batch+condition
}
if (from_htseq) {
  ddsAll <- DESeqDataSetFromHTSeqCount_delim(samples,
                                             directory=dirloc.tmp,
                                             design=model,
                                             sep=sep)
  # annotation
  # the order should be the same...
  id.mapping <- read_symbols(samples, dirloc.tmp, as.integer(args$symbolCol), sep=sep)
  if (!is.null(args$orfCol)) {
    # add ORF types
    orfType.mapping <- read_symbols(samples, dirloc.tmp, as.integer(args$orfCol), sep=sep)
  }
  # remove files
  unlink(dirloc.tmp, recursive=TRUE)
} else {
  ddsAll <- DESeqDataSetFromMatrix(countData=counts,
                                   colData=samples,
                                   design=model)
  id.mapping <- rownames(counts)
  names(id.mapping) <- id.mapping
}
stopifnot(all(samples$sampleName == colnames(ddsAll)))

## fit contrasts
cat("Fitting contrasts...\n")
contrasts <- data.frame(t(sapply(params$contrast, c)))
contrasts$name <- rownames(contrasts)
apply(contrasts, 1, run_analysis_fit_sub, coldata=samples, dds=ddsAll)
cat("... done!\n")
