from glob import glob
import collections
import os
from singlecellmultiomics.utils import get_contig_list_from_fasta, is_main_chromosome
from singlecellmultiomics.bamProcessing import get_read_group_format
"""
This workflow:

"""
################## configuration ##################
configfile: "config.json"
# config
contigs = []
has_misc = False
for contig in get_contig_list_from_fasta(config['reference_file']):
    if is_main_chromosome(contig):
        contigs.append(contig)
    else:
        has_misc=True
if has_misc:
    contigs.append('MISC_ALT_CONTIGS_SCMO')


libraries = [ g.split('/')[-2] for g in glob('./*/sorted.bam') ]
# Obtain contigs:
rule all:
    input:
        b = expand("{library}/recallibration/recallibrated.scrg.bam", library=libraries),
        i = expand("{library}/recallibration/recallibrated.scrg.bam.bai", library=libraries)



rule SCMO_tagmultiome_NLA_parallel_scatter:
    input:
        bam = "{library}/sorted.bam",
        bam_index = "{library}/sorted.bam.bai",
        mapfile = config['mapfile']

    output:
        bam = temp("{library}/processed/TEMP_CONTIG/{contig}.bam"),
        bam_index = temp("{library}/processed/TEMP_CONTIG/{contig}.bam.bai"),
        stats = "{library}/processed/stats_{contig}.tsv"

    log:
        stdout="{library}/log/tag_scatter/{contig}.stdout",
        stderr="{library}/log/tag_scatter/{contig}.stderr"


    threads: 1
    params:
        runtime="60h",
        alleles = config['alleles'],
        ref = config['reference_file']
    resources:
        mem_mb=lambda wildcards, attempt: attempt * 10000

    shell:
        "bamtagmultiome.py -ref {params.ref} -molecule_iterator_verbosity_interval 10 -stats_file_path {output.stats} -read_group_format 1 -method nla -alleles {params.alleles} -mapfile {input.mapfile} -contig {wildcards.contig} {input.bam} -o {output.bam} > {log.stdout} 2> {log.stderr}"


rule SCMO_tagmultiome_NLA_parallel_gather:
    input:
        chr_bams =  expand("{{library}}/processed/TEMP_CONTIG/{contig}.bam", contig=contigs),
        chr_bams_indices =  expand("{{library}}/processed/TEMP_CONTIG/{contig}.bam.bai", contig=contigs)
    output:
        bam = "{library}/processed/tagged.bam",
        bam_index = "{library}/processed/tagged.bam.bai"
    log:
        stdout="{library}/log/tag_gather/library.stdout",
        stderr="{library}/log/tag_gather/library.stderr"

    threads: 1
    params: runtime="8h"
    message:
        'Merging contig BAM files'

    shell:
        "samtools merge -c {output.bam} {input.chr_bams} > {log.stdout} 2> {log.stderr}; samtools index {output.bam}"


rule gatk_BaseRecalibrator:
    input:
        scbam='{library}/processed/tagged.bam',

    output:
        recall_table = "{library}/recallibration/recall_table.tsv"

    log:
        stdout="{library}/log/BaseRecalibrator/BaseRecalibrator.stdout",
        stderr="{library}/log/BaseRecalibrator/BaseRecalibrator.stderr"

    threads: 1
    params:
        runtime="60h",
        reference=config['reference_file'],
        gatk_path=config['gatk_path'],
        known_variants_vcf=config['known_variants_vcf'],
        covariate_radius=config['covariate_radius']


    resources:
        mem_mb = lambda wildcards, attempt, input: attempt * 10000,
        runtime = lambda wildcards, attempt, input: attempt * 24

    shell:
        "{params.gatk_path} BaseRecalibrator \
        --known-sites {params.known_variants_vcf} \
        --reference {params.reference} \
         --input {input.scbam} \
         --output {output.recall_table} \
         -mcs {params.covariate_radius} \
         > {log.stdout} 2> {log.stderr} \
         ;"

rule SCMO_bamMatchGATKBQSRReport:
    input:
        scbam='{library}/processed/tagged.bam',
        recall_table='{library}/recallibration/recall_table.tsv'
    output:
        recall_matched_bam = temp("{library}/recallibration/recall_matched.bam"),
        recall_matched_bam_index = temp("{library}/recallibration/recall_matched.bam.bai"),

    log:
        stdout="{library}/log/matchGATKBQSRReport/matchGATKBQSRReport.stdout",
        stderr="{library}/log/matchGATKBQSRReport/matchGATKBQSRReport.stderr"

    threads: 1
    params:
        runtime="60h",
        covariate_radius=config['covariate_radius']

    resources:
        mem_mb = lambda wildcards, attempt, input: attempt * 10000,
        runtime = lambda wildcards, attempt, input: attempt * 24

    shell:
        "bamMatchGATKBQSRReport.py {input.scbam} \
        {input.recall_table} -o {output.recall_matched_bam} \
         > {log.stdout} 2> {log.stderr}; samtools index {output.recall_matched_bam}"


rule gatk_BaseRecalibratorApply:
    input:
        scbam='{library}/recallibration/recall_matched.bam',
        scbam_index='{library}/recallibration/recall_matched.bam.bai',
        recall_table='{library}/recallibration/recall_table.tsv',

    output:
        recall_bam = temp("{library}/recallibration/recallibrated.bam"),
        recall_bam_index = temp("{library}/recallibration/recallibrated.bam.bai")

    log:
        stdout="{library}/log/ApplyBQSR/ApplyBQSR.stdout",
        stderr="{library}/log/ApplyBQSR/ApplyBQSR.stderr"

    threads: 1
    params:
        runtime="60h",
        gatk_path=config['gatk_path'],
        covariate_radius=config['covariate_radius']

    resources:
        mem_mb = lambda wildcards, attempt, input: attempt * 10000,
        runtime = lambda wildcards, attempt, input: attempt * 24

    shell:
        "{params.gatk_path} ApplyBQSR \
        --bqsr-recal-file {input.recall_table} \
         --input {input.scbam} \
         --output {output.recall_bam} \
         > {log.stdout} 2> {log.stderr}; samtools index {output.recall_bam} \
         ;"


rule revert_read_groups_to_SC:
    input:
        recall_bam='{library}/recallibration/recallibrated.bam',
        recall_bam_index='{library}/recallibration/recallibrated.bam.bai'

    output:
        # File with single cell read groups
        scrg_bam = "{library}/recallibration/recallibrated.scrg.bam",
        scrg_bam_i = "{library}/recallibration/recallibrated.scrg.bam.bai"

    log:
        stdout="{library}/log/SCMO_RG/SCMO_RG.stdout",
        stderr="{library}/log/SCMO_RG/SCMO_RG.stderr"

    threads: 1
    params:
        runtime="60h",
        format=0

    resources:
        mem_mb = lambda wildcards, attempt, input: attempt * 10000,
        runtime = lambda wildcards, attempt, input: attempt * 24

    shell:
        "bamReadGroupFormat.py {input.recall_bam} -o {output.scrg_bam} -format {params.format} > {log.stdout} 2> {log.stderr}; samtools index {output.scrg_bam} "
