from pathlib import Path

import pandas as pd
from snakemake.utils import min_version

# Configuration

min_version("5.1.2")

configfile: "config/config.yaml"

sample_files = pd.read_csv(config["samples"],sep='\t').set_index("Sample",drop=False)
samples = sample_files.index.tolist()
reference_abs = Path(config["reference"]).absolute()

# Workflow

rule all:
    input:
        "gdi-input.fofn",
        "mlst.tsv" if config["include_mlst"] else []


rule prepare_reference:
    input:
        reference=reference_abs,
    output:
        reference="reference/reference.fasta",
        reference_mm2_index="reference/reference.fasta.mmi"
    conda:
        "envs/main.yaml"
    threads: 1
    log:
        faidx="logs/prepare_reference.faidx.log",
        mm2_index="logs/prepare_reference.mm2_index.log"
    shell:
        "ln -s {input.reference} {output.reference} && "
        "samtools faidx {output.reference} 2>&1 1>{log.faidx} && "
        "minimap2 -t {threads} -d {output.reference_mm2_index} {output.reference} 2>&1 1>{log.mm2_index}"


rule assembly_align:
    input:
        reference_mm2_index="reference/reference.fasta.mmi",
        sample=lambda wildcards: sample_files.loc[wildcards.sample]['File'],
    output:
        "align/{sample}.bam",
    conda:
        "envs/main.yaml"
    threads: 1
    log:
        mm2="logs/assembly_align.{sample}.minimap2.log",
        samsort="logs/assembly_align.{sample}.samsort.log",
    shell:
        "minimap2 -t {threads} -a -x asm5 {input.reference_mm2_index} {input.sample} 2> {log.mm2} | "
        "samtools sort --threads {threads} --output-fmt BAM --write-index -o {output} 2> {log.samsort}"


rule assembly_variant:
    input:
        reference="reference/reference.fasta",
        bam="align/{sample}.bam",
    output:
        "variant/{sample}.vcf.gz",
    conda:
        "envs/main.yaml"
    threads: 1
    log:
        mpileup="logs/assembly_variant.{sample}.mpileup.log",
        call="logs/assembly_variant.{sample}.call.log",
    shell:
        "bcftools mpileup --threads {threads} -Ou -f {input.reference} {input.bam} 2> {log.mpileup} | "
        "bcftools call --threads {threads} --ploidy 1 -Oz -mv 1> {output}"


rule assembly_consensus:
    input:
        reference="reference/reference.fasta",
        bam="align/{sample}.bam",
    output:
        "consensus/{sample}.fasta.gz",
    conda:
        "envs/main.yaml"
    log:
        consensus="logs/assembly_consensus.{sample}.consensus.log",
    shell:
        "htsbox pileup -f {input.reference} -d -F {input.bam} 2> {log.consensus} | "
        "gzip -c - > {output}"


rule sourmash_sketch:
    input:
        sample=lambda wildcards: sample_files.loc[wildcards.sample]['File'],
    output:
        "sketch/{sample}.sig.gz"
    conda:
        "envs/main.yaml"
    threads: 1
    params:
        sourmash_params=config["sourmash_params"],
    log:
        sketch="logs/sourmash_sketch.{sample}.sketch.log"
    shell:
        "sourmash sketch dna -p {params.sourmash_params} --name {wildcards.sample} "
        "--output {output}.tmp {input.sample} 2>{log.sketch} && "
        "gzip {output}.tmp && mv {output}.tmp.gz {output}"


rule basic_mlst:
    input:
        sample=lambda wildcards: sample_files.loc[wildcards.sample]['File'],
    output:
        "mlst/{sample}.tsv"
    conda:
        "envs/mlst.yaml"
    threads: 1
    log:
        mlst="logs/basic_mlst.{sample}.log"
    shell:
        "mlst --threads {threads} --nopath {input.sample} > {output} 2> {log.mlst}"


rule gdi_input_fofn:
    input:
        sample_vcfs=expand("variant/{sample}.vcf.gz",sample=samples),
        sample_consensus=expand("consensus/{sample}.fasta.gz",sample=samples),
        sample_sketches=expand("sketch/{sample}.sig.gz",sample=samples) if config["include_kmer"] else []
    log:
        "logs/gdi_input_fofn.log"
    output:
        "gdi-input.fofn"
    script:
        "scripts/gdi-create-fofn.py"


rule mlst_full_table:
    input:
        sample_vcfs=expand("mlst/{sample}.tsv",sample=samples),
    log:
        "logs/mlst_full_table.log"
    output:
        "mlst.tsv"
    shell:
        "cat {input} > {output}"
