import glob
import os


####LOADING VARIABLES####

#load config file
configfile: "experiment.yaml"

#load variables
library = config["library"]
fasta = config["lib_info"][library]["fasta"]
index = config["lib_info"][library]["index"]
sg_length = config["lib_info"][library]["sg_length"]
stats = config["stats"]["type"]

#check if read files need to be renamed
if "rename" in config and not os.path.exists("rename.done"):
    
    OLD_FILES = list(config["rename"].keys())
    NEW_FILES = list(config["rename"].values())

    #create sample names
    SAMPLES = [x.replace(".fq.gz","") for x in NEW_FILES]
else:
    
    #create sample names
    SAMPLES = [os.path.basename(x).replace(".fq.gz","") for x in glob.glob("reads/*fq.gz")]
    #remove any pre-existing trimmed fq files from this list
    SAMPLES = [x for x in SAMPLES if not "_trimmed" in x]

#load stats comparisons
COMPARISONS = [value.replace(",","-") for (key,value) in config["stats"]["comparisons"].items()] #commas are not compatible with snakemake report

if stats == "bagel2": #remove comparisons with pooled control samples (not supported by BAGEL2)
    
    COMPARISONS = [x for x in COMPARISONS if not "-" in x.split("_vs_")[1]]


#set targets for target rule
TARGETS = [expand("qc/fastqc/{sample}.html", sample=SAMPLES),
            expand("qc/fastqc/{sample}_fastqc.zip", sample=SAMPLES),
            "qc/multiqc.html",
            "count/alignment-rates.pdf",
            "count/sequence-coverage.pdf"
            ]

if stats == "mageck":
    
    TARGETS.extend([
        expand("mageck/{comparison}.gene_summary.txt", comparison=COMPARISONS),
        expand("mageck/{comparison}.sgrna_summary.txt", comparison=COMPARISONS),
        expand("mageck_flute/{comparison}.volcano.pdf", comparison=COMPARISONS),
        expand("mageck_flute/{comparison}.dot_pos.pdf", comparison=COMPARISONS),
        expand("mageck_flute/{comparison}.dot_neg.pdf", comparison=COMPARISONS),
        expand("mageck_flute/{comparison}.sgrank.pdf", comparison=COMPARISONS)
    ])

elif stats == "bagel2":
    TARGETS.extend([
       "count/counts-aggregated-bagel2.tsv",
        expand("bagel2/{comparison}.foldchange", comparison=COMPARISONS),
        expand("bagel2/{comparison}.bf", comparison=COMPARISONS),
        expand("bagel2/{comparison}.pr", comparison=COMPARISONS),
        expand("bagel2_plots/{comparison}.bf.pdf", comparison=COMPARISONS),
        expand("bagel2_plots/{comparison}.pr.pdf", comparison=COMPARISONS),
    ])



####SNAKEMAKE RULES####

#report location
report: "report/report.rst"

#rules to be run on login node instead of compute node when running on HPC (only very small jobs)
localrules: all, rename, join, plot_alignment_rate, plot_coverage, plot_bf, plot_pr

#extend TARGETS with renamed files if required
if "rename" in config and not os.path.exists("rename.done"):
    
    TARGETS.extend(expand("reads/{new_file}", new_file=NEW_FILES)),
    TARGETS.append("rename.done")


#set target rule
rule all:
    input:
        TARGETS


if "rename" in config and not os.path.exists("rename.done"):
    
    rule rename:
        input:
            expand("reads/{old_file}", old_file=OLD_FILES)
        output:
            f=expand("reads/{new_file}", new_file=NEW_FILES),
            t=touch("rename.done"),
        run:
            for o,n in zip(list({input})[0], list({output.f})[0]):
                os.rename(o,n)


rule fastqc:
    input:
        "reads/{sample}_trimmed.fq.gz"
    output:
        html="qc/fastqc/{sample}.html",
        zip="qc/fastqc/{sample}_fastqc.zip" # the suffix _fastqc.zip is necessary for multiqc to find the file
    params:
        extra = "--quiet"
    log:
        "logs/fastqc/{sample}.log"
    threads: config["resources"]["fastqc"]["cpu"]
    resources:
        runtime=config["resources"]["fastqc"]["time"]
    wrapper:
        "v1.31.1/bio/fastqc"


rule multiqc:
    input:
        expand("qc/fastqc/{sample}_fastqc.zip", sample=SAMPLES)
    output:
        report("qc/multiqc.html", caption="report/multiqc.rst", category="MultiQC analysis of fastq files")
    params:
        extra="",  # Optional: extra parameters for multiqc.
        use_input_files_only=True, # Optional, use only a.txt and don't search folder samtools_stats for files
    resources:
        runtime=config["resources"]["fastqc"]["time"]
    log:
        "logs/multiqc/multiqc.log"
    wrapper:
        "v1.31.1/bio/multiqc"


rule trim:
    input:
        "reads/{sample}.fq.gz"
    output:
        temp("reads/{sample}_trimmed.fq.gz")
    params:
        sgl=sg_length,
        lt=config["lib_info"][library]["left_trim"],
    threads: config["resources"]["trim"]["cpu"]
    conda:
        "envs/trim.yaml"
    log:
        "logs/trim/{sample}.log",
    resources:
        runtime=config["resources"]["trim"]["time"]
    shell:
        "cutadapt -j {threads} --quality-base 33 -u {params.lt} -l {params.sgl} -o {output} {input}"


rule count:
    input: 
        "reads/{sample}_trimmed.fq.gz"
    output:
        "count/{sample}.guidecounts.txt"
    params:
        idx=index,
        mm=config["mismatch"],
    threads: config["resources"]["count"]["cpu"]
    resources:
        runtime=config["resources"]["count"]["time"]
    log:
        "logs/count/{sample}.log"
    conda:
        "envs/count.yaml"
    shell:
        "zcat {input} | hisat2 --no-hd -p {threads} -t -N {params.mm} -x {params.idx} - 2> {log} | "
        "sed '/XS:/d' | cut -f3 | sort | uniq -c | sed 's/^ *//' | sed '1d' > {output}"


rule join:
    input:
        files=expand("count/{sample}.guidecounts.txt", sample=SAMPLES)
    output:
        "count/counts-aggregated.tsv"
    params:
        fa=fasta,
    script:
        "scripts/join.py"


if stats == "mageck":
    
    rule mageck:
        input: 
            "count/counts-aggregated.tsv"
        output:
            "mageck/{comparison}_summary.Rnw",
            report("mageck/{comparison}.gene_summary.txt", caption="report/mageck.rst", category="MAGeCK"),
            "mageck/{comparison}.log",
            "mageck/{comparison}.R",
            "mageck/{comparison}.sgrna_summary.txt",
            "mageck/{comparison}.normalized.txt"
        resources:
            runtime=config["resources"]["stats"]["time"]
        conda:
            "envs/stats.yaml"
        log:
            "logs/mageck/{comparison}.log"
        shell:
            '''
            mageck test --normcounts-to-file -k {input} -t $(echo "{wildcards.comparison}" | sed 's/_vs_.*//' | sed 's/-/,/') -c $(echo "{wildcards.comparison}" | sed 's/^[^_vs_]*_vs_//' | sed 's/-/,/') -n mageck/{wildcards.comparison} 2> {log}
            '''


    rule mageck_flute:
        input:
            "mageck/{comparison}.gene_summary.txt",
            "mageck/{comparison}.sgrna_summary.txt"
        output:
            report("mageck_flute/{comparison}.volcano.pdf", caption="report/volcano.rst", category="MAGeCK Flute", subcategory="{comparison}", labels={"Comparison":"{comparison}","Figure": "volcano plot"}),
            report("mageck_flute/{comparison}.dot_pos.pdf", caption="report/dot-plot_pos.rst", category="MAGeCK Flute", subcategory="{comparison}", labels={"Comparison":"{comparison}","Figure": "dot plot enriched genes"}),
            report("mageck_flute/{comparison}.dot_neg.pdf", caption="report/dot-plot_neg.rst", category="MAGeCK Flute", subcategory="{comparison}", labels={"Comparison":"{comparison}","Figure": "dot plot depleted genes"}),
            report("mageck_flute/{comparison}.sgrank.pdf", caption="report/sgrank.rst", category="MAGeCK Flute", subcategory="{comparison}", labels={"Comparison":"{comparison}","Figure": "sgRNA rank"}),
        params:
            spc=config["lib_info"][library]["species"]
        resources:
            runtime=config["resources"]["stats"]["time"]
        conda:
            "envs/flute.yaml"
        script:
            "scripts/flute.R"


elif stats == "bagel2":
    
    rule convert_count_table:
        input:
            "count/counts-aggregated.tsv"
        output:
            "count/counts-aggregated-bagel2.tsv"
        params:
            fa=fasta,
            b2dir=config["stats"]["bagel2_dir"]
        resources:
            runtime=config["resources"]["stats"]["time"]
        conda:
            "envs/stats.yaml"
        script:
            "scripts/convert_count_table.py"
    
    
    rule bagel2fc:
        input:
            "count/counts-aggregated-bagel2.tsv"
        output:
            "bagel2/{comparison}.foldchange"
        params:
            b2dir=config["stats"]["bagel2_dir"]
        resources:
            runtime=config["resources"]["stats"]["time"]
        conda:
            "envs/stats.yaml"
        log:
            "logs/bagel2/fc_{comparison}.log"
        script:
            "scripts/bagel2fc.py"


    rule bagel2bf:
        input:
            "bagel2/{comparison}.foldchange"
        output:
            "bagel2/{comparison}.bf"
        params:
            b2dir=config["stats"]["bagel2_dir"],
            species=config["lib_info"][library]["species"],
        resources:
            runtime=config["resources"]["stats"]["time"]
        conda:
            "envs/stats.yaml"
        log:
            "logs/bagel2/bf_{comparison}.log"
        script:
            "scripts/bagel2bf.py"


    rule bagel2pr:
        input:
            "bagel2/{comparison}.bf"
        output:
            report("bagel2/{comparison}.pr", caption="report/bagel2.rst", category="BAGEL2")
        params:
            b2dir=config["stats"]["bagel2_dir"],
            species=config["lib_info"][library]["species"]
        resources:
            runtime=config["resources"]["stats"]["time"]
        conda:
            "envs/stats.yaml"
        log:
            "logs/bagel2/pr_{comparison}.log"
        script:
            "scripts/bagel2pr.py"


    rule plot_bf:
        input:
            "bagel2/{comparison}.bf"
        output:
            report("bagel2_plots/{comparison}.bf.pdf", caption="report/bagel2_plots.rst", category="BAGEL2 plots", subcategory="{comparison}", labels={"Comparison":"{comparison}", "Figure":"BF plot"})
        conda:
            "envs/stats.yaml"
        script:
            "scripts/plot_bf.py"


    rule plot_pr:
        input:
            "bagel2/{comparison}.pr"
        output:
            report("bagel2_plots/{comparison}.pr.pdf", caption="report/bagel2_plots.rst", category="BAGEL2 plots", subcategory="{comparison}", labels={"Comparison":"{comparison}", "Figure":"Precision-recall plot"})
        conda:
            "envs/stats.yaml"
        script:
            "scripts/plot_pr.py"


rule plot_alignment_rate:
    input:
        expand("logs/count/{sample}.log", sample=SAMPLES)
    output:
        report("count/alignment-rates.pdf", caption="report/alignment-rates.rst", category="Alignment rates")
    params:
        name="plot_alignment_rate",
    script:
        "scripts/plot.py"


rule plot_coverage:
    input:
        "count/counts-aggregated.tsv"
    params:
        name="plot_coverage",
        fa=fasta,
    output:
        report("count/sequence-coverage.pdf", caption="report/plot-coverage.rst", category="Sequence coverage")
    script:
        "scripts/plot.py"







