from collections import namedtuple, defaultdict
from configparser import ConfigParser
from snakemake.utils import R
import glob
import os
import numpy as np
import pysam

configfile: "config.json"

SAMPLES = set()
FILES_R1 = dict()
FILES_R2 = dict()
DISEASES = dict()
TRACKS = dict()
GROUP = dict()
SAMPLES_OF_GROUP = defaultdict(list)

genome = config["genome"]
ref = config["ref"]
dbsnp = config["dbsnp"]
track_path = config["track_path"]
raw = config["raw_path"]

for line in open("units.txt"):
    if line.startswith("#"): #skip comments
        continue

    if len(line) < 2: #skip empty lines
        continue

    try:
        samplename, disease, r1, r2, track, group = line.strip("\n").split("\t")
    except:
        print(line)
        raise

    if group == "":
        group = "_" + samplename

    SAMPLES.add(samplename)
    FILES_R1[samplename] = sorted(glob.glob(os.path.join(raw,r1)))
    FILES_R2[samplename] = sorted(glob.glob(os.path.join(raw,r2)))
    DISEASES[samplename] = disease
    TRACKS[samplename] = track
    GROUP[samplename] = group
    SAMPLES_OF_GROUP[group].append(samplename)

rule all:
    input:
#        expand("results/metrices/{S}.meancoverage.txt", S=SAMPLES),
#        expand("results/metrices/{S}.shortcoveragehistogramm.txt", S=SAMPLES),
#        expand("results/metrices/{S}.capturekitbasecount.txt", S=SAMPLES),
        expand("results/meta/{S}.txt", S=SAMPLES),
        expand("bam/{S}.bam", S=SAMPLES),
        expand("bam/{S}.bam.bai", S=SAMPLES),

        expand("vcf/{G}.vcf", G=GROUP.values()),
        expand("snps/{S}.h5", S=SAMPLES),
#        expand("meta/{S}.coveragehist.txt", S=SAMPLES),


rule AnnotateMeta:
    input:
        sample="tmp/merged_h5/{S}.h5",
        meta="results/meta/{S}.txt",
    output:
        "snps/{S}.h5"
    shell:
        '''
        python3 scripts/meta.py {input.sample} --storelist {input.meta}
        mv {input.sample} {output}
        '''


rule MergeH5:
    input:
        snps="tmp/snps/{S}.h5",
        delly="tmp/delly_h5/{S}.h5",
    output:
        "tmp/merged_h5/{S}.h5",
    shell:
        "python3 scripts/merge_h5.py {input.snps} {input.delly} {output}"


rule BuildMeta:
    input:
        coveragehistogramm="results/metrices/{S}.capturekitmeancoverage.txt",
        chromosome_stats="results/metrices/{S}.chromosomestats.csv",
        overall_basecount="results/metrices/{S}.overallbasecount.txt",
        capturekit_basecount="results/metrices/{S}.capturekitbasecount.txt",
        capturekit_shorthistogramm = "results/metrices/{S}.shortcoveragehistogramm.txt",
    output:
        meta="results/meta/{S}.txt"
    params:
        disease = lambda wc: DISEASES[wc.S],
        genome = genome
    run:
        mean_coverage = float(open(input.coveragehistogramm).readline().strip())
        chrom_readcounts = [(x.split()[0], x.split()[2]) for x in open(input.chromosome_stats, "r").readlines() if not x.startswith("#")]
        overall_basecount = int(float(open(input.overall_basecount).readline().strip()))
        capturekit_basecount = int(float(open(input.capturekit_basecount).readline().strip()))
        on_target_rate = capturekit_basecount / overall_basecount if overall_basecount else 0
        capturekit_shorthistogramm = [x.split() for x in open(input.capturekit_shorthistogramm, "r").readlines()]

        o = open(output.meta, "w")
        print("exome_sequencing_capture_kit_coverage", mean_coverage, sep="\t", file=o)
        print("exome_sequencing_ontargetrate", on_target_rate, sep="\t", file=o)
        print("exome_sequencing_mappedbases", overall_basecount, sep="\t", file=o)
        print("disease", params.disease, sep="\t", file=o)
        print("genome", params.genome, sep="\t", file=o)
        for chrom, count in chrom_readcounts:
            print("{chrom}/exome_sequencing_readcount".format(chrom=chrom.upper()), count, sep="\t", file=o)
        for x, count in chrom_readcounts:
            print("exome_coverage_greater_{x}".format(x=x), count, sep="\t", file=o)


rule CaptureKitShortHistogramm:
    input:
        "results/metrices/{S}.coveragehistogramm.txt"
    output:
        "results/metrices/{S}.shortcoveragehistogramm.txt"
    run:
        shell('> {output}')
        if os.stat(input[0]).st_size > 0:
            for i in [1, 5, 10, 20, 50, 100]:
                shell("""cat {input} | awk 'BEGIN {{OFS = "\t"}} {{total+=$2; if ($1 >= {i}){{s += $2}}}} END {{print {i}, s / total}}' >> {output}""")
        else:
            for i in [1, 5, 10, 20, 50, 100]:
                shell("echo '{i}\t0' >> {output}")


rule CaptureKitMeanCoverage:
    input:
        "results/metrices/{S}.coveragehistogramm.txt"
    output:
        "results/metrices/{S}.capturekitmeancoverage.txt"
    log:
        "log/metrices/{S}.capturekitmeancoverage.log"
    shell:
        """cat {input} | awk '{{total+=$2; mult+=$1*$2}} END {{print mult / total}}' > {output}"""


rule CoverageHistogramm:
    input:
        "results/metrices/{S}.capturekitbasestats.cov.gz"
    output:
        "results/metrices/{S}.coveragehistogramm.txt"
    log:
        "log/metrices/{S}.coveragehistogramm.log"
    run:
        if os.stat(input[0]).st_size > 0:
            shell("""zcat {input} | cut -f 3 | tail -n +2 | sort -n | uniq -c | awk 'BEGIN {{OFS = "\t"}} {{print $2, $1}}'> {output}""")
        else:
            shell("echo -e '0\t1' > {output}")


rule OverallBaseCount:
    input:
        'results/metrices/{S}.chromosomestats.csv'
    output:
        'results/metrices/{S}.overallbasecount.txt'
    log:
        'log/metrices/{S}.metrices.overallbasecount.log'
    shell:
        '''cat {input} | awk '{{count += ($3-$2)*$5}} END {{print count}}' > {output}'''


rule CaptureKitBaseCount:
    input:
         "results/metrices/{S}.coveragehistogramm.txt"
    output:
         "results/metrices/{S}.capturekitbasecount.txt"
    log:
        'log/metrices/{S}.metrices.capturekitbasecount.log'
    shell:
        '''cat {input} | awk '{{total+=$1*$2}} END {{print total}}' > {output}'''


rule CaptureKitBaseStats:
    threads: 2
    input:
        bam="bam/{S}.bam",
        track= lambda wc: [] if TRACKS[wc.S] == 'none' else os.path.join(track_path, TRACKS[wc.S])
    output:
        "results/metrices/{S}.capturekitbasestats.cov.gz"
    log:
        "log/metrices/{S}.capturekitbasestats.txt"
    run:
        if input.track:
            shell('sambamba depth base -t {threads} -c 0 -m -L {input.track} {input.bam} 2> {log} | pigz -c > {output}')
        else:
            shell("> {output}") #zero byte file


rule ChromosomeStats:
    threads: 2
    input:
        bam="bam/{S}.bam",
        bai="bam/{S}.bam.bai",
        region="tmp/regions/{S}.regions.txt"
    output:
        'results/metrices/{S}.chromosomestats.csv',
    log:
        'log/metrices/{S}.allcoverage.log'
    shell:
        "sambamba depth region -L {input.region} -m -t {threads} {input.bam} > {output} 2> {log}"


rule CreateChromosomeRegionFile:
    input:
        "bam/{S}.bam",
    output:
        temp("tmp/regions/{S}.regions.txt")
    shell:
        '''sambamba view -H {input} | grep @SQ | awk 'BEGIN{{OFS="\\t"}}{{print substr($2,4,100), 0, substr($3, 4, 100)}}' > {output}'''


rule DuplicateReadCount:
    threads: 10
    input:
        'bam/{S}.bam'
    output:
        'results/metrices/{S}.duplicatecount.txt'
    log:
        'log/metrices/{S}.duplicatecount.log'
    shell:
        'sambamba view -c -F "duplicate" -t {threads} {input} > {output} 2> {log}'


rule ConvertToH5:
    input:
        lambda wc: "vcf/{G}.vcf".format(G=GROUP[wc.S])
    output:
        "tmp/snps/{S}.h5"
    log:
        "log/{S}.log"
    shell:
        "python3 scripts/convert_vcf_2.py {input} --samples {wildcards.S} --output {output} > {log}"


rule NormVCFandContext:
    input:
        vcf="tmp/vcf/{G}.vcf",
        ref=ref
    output:
        "tmp/norm_vcf/{G}.vcf"
    shell:
        "python3 scripts/annotatecontext.py {input.vcf} {input.ref} | bcftools norm -f {input.ref} - > {output}"


rule SnpEffandSnpSift:
    threads: 6
    input:
        "tmp/norm_vcf/{G}.vcf"
    output:
        "vcf/{G}.vcf"
    params:
        genome = genome
    shell:
        "cat {input} | grep -v hs37d5 | snpeff -noStats -t {params.genome} - | snpsift annotate -tabix {dbsnp} - > {output}"


rule ConvertToH5_Delly:
    input:
        lambda wc: expand("tmp/vcf_delly_annotated/{G}.{T}.vcf", G=GROUP[wc.S], T=["INS", "DEL", "INV"])
    output:
        "tmp/delly_h5/{S}.h5"
    shell:
        "python3 scripts/convert_vcf_delly.py {input} {wildcards.S} {output}"


rule SnpEff_Delly:
    threads: 32
    input:
        "tmp/vcf_delly/{G}.vcf"
    output:
        "tmp/vcf_delly_annotated/{G}.vcf"
    log:
        "log/snpeff/{G}.snpeff"
    params:
        genome = genome
    shell:
        "cat {input} | grep -v hs37d5 | java -Xmx100g -jar /vol/tools/lib/snpEff-4.2/snpEff.jar -noStats -t {params.genome} - > {output} 2> {log}"


rule IndelCall:
    threads: 3
    input:
        bam=lambda wc: expand("bam/{S}.bam", S=SAMPLES_OF_GROUP[wc.G]),
        bai=lambda wc: expand("bam/{S}.bam.bai", S=SAMPLES_OF_GROUP[wc.G])
    output:
        "tmp/vcf_delly/{G}.{T}.vcf",
    log:
        "log/{G}.delly.log"
    shell:
        """
        export OMP_NUM_THREADS={threads}
        delly_parallel_0.7.1 -g {ref} {input.bam} -t {wildcards.T} -o {output} > {log}
        touch {output}
        """


rule SnpCall:
    threads: 60
    input:
        bam=lambda wc: expand("bam/{S}.bam", S=SAMPLES_OF_GROUP[wc.G]),
        bai=lambda wc: expand("bam/{S}.bam.bai", S=SAMPLES_OF_GROUP[wc.G])
    output:
        "tmp/vcf/{G}.vcf",
    log:
        "log/calling/{G}.freebayes.log"
    shell:
        './scripts/freebayes-parallel <(fasta_generate_regions.py {ref}.fai 10000000) {threads} -u -f {ref} {input.bam} --min-alternate-fraction 0.10 2> {log} | vcffilter -f "QUAL > 10" > {output}'
        #TODO: parameters!


rule Index:
    threads: 20
    input:
        "{S}.bam"
    output:
        "{S}.bam.bai"
    shell:
        "sambamba index -t {threads} {input}"


rule Map:
    threads:
        32
    input:
        ref = ref,
        index = ref+".fai",
        fastq_r1 = lambda wc: FILES_R1[wc.S],
        fastq_r2 = lambda wc: FILES_R2[wc.S]
    output:
        tmp_r1 = temp("pipes/{S}.R1.fastq.gz"),
        tmp_r2 = temp("pipes/{S}.R2.fastq.gz"),
        bam = "bam/{S}.bam"
    params:
        RG="\"@RG\\tID:{S}\\tSM:{S}\\tPL:Illumina\""
    log:
        "log/{S}.bwa.log"
    shell:
        """
        mkfifo {output.tmp_r1}
        mkfifo {output.tmp_r2}
        cat {input.fastq_r1} > {output.tmp_r1} &
        cat {input.fastq_r2} > {output.tmp_r2} &
        bwa mem -R {params.RG} -M -t {threads} {input.ref} {output.tmp_r1} {output.tmp_r2} | mbuffer -q \
        | samblaster -M | sambamba view -t {threads} -S -f bam -h /dev/stdin | sambamba sort --tmpdir tmp /dev/stdin -m 20GB -t {threads} -o {output.bam}
        """


rule BWAIndex:
    input:
        ref
    output:
        index = ref+".fai"
    shell:
        """
        bwa index {input}
        samtools faidx {input}
        """
