
from snakemake.utils import Paramspace, validate, min_version
import pandas as pd
from os import makedirs
import os.path
from aux.checkConfig import readDefaults, makeConfig
from aux.checkPatients import checkPatientsBetas, checkPatientInfo
from aux.checkInput import checkInput

min_version("6.10.0")

DEFAULTS = readDefaults(os.path.join(workflow.basedir, "config.schema.yml"))
config = makeConfig(config, DEFAULTS)
validate(config, "config.schema.yml")

workdir: config.get("workdir", ".")

if config.get("singularity"):
    ruleorder: beast > beast_nosingularity 
else: 
    ruleorder: beast_nosingularity > beast

XMLgenerator = workflow.source_path("aux/methylationBetas2xml.py")
_ = workflow.source_path("aux/createXML.py"), workflow.source_path("aux/readInputMethylation.py")
readIDAT = workflow.source_path("aux/readIDAT.R")
methCNV = workflow.source_path("aux/methCNV.R")
selectCpGs = workflow.source_path("aux/select_cpgs.py")
manifest = workflow.source_path("data/MethylationEPIC_v-1-0_B4.csv")
crossreactivefile = workflow.source_path("data/13059_2016_1066_MOESM1_ESM.csv")

patient_col = config.get("patient_col", "Patient")
sample_col = config.get("sample_col", "Sample")
age_col = config.get("age_col", "Age")
age_diagnosis_col = config.get("age_diagnosis_col"),
sample_type_col = config.get("sample_type_col"),

df = pd.read_csv(config.get("patientInfo"))
df = checkPatientInfo(df, patient_col)
PATIENTS = df[patient_col].unique()
SAMPLES = df.groupby(patient_col)[sample_col].agg(list)

stemCellsSweep=config.get("stemCells").split("-")
STEMCELLS=list(range(int(stemCellsSweep[0]), int(stemCellsSweep[1]), int(stemCellsSweep[2])))

mode = checkInput(config)
activateCNV = False

if mode == "trees":
    ruleorder: split_flucpgs_trees > split_flucpgs
elif mode == "complete":
    activateCNV = True if config.get("cnv") and df['Group'].str.lower().str.count('normal|control').sum() > 2 else False

    ruleorder: split_flucpgs > split_flucpgs_trees
else:
    raise RuntimeError(f"Unexpected inferred mode {mode}")


if config.get("fcpgs"):
    ruleorder: force_cpgs > select_cpgs
else:
    ruleorder: select_cpgs > force_cpgs

rule all:
    input: expand("results/{patient}/{patient}.{stemcell}cells.{ext}", patient=PATIENTS, stemcell=STEMCELLS, ext = ["ops", "trees", "log"]) 

rule readIDAT:
    input:
        dir = config.get("input"),
        samplesheet = config.get("patientInfo")
    output:
        directory("idat_processed"),
    params:
        name = config.get("output"),
        readIDAT = readIDAT,
        sample_col = sample_col
    threads:
        workflow.cores
    shell:
        """
        Rscript {params.readIDAT} -i {input.dir} -o {output} -n {params.name} -p {input.samplesheet} -c {threads} -s {params.sample_col}
        """

rule callCNVs:
    input:
        rules.readIDAT.output
    output:
        expand("{out}.blacklist.csv", out=config.get("output"))
    threads:
        workflow.cores,
    params:
        methCNV=methCNV,
        genome_plot = "--genome_plot" if config.get("genome_plot") else "",
        name = config.get("output")
    shell:
        """
        Rscript {params.methCNV} -i {input}/{params.name}.RData -c {threads} {params.genome_plot} 
        rm {input}/{params.name}.RData
        mv {input}/{output} {output}
        """

rule select_cpgs:
    input: "idat_processed"
    output: expand("{OUTPUT}.fluCpGs.csv", OUTPUT = config.get("output"))
    params:
        manifest = config.get("manifest", manifest),
        crossreactivefile = crossreactivefile,
        percent = config.get("percent", 5),
        patientInfo = config.get("patientInfo"),
        name = config.get("output"),
        patient_col = patient_col,
        sample_col = sample_col,
        selectCpGs = selectCpGs
    shell:
        """
        python3 {params.selectCpGs} -e {params.manifest} -p {params.percent} -c {crossreactivefile} {output} {input}/{params.name}.betas.csv {input}/{params.name}.M.csv {input}/{params.name}.U.csv {params.patientInfo} -P {params.patient_col} -S {params.sample_col}
        """

rule force_cpgs:
    input: 
        betas = "idat_processed",
        cpgs = config.get("fcpgs", "filenotfound")
    output: 
        flucpgs = expand("{OUTPUT}.fluCpGs.csv", OUTPUT = config.get("output")),
        cpg_ids = temp("cpgs.tmp")
    shell:
        """
        grep -o 'cg[0-9]*' {input.cpgs} > {output.cpg_ids}
        head -n1 {input}/{params.name}.betas.csv > {output.flucpgs}
        fgrep -f {output.cpg_ids} {input}/{params.name}.betas.csv >> {output.flucpgs}
        """

checkpoint split_flucpgs:
    input: expand("{OUTPUT}.fluCpGs.csv", OUTPUT = config.get("output")),
    output: temp(directory("miniBetas"))
    run:
        df = pd.read_csv(input[0])
        df.set_index(df.columns[0], inplace=True)
        
        makedirs(output[0], exist_ok = True)
        for patient in PATIENTS:
            samples = SAMPLES[patient]
            df.filter(items = samples).to_csv(f"{output[0]}/{patient}.csv")

checkpoint split_flucpgs_trees:
    input: config.get("input")
    output: temp(directory("miniBetas"))
    run:
        df = pd.read_csv(input[0])
        df.set_index(df.columns[0], inplace=True)
        
        makedirs(output[0], exist_ok = True)
        for patient in PATIENTS:
            samples = SAMPLES[patient]
            df.filter(items = samples).to_csv(f"{output[0]}/{patient}.csv")


def get_blacklists():
    files = [crossreactivefile]
    if activateCNV:
        files.append(rules.callCNVs.output)
    
    if config.get("blacklist"):
        files.append(config.get("blacklist"))

    return files


rule merge_blacklists:
    input: get_blacklists()
    output: temp("blacklist.csv")
    shell:
        """
        cat {input} | tr ',' '\t' | grep -o 'cg[0-9]*' > {output}
        """

def aggregate_input(wildcards):
    if mode == "complete":
        checkpoint_output = checkpoints.split_flucpgs.get(**wildcards).output[0]
    else:
        checkpoint_output = checkpoints.split_flucpgs_trees.get(**wildcards).output[0]  
    return "miniBetas/{patient}.csv"

rule getXML:
    input: 
        minibetas = aggregate_input,
        blacklist = rules.merge_blacklists.output,
        patientInfo = config.get("patientInfo"),
    output: 
        xml="results/XMLs/{patient}.{stemcell}cells.xml",
        tmp_input = temp("{patient}.{stemcell}cells.csv.tmp")
    params:
        delta = config.get("delta", .2),
        eta = config.get("eta", .7),
        kappa = config.get("kappa", 50),
        mu = config.get("mu", .1),
        gamma = config.get("gamma", .1),
        lam = config.get("lam", 1),
        iterations = config.get("iterations", 750_000),
        precision = config.get("precision", 6),
        sampling = config.get("sampling", 75),
        screenSampling = config.get("screenSampling", 500),
        stripRownames = "--stripRownames" if config.get("stripRownames") else "",
        mle_ps = "--mle-ps" if config.get("mle_ps") else "",
        mle_ss = "--mle-ss" if config.get("mle_ss") else "",
        hme = "--hme" if config.get("hme") else "",
        mle_steps = config.get("mle_steps", 100),
        mle_iterations = config.get("mle_iterations", int(config.get("iterations", 750_000)/config.get("mle_steps", 100))),
        mle_sampling = config.get("mle_sampling", int(config.get("iterations", 750_000)/1000)),
        luca_mode = config.get("luca_mode", "auto"),
        age_col = age_col,
        age_diagnosis_col = f"--age-diagnosis-col {age_diagnosis_col}" if age_diagnosis_col else "",
        sample_col = sample_col,
        sample_type_col = sample_type_col,        
        XMLgenerator = XMLgenerator
    shell:
        """
        input=$(realpath {input.minibetas})
        blacklist=$(realpath {input.blacklist})
        tmp="$(pwd)/{output.tmp_input}"
        fgrep -v -f $blacklist $input > $tmp

        cd $(dirname {output.xml})
        name=$(basename {output.xml})
        name=${{name%.xml}}

        python3 {XMLgenerator} --samplesheet {input.patientInfo} --input $tmp --output $name --luca-mode {params.luca_mode} --age-col {params.age_col} --sample-col {params.sample_col} --sample-type-col {params.sample_type_col} --stemCells {wildcards.stemcell} --delta {params.delta} --eta {params.eta} --kappa {params.kappa} --mu {params.mu} --gamma {params.gamma} --lambda {params.lam} --iterations {params.iterations} --precision {params.precision} --sampling {params.sampling} --screenSampling {params.screenSampling} {params.stripRownames} {params.mle_ps} {params.mle_ss} {params.hme} --mle-steps {params.mle_steps} --mle-iterations {params.mle_iterations} --mle-sampling {params.mle_sampling} 
        """

        

rule beast:
    input:
        "results/XMLs/{patient}.{stemcell}cells.xml"
    output:
        multiext("results/{patient}/{patient}.{stemcell}cells", ".trees", ".log", ".ops")
    log: 
        "beast_logs/{patient}/{stemcell}cells/{patient}.run.log"
    params: 
        seed = f"-seed {config.get('seed')}" if config.get("seed") >= 0 else ""
    shell:
        """
        out="results/{wildcards.patient}" && mkdir -p $out 
        singularity run -B $(pwd):/mnt docker://rachelicr/pisca-branch-master {params.seed} {input} > {log} 2> {log}
        for outfile in {output}
        do
            mv $(basename $outfile) $out
        done
        """

rule beast_nosingularity:
    input:
        "results/XMLs/{patient}.{stemcell}cells.xml"
    output:
        multiext("results/{patient}/{patient}.{stemcell}cells", ".trees", ".log", ".ops")
    log: 
        "beast_logs/{patient}/{stemcell}cells/{patient}.run.log"
    params: 
        seed = f"-seed {config.get('seed')}" if config.get("seed") >= 0 else ""
    shell:
        """
        out="results/{wildcards.patient}" && mkdir -p $out 
        beast -beagle_off {params.seed} {input} > {log} 2> {log}

        for outfile in {output}
        do
            mv $(basename $outfile) $out
        done
        """
