import glob
import os


####LOADING VARIABLES####

#load config file
configfile: "experiment.yaml"

#load variables
library = config["library"]
fasta = config["lib_info"][library]["fasta"]
index = config["lib_info"][library]["index"]
sg_length = config["lib_info"][library]["sg_length"]

#check if read files need to be renamed
if "rename" in config and not os.path.exists("rename.done"):
    OLD_FILES = list(config["rename"].keys())
    NEW_FILES = list(config["rename"].values())

    #create sample names
    SAMPLES = [x.replace(".fq.gz","") for x in NEW_FILES]
else:
    #create sample names
    SAMPLES = [os.path.basename(x).replace(".fq.gz","") for x in glob.glob("reads/*fq.gz")]
    #remove any pre-existing trimmed fq files from this list
    SAMPLES = [x for x in SAMPLES if not "_trimmed" in x]

#load mageck settings
COMPARISONS = [value.replace(",","-") for (key,value) in config["stats"]["comparisons"].items()]

#set targets for target rule
TARGETS = [expand("qc/fastqc/{sample}.html", sample=SAMPLES),
        expand("qc/fastqc/{sample}_fastqc.zip", sample=SAMPLES),
        "qc/multiqc.html",
        "count/alignment-rates.pdf",
        "count/sequence-coverage.pdf",
        expand("mageck/{comparison}.gene_summary.txt", comparison=COMPARISONS),
        expand("mageck/{comparison}.sgrna_summary.txt", comparison=COMPARISONS),
        expand("mageck_flute/{comparison}.volcano.pdf", comparison=COMPARISONS),
        expand("mageck_flute/{comparison}.dot_pos.pdf", comparison=COMPARISONS),
        expand("mageck_flute/{comparison}.dot_neg.pdf", comparison=COMPARISONS),
        expand("mageck_flute/{comparison}.sgrank.pdf", comparison=COMPARISONS)
        ]



####SNAKEMAKE RULES####

#report location
report: "report/report.rst"

#rules to be run on login node instead of compute node when running on HPC (only very small jobs)
localrules: all, rename, join, plot_alignment_rate, plot_coverage  

#extend TARGETS with renamed files if required
if "rename" in config and not os.path.exists("rename.done"):
    TARGETS.extend(expand("reads/{new_file}", new_file=NEW_FILES)),
    TARGETS.append("rename.done")


#set target rule
rule all:
    input:
        TARGETS


if "rename" in config and not os.path.exists("rename.done"):
    rule rename:
        input:
            expand("reads/{old_file}", old_file=OLD_FILES)
        output:
            f=expand("reads/{new_file}", new_file=NEW_FILES),
            t=touch("rename.done"),
        run:
            for o,n in zip(list({input})[0], list({output.f})[0]):
                os.rename(o,n)


rule fastqc:
    input:
        "reads/{sample}_trimmed.fq.gz"
    output:
        html="qc/fastqc/{sample}.html",
        zip="qc/fastqc/{sample}_fastqc.zip" # the suffix _fastqc.zip is necessary for multiqc to find the file
    params:
        extra = "--quiet"
    log:
        "logs/fastqc/{sample}.log"
    threads: config["resources"]["fastqc"]["cpu"]
    resources:
        runtime=config["resources"]["fastqc"]["time"]
    wrapper:
        "v1.31.1/bio/fastqc"


rule multiqc:
    input:
        expand("qc/fastqc/{sample}_fastqc.zip", sample=SAMPLES)
    output:
        report("qc/multiqc.html", caption="report/multiqc.rst", category="MultiQC analysis of fastq files")
    params:
        extra="",  # Optional: extra parameters for multiqc.
        use_input_files_only=True, # Optional, use only a.txt and don't search folder samtools_stats for files
    resources:
        runtime=config["resources"]["fastqc"]["time"]
    log:
        "logs/multiqc/multiqc.log"
    wrapper:
        "v1.31.1/bio/multiqc"


rule trim:
    input:
        "reads/{sample}.fq.gz"
    output:
        temp("reads/{sample}_trimmed.fq.gz")
    params:
        sgl=sg_length,
        lt=config["left_trim"],
    threads: config["resources"]["trim"]["cpu"]
    conda:
        "envs/trim.yaml"
    log:
        "logs/trim/{sample}.log",
    resources:
        runtime=config["resources"]["trim"]["time"]
    shell:
        "cutadapt -j {threads} --quality-base 33 -u {params.lt} -l {params.sgl} -o {output} {input}"


rule count:
    input: 
        "reads/{sample}_trimmed.fq.gz"
    output:
        "count/{sample}.guidecounts.txt"
    params:
        idx=index,
        mm=config["mismatch"],
    threads: config["resources"]["count"]["cpu"]
    resources:
        runtime=config["resources"]["count"]["time"]
    log:
        "logs/count/{sample}.log"
    conda:
        "envs/count.yaml"
    shell:
        "zcat {input} | hisat2 --no-hd -p {threads} -t -N {params.mm} -x {params.idx} - 2> {log} | "
        "sed '/XS:/d' | cut -f3 | sort | uniq -c | sed 's/^ *//' | sed '1d' > {output}"


rule join:
    input:
        files=expand("count/{sample}.guidecounts.txt", sample=SAMPLES)
    output:
        "count/counts-aggregated.tsv"
    params:
        fa=fasta,
    script:
        "scripts/join.py"


rule mageck:
    input: 
        "count/counts-aggregated.tsv"
    output:
        "mageck/{comparison}_summary.Rnw",
        report("mageck/{comparison}.gene_summary.txt", caption="report/mageck.rst", category="MAGeCK"),
        "mageck/{comparison}.log",
        "mageck/{comparison}.R",
        #"mageck/{comparison}.report.Rmd",
        "mageck/{comparison}.sgrna_summary.txt",
        "mageck/{comparison}.normalized.txt"
    resources:
        runtime=config["resources"]["mageck"]["time"]
    conda:
        "envs/mageck.yaml"
    log:
        "logs/mageck/{comparison}.log"
    shell:
        '''
        mageck test --normcounts-to-file -k {input} -t $(echo "{wildcards.comparison}" | sed 's/_vs_.*//' | sed 's/-/,/') -c $(echo "{wildcards.comparison}" | sed 's/^[^_vs_]*_vs_//' | sed 's/-/,/') -n mageck/{wildcards.comparison} 2> {log}
        '''


rule mageck_flute:
    input:
        "mageck/{comparison}.gene_summary.txt",
        "mageck/{comparison}.sgrna_summary.txt"
    output:
        report("mageck_flute/{comparison}.volcano.pdf", caption="report/volcano.rst", category="MAGeCK Flute", subcategory="{comparison}", labels={"Comparison":"{comparison}","Figure": "volcano plot"}),
        report("mageck_flute/{comparison}.dot_pos.pdf", caption="report/dot-plot_pos.rst", category="MAGeCK Flute", subcategory="{comparison}", labels={"Comparison":"{comparison}","Figure": "dot plot enriched genes"}),
        report("mageck_flute/{comparison}.dot_neg.pdf", caption="report/dot-plot_neg.rst", category="MAGeCK Flute", subcategory="{comparison}", labels={"Comparison":"{comparison}","Figure": "dot plot depleted genes"}),
        report("mageck_flute/{comparison}.sgrank.pdf", caption="report/sgrank.rst", category="MAGeCK Flute", subcategory="{comparison}", labels={"Comparison":"{comparison}","Figure": "sgRNA rank"}),
    params:
        spc=config["lib_info"][library]["species"],
    resources:
        runtime=config["resources"]["mageck"]["time"]
    conda:
        "envs/flute.yaml"
    script:
        "scripts/flute.R"


rule plot_alignment_rate:
    input:
        expand("logs/count/{sample}.log", sample=SAMPLES)
    output:
        report("count/alignment-rates.pdf", caption="report/alignment-rates.rst", category="Alignment rates")
    params:
        name="plot_alignment_rate",
    script:
        "scripts/plot.py"


rule plot_coverage:
    input:
        "count/counts-aggregated.tsv",
    params:
        name="plot_coverage",
        fa=fasta,
    output:
        report("count/sequence-coverage.pdf", caption="report/plot-coverage.rst", category="Sequence coverage")
    script:
        "scripts/plot.py"







