from os import listdir
from os.path import isfile

import gzip

def getConditions():
    return set([seqRun['modules']['condition'] for seqRun in config['entries'].values()])

def list_of_all_features(gff_path):
    features = set()
    if gff_path.endswith(".gz"):
        f2r = gzip.open(gff_path, "rt")
    else:
        f2r = open(gff_path, "r")

    for line in f2r.readlines():
        if line.startswith("#"): #skip lines starting with #
            continue
        if not line or line.isspace():
            continue
        features.add(line.strip().split("\t")[2]) #feature
    f2r.close()
  
    # Features to check for either user selected or Curare pre-selected
    if "%%ADDITIONAL_FEATCOUNTS_TABLES%%":
        features = features.intersection("%%ADDITIONAL_FEATCOUNTS_TABLES%%".split(","))
    else: 
        allowed_features = ["CDS", "exon", "intron", "gene", "mRNA", "tRNA", "rRNA", "ncRNA", "operon", "snRNA", "snoRNA",
            "miRNA", "pseudogene", "small regulatory ncRNA", "rasiRNA", "guide RNA", "siRNA", "stRNA", "sRNA"]
        features = features.intersection(allowed_features)

    return list(features)

def get_main_feature_name():
    use_parent_as_id = %%USE_PARENT_INSTEAD_OF_ID%%
    if use_parent_as_id:
        return "%%GFF_FEATURE_PARENT%%"
    else:
        return "%%GFF_FEATURE_NAME%%"

rule all:
    input:
        "analysis/dge_analysis_edgeR/counts.txt",
        "analysis/dge_analysis_edgeR/edgeR.RData",
        expand('analysis/dge_analysis_edgeR/summary/{COND}.xlsx', COND=getConditions()),
        ".report/modules/dge_analysis_edgeR.html",
        expand("analysis/dge_analysis_edgeR/gene_body_coverage/{feature}/{feature}.geneBodyCoverage.txt", feature = list_of_all_features("%%GFF_PATH%%"))

rule summary_tsv_to_xslx:
    input:
        'analysis/dge_analysis_edgeR/summary/{COND}.tsv'
    output:
        'analysis/dge_analysis_edgeR/summary/{COND}.xlsx'
    conda:
        "../lib/conda_env.yaml"
    params:
        number_conditions = len(getConditions())-1
    group:
        "dge_analysis_edgeR"
    shell:
        "python3 lib/edgeR_summary_tsv_to_xlsx.py --tsv \"{input}\" --conditions \"{params.number_conditions}\" --gff '%%GFF_PATH%%' --identifier \"%%GFF_FEATURE_NAME%%\" "
        "--attributes \"%%ATTRIBUTE_COLUMNS%%\" --output \"{output}\""

rule make_count_tables:
    input:
        mappings=expand("mapping/{A}.bam", A=config['entry_order'])
    output:
        featcounts="analysis/dge_analysis_edgeR/count_tables/{feature}.txt",
        featcounts_summary="analysis/dge_analysis_edgeR/count_tables/{feature}.txt.summary"
    conda:
        "../lib/conda_env.yaml"
    group:
        "count_reads"
    threads:
        1
    log:
        log="analysis/dge_analysis_edgeR/logs/count_tables/{feature}.log"
    shell:
        "featureCounts -T {threads} %%STRAND_SPECIFICITY%% %%ADDITIONAL_FEATCOUNTS_OPTIONS%% -t {wildcards.feature} -g '%%GFF_FEATURE_NAME%%' -a %%GFF_PATH%% -o '{output.featcounts}' {input.mappings} 2>&1 | "
        "tee {log.log};"

rule collect_count_tables:
    input:
        count_tables=expand("analysis/dge_analysis_edgeR/count_tables/{feature}.txt", feature = list_of_all_features("%%GFF_PATH%%"))
    output:
        "analysis/dge_analysis_edgeR/count_tables/.count_tables"
    log:
        log="analysis/dge_analysis_edgeR/logs/count_tables/collect_count_tables.log"
    group:
        "count_reads"
    shell:
        "for i in {input.count_tables}; do echo $i >> {output}; done"

rule visualize_assignments:
    input:
        "analysis/dge_analysis_edgeR/count_tables/.count_tables"
    output:
        directory('analysis/dge_analysis_edgeR/visualization/feature_assignments')
    params:
        input="analysis/dge_analysis_edgeR/count_tables/"
    conda:
        "../lib/conda_env.yaml"
    shell:
        "mkdir -p {output}; python3 lib/visualize_assignments.py -i {params.input} -o {output}"

rule count_reads:
    input:
        mappings=expand("mapping/{A}.bam", A=config['entry_order'])
    output:
        counts = "analysis/dge_analysis_edgeR/counts.txt",
        summary = "analysis/dge_analysis_edgeR/counts.txt.summary"
    params:
        feature_name = get_main_feature_name()
    conda:
        "../lib/conda_env.yaml"
    group:
        "count_reads"
    threads:
        1
    log:
        log="analysis/dge_analysis_edgeR/logs/count_tables/counts.log"
    shell:
        "featureCounts -T {threads} %%STRAND_SPECIFICITY%% -t %%GFF_FEATURE_TYPE%% -g {params.feature_name} -a %%GFF_PATH%% %%ADDITIONAL_FEATCOUNTS_OPTIONS%% -o '{output.counts}' {input.mappings} 2>&1 | "
        "tee {log.log};"

rule create_conditions:
    output:
        "analysis/dge_analysis_edgeR/conditions.txt"
    run:
        with(open(output[0], 'w')) as output_file:
            for name in config['entry_order']:
                output_file.write('{}\t{}\n'.format(name, config['entries'][name]['modules']['condition']))

rule edgeR_normalize_counts:
    input:
        count_table = "analysis/dge_analysis_edgeR/counts.txt",
        conditions = "analysis/dge_analysis_edgeR/conditions.txt",
        feature_counts_log = "analysis/dge_analysis_edgeR/counts.txt.summary",
    output:
        dump="analysis/dge_analysis_edgeR/edgeR.RData",
        correlation_heatmap="analysis/dge_analysis_edgeR/visualization/correlation_heatmap.svg",
        mds="analysis/dge_analysis_edgeR/visualization/mds.svg",
        img_assignment_rel="analysis/dge_analysis_edgeR/visualization/counts_assignment_relative.svg",
        img_assignment_abs="analysis/dge_analysis_edgeR/visualization/counts_assignment_absolute.svg",
        counts_rpkm="analysis/dge_analysis_edgeR/counts_rpkm.txt",
        counts_cpm="analysis/dge_analysis_edgeR/counts_cpm.txt",
        summary=expand('analysis/dge_analysis_edgeR/summary/{COND}.tsv', COND=getConditions()),
        comparisons=directory("analysis/dge_analysis_edgeR/edgeR_comparisons")
    params:
        output_dir='analysis/dge_analysis_edgeR/'
    conda:
        "../lib/conda_env.yaml"
    group:
        "dge_analysis_edgeR"
    log:
        log="analysis/dge_analysis_edgeR/logs/edgeR_normalize_counts.log"
    threads:
        1
    shell:
        "R --vanilla --file=lib/edgeR_analysis_normalize_counts.R --args --threads {threads} --count-table {input.count_table} --conditions {input.conditions} --featcounts-log {input.feature_counts_log} --output {params.output_dir} --r-data {output.dump} "
        "--featcounts-log {input.feature_counts_log} 2>&1 | tee -a {log.log}"

rule index_bam:
    input:
        bam="mapping/{bam}.bam"
    output:
        bam=temp("analysis/dge_analysis_edgeR/gene_body_coverage/{bam}.bam"),
        bai=temp("analysis/dge_analysis_edgeR/gene_body_coverage/{bam}.bam.bai")
    conda:
        "../lib/conda_env.yaml"
    shell:
        "ln -s ../../../{input.bam} {output.bam} && samtools index {output.bam}"

rule convert_gff_to_bed:
    input:
        reference = "%%GFF_PATH%%"
    output:
        bed = "analysis/dge_analysis_edgeR/gene_body_coverage/ref_{feature_type}.bed"
    params:
        reverse_strand = lambda wildcards: "-r" if "%%STRAND_SPECIFICITY%%" == "-s 2" else ""
    conda:
        "../lib/conda_env.yaml"
    shell:
        "python3 lib/gff_to_rseqc_bed.py {params.reverse_strand} --gff {input.reference} --type {wildcards.feature_type} > {output.bed}"

rule gene_body_coverage:
    input:
        ref="analysis/dge_analysis_edgeR/gene_body_coverage/ref_{feature_type}.bed",
        bam_files=expand("analysis/dge_analysis_edgeR/gene_body_coverage/{name}.bam", name=config['entries'].keys()),
        bam_index=expand("analysis/dge_analysis_edgeR/gene_body_coverage/{name}.bam.bai", name=config['entries'].keys())
    output:
        "analysis/dge_analysis_edgeR/gene_body_coverage/{feature_type}/{feature_type}.geneBodyCoverage.txt"
    params:
        in_bam=lambda wildcards, input: ",".join(input.bam_files),
        out_dir="analysis/dge_analysis_edgeR/gene_body_coverage/{feature_type}/{feature_type}",
        reverse_strand = lambda wildcards: "0" if "%%STRAND_SPECIFICITY%%" == "-s 2" else "1",
        warning="analysis/dge_analysis_edgeR/gene_body_coverage/{feature_type}/WARNING.txt"
    conda:
        "../lib/conda_env.yaml"
    shell:
        "geneBody_coverage.py -r {input.ref} -i {params.in_bam} -o {params.out_dir};"
        "if [ {params.reverse_strand} -eq 0 ]; then echo -e \"Warning\nThis module was run with \\\"reversely stranded\\\" settings. The plots must be interpreted as 3' to 5' and not as labeled!\" > {params.warning}; fi"

rule generate_report_data:
    input:
        stats="analysis/dge_analysis_edgeR/counts.txt.summary",
        comparisons="analysis/dge_analysis_edgeR/edgeR_comparisons",
        img_assignment_rel="analysis/dge_analysis_edgeR/visualization/counts_assignment_relative.svg",
        img_assignment_abs="analysis/dge_analysis_edgeR/visualization/counts_assignment_absolute.svg",
        img_mds="analysis/dge_analysis_edgeR/visualization/mds.svg",
        correlation_heatmap="analysis/dge_analysis_edgeR/visualization/correlation_heatmap.svg",
        feature_assignment="analysis/dge_analysis_edgeR/visualization/feature_assignments",
        count_table="analysis/dge_analysis_edgeR/counts.txt"
    output:
        dge_analysis_data=".report/data/dge_analysis_edgeR_data.js",
        dge_analysis_html=".report/modules/dge_analysis_edgeR.html",
        dge_analysis_js=".report/js/modules/dge_analysis_edgeR.js",
        dge_analysis_img_assignment_rel=".report/img/modules/dge_analysis_edgeR/counts_assignment_relative.svg",
        dge_analysis_img_assignment_abs=".report/img/modules/dge_analysis_edgeR/counts_assignment_absolute.svg",
        dge_analysis_img_mds='.report/img/modules/dge_analysis_edgeR/mds.svg',
        dge_analysis_img_correlation='.report/img/modules/dge_analysis_edgeR/correlation_heatmap.svg',
        dge_analysis_img_feature_assignment=directory('.report/img/modules/dge_analysis_edgeR/feature_assignment')
    params:
        visualization="analysis/dge_analysis_edgeR/visualization"
    conda:
        "../lib/conda_env.yaml"
    group:
        "dge_analysis_edgeR_report"
    shell:
        "python3 lib/generate_report_data.py --fc_stats {input.stats} --fc_main_feature '%%GFF_FEATURE_TYPE%%' --comparison_dir {input.comparisons} --visualization {params.visualization} --output {output.dge_analysis_data} --counttable '{input.count_table}' && "
        "cp lib/report/dge_analysis_edgeR.html {output.dge_analysis_html} && "
        "cp lib/report/dge_analysis_edgeR.js {output.dge_analysis_js} && "
        "cp {input.img_assignment_rel} {output.dge_analysis_img_assignment_rel} && "
        "cp {input.img_assignment_abs} {output.dge_analysis_img_assignment_abs} && "
        "cp {input.img_mds} {output.dge_analysis_img_mds} && "
        "cp {input.correlation_heatmap} {output.dge_analysis_img_correlation} && "
        "mkdir -p {output.dge_analysis_img_feature_assignment} && cp {input.feature_assignment}/*.svg {output.dge_analysis_img_feature_assignment}"
