from pathlib import Path

OUTDIR = Path(config["outdir"])
QUERY = config["query_fasta"]
PROTEOMES = config["proteomes_dir"]
USER_CDS = Path(config["user_cds"])
THREADS = int(config.get("threads", 4))
COVERAGE = float(config.get("coverage", 0.70))
OUTGROUP_QUERY = str(config.get("outgroup_query", "")).strip()
ABSREL_P = float(config.get("absrel_p", 0.1))
ABSREL_DYNAMIC_START = float(config.get("absrel_dynamic_start", 0.05))
ABSREL_DYNAMIC_STEP = float(config.get("absrel_dynamic_step", 0.01))
ABSREL_DYNAMIC_MAX = float(config.get("absrel_dynamic_max", 1.0))
MEME_P = float(config.get("meme_p", 0.1))
USE_CLIPKIT = bool(config.get("use_clipkit", True))
CLIPKIT_MODE_PROTEIN = str(config.get("clipkit_mode_protein", "kpic-smart-gap")).strip() or "kpic-smart-gap"
CLIPKIT_MODE_CODON = str(config.get("clipkit_mode_codon", "kpic-smart-gap")).strip() or "kpic-smart-gap"
IQTREE_MODEL = str(config.get("iqtree_model", "MFP")).strip() or "MFP"
IQTREE_BOOTSTRAP = int(config.get("iqtree_bootstrap", 1000))
IQTREE_BNNI = bool(config.get("iqtree_bnni", False))
EXE = config.get("executables", {})
BLASTP = EXE["blastp"]
MAKEBLASTDB = EXE["makeblastdb"]
IQTREE = EXE["iqtree"]
HYPHY = EXE["hyphy"]
CODEML = EXE["codeml"]
CLIPKIT = EXE["clipkit"]
BABAPPALIGN = EXE["babappalign"]

HAVE_CDS = USER_CDS.exists() and USER_CDS.stat().st_size > 0

if HAVE_CDS:
    rule all:
        input:
            OUTDIR / "summary" / "episodic_selection_summary.txt",
            OUTDIR / "asr" / "asr_done.json"
else:
    rule all:
        input:
            OUTDIR / "orthogroup" / "WAITING_FOR_CDS.txt"


rule rbh_orthogroup:
    input:
        query=QUERY,
        proteomes=PROTEOMES,
    output:
        proteins=OUTDIR / "orthogroup" / "orthogroup_proteins.fasta",
        headers=OUTDIR / "orthogroup" / "orthogroup_headers.txt",
        rbh=OUTDIR / "orthogroup" / "rbh_summary.tsv",
    threads: THREADS
    shell:
        r'''
        mkdir -p {OUTDIR}/orthogroup
        python -m babappasnake.scripts.run_rbh_pipeline \
          --query {input.query} \
          --proteomes {input.proteomes} \
          --outdir {OUTDIR}/orthogroup \
          --coverage {COVERAGE} \
          --threads {threads} \
          --blastp "{BLASTP}" \
          --makeblastdb "{MAKEBLASTDB}"
        '''


rule waiting_note:
    input:
        proteins=rules.rbh_orthogroup.output.proteins,
        headers=rules.rbh_orthogroup.output.headers,
    output:
        note=OUTDIR / "orthogroup" / "WAITING_FOR_CDS.txt",
    shell:
        r'''
        python -m babappasnake.scripts.write_waiting_note \
          --orthogroup {input.proteins} \
          --headers {input.headers} \
          --expected-cds {USER_CDS} \
          --outfile {output.note}
        '''


if HAVE_CDS:

    rule map_cds:
        input:
            proteins=rules.rbh_orthogroup.output.proteins,
            cds=USER_CDS,
        output:
            mapped=OUTDIR / "mapped_cds" / "mapped_orthogroup_cds.fasta",
            proteins=OUTDIR / "mapped_cds" / "mapped_orthogroup_proteins.fasta",
            table=OUTDIR / "mapped_cds" / "cds_protein_mapping.tsv",
        shell:
            r'''
            mkdir -p {OUTDIR}/mapped_cds
            python -m babappasnake.scripts.map_cds_to_proteins \
              --proteins {input.proteins} \
              --cds {input.cds} \
              --out-cds {output.mapped} \
              --out-proteins {output.proteins} \
              --mapping {output.table}
            '''


    rule align_proteins:
        input:
            rules.map_cds.output.proteins
        output:
            aln=OUTDIR / "alignments" / "orthogroup_proteins.protein.aln.fasta"
        threads: THREADS
        shell:
            r'''
            mkdir -p {OUTDIR}/alignments
            cp {input} {OUTDIR}/alignments/orthogroup_proteins.fasta
            cd {OUTDIR}/alignments && \
            "{BABAPPALIGN}" orthogroup_proteins.fasta --model babappascore --mode protein
            test -f {output.aln}
            '''


    rule align_cds:
        input:
            rules.map_cds.output.mapped
        output:
            protein_aln=OUTDIR / "alignments" / "mapped_orthogroup_cds.protein.aln.fasta",
            codon_aln=OUTDIR / "alignments" / "mapped_orthogroup_cds.codon.aln.fasta"
        threads: THREADS
        shell:
            r'''
            mkdir -p {OUTDIR}/alignments
            cp {input} {OUTDIR}/alignments/mapped_orthogroup_cds.fasta
            cd {OUTDIR}/alignments && \
            "{BABAPPALIGN}" mapped_orthogroup_cds.fasta --model babappascore --mode codon
            test -f {output.protein_aln}
            test -f {output.codon_aln}
            '''


    if USE_CLIPKIT:
        rule trim_protein_alignment:
            input:
                rules.align_proteins.output.aln
            output:
                trimmed=OUTDIR / "trimmed" / "orthogroup_proteins.clipkit.fasta"
            shell:
                r'''
                mkdir -p {OUTDIR}/trimmed
                "{CLIPKIT}" {input} -m {CLIPKIT_MODE_PROTEIN:q} -s aa -o {output.trimmed}
                '''


        rule trim_codon_alignment:
            input:
                rules.align_cds.output.codon_aln
            output:
                trimmed=OUTDIR / "trimmed" / "mapped_orthogroup_cds.clipkit.fasta"
            shell:
                r'''
                mkdir -p {OUTDIR}/trimmed
                "{CLIPKIT}" {input} -m {CLIPKIT_MODE_CODON:q} --codon -o {output.trimmed}
                '''


        rule strip_terminal_stop_codon:
            input:
                rules.trim_codon_alignment.output.trimmed
            output:
                cleaned=OUTDIR / "trimmed" / "mapped_orthogroup_cds.clipkit.nostop.fasta"
            shell:
                r'''
                python -m babappasnake.scripts.strip_terminal_stop_codon \
                  --input {input} \
                  --output {output.cleaned}
                '''
        TREE_PROT_ALN = rules.trim_protein_alignment.output.trimmed
        TEST_CDS_ALN = rules.strip_terminal_stop_codon.output.cleaned
    else:
        TREE_PROT_ALN = rules.align_proteins.output.aln
        TEST_CDS_ALN = rules.align_cds.output.codon_aln


    rule iqtree_ml:
        input:
            TREE_PROT_ALN
        output:
            tree=OUTDIR / "tree" / "orthogroup.treefile"
        threads: THREADS
        params:
            bnni="-bnni" if IQTREE_BNNI else ""
        shell:
            r'''
            mkdir -p {OUTDIR}/tree
            "{IQTREE}" -s {input} -nt {threads} -m {IQTREE_MODEL:q} -B {IQTREE_BOOTSTRAP} {params.bnni} -redo -pre {OUTDIR}/tree/orthogroup
            test -f {output.tree}
            '''

    rule root_iqtree_outgroup:
        input:
            tree=rules.iqtree_ml.output.tree
        output:
            rooted=OUTDIR / "tree" / "orthogroup.rooted.treefile"
        params:
            outgroup=OUTGROUP_QUERY
        shell:
            r'''
            python -m babappasnake.scripts.root_tree_outgroup \
              --tree {input.tree} \
              --output {output.rooted} \
              --outgroup {params.outgroup:q}
            '''
    ROOTED_TREE = rules.root_iqtree_outgroup.output.rooted


    rule codeml_asr:
        input:
            aln=TEST_CDS_ALN,
            tree=ROOTED_TREE,
        output:
            done=OUTDIR / "asr" / "asr_done.json",
            mlc=OUTDIR / "asr" / "mlc_asr.txt",
            rst=OUTDIR / "asr" / "rst",
        shell:
            r'''
            mkdir -p {OUTDIR}/asr
            python -m babappasnake.scripts.run_codeml_asr \
              --alignment {input.aln} \
              --tree {input.tree} \
              --outdir {OUTDIR}/asr \
              --codeml "{CODEML}"
            test -f {output.done}
            test -f {output.mlc}
            test -f {output.rst}
            '''


    rule hyphy_exploratory:
        input:
            aln=TEST_CDS_ALN,
            tree=ROOTED_TREE,
        output:
            done=OUTDIR / "hyphy" / "hyphy_done.json",
            absrel=OUTDIR / "hyphy" / "absrel.json",
            meme=OUTDIR / "hyphy" / "meme.json",
        threads: THREADS
        shell:
            r'''
            python -m babappasnake.scripts.run_hyphy \
              --cds-aln {input.aln} \
              --tree {input.tree} \
              --outdir {OUTDIR}/hyphy \
              --threads {threads} \
              --hyphy "{HYPHY}"
            test -f {output.done}
            test -f {output.absrel}
            test -f {output.meme}
            '''


    rule parse_foregrounds:
        input:
            absrel=rules.hyphy_exploratory.output.absrel
        output:
            tsv=OUTDIR / "hyphy" / "significant_foregrounds.tsv",
            lst=OUTDIR / "hyphy" / "significant_foregrounds.txt",
            meta=OUTDIR / "hyphy" / "foreground_threshold.json",
        shell:
            r'''
            python -m babappasnake.scripts.parse_hyphy_foregrounds \
              --absrel-json {input.absrel} \
              --dynamic \
              --dynamic-start {ABSREL_DYNAMIC_START} \
              --dynamic-step {ABSREL_DYNAMIC_STEP} \
              --dynamic-max {ABSREL_DYNAMIC_MAX} \
              --out-tsv {output.tsv} \
              --out-list {output.lst} \
              --out-meta {output.meta}
            '''


    rule prepare_foreground_trees:
        input:
            tree=ROOTED_TREE,
            lst=rules.parse_foregrounds.output.lst,
        output:
            manifest=OUTDIR / "branchsite" / "foreground_trees.tsv"
        shell:
            r'''
            mkdir -p {OUTDIR}/branchsite
            python -m babappasnake.scripts.prepare_foreground_trees \
              --tree {input.tree} \
              --foreground-list {input.lst} \
              --outdir {OUTDIR}/branchsite \
              --manifest {output.manifest}
            '''


    rule branchsite_batch:
        input:
            aln=TEST_CDS_ALN,
            manifest=rules.prepare_foreground_trees.output.manifest,
            lst=rules.parse_foregrounds.output.lst,
        output:
            tsv=OUTDIR / "branchsite" / "branchsite_results.tsv"
        threads: THREADS
        shell:
            r'''
            python -m babappasnake.scripts.run_branchsite_batch \
              --alignment {input.aln} \
              --tree-dir {OUTDIR}/branchsite \
              --foreground-list {input.lst} \
              --out-tsv {output.tsv} \
              --codeml "{CODEML}" \
              --jobs {threads}
            '''


    rule final_summary:
        input:
            rbh=rules.rbh_orthogroup.output.rbh,
            mapping=rules.map_cds.output.table,
            absrel=rules.parse_foregrounds.output.tsv,
            absrel_meta=rules.parse_foregrounds.output.meta,
            branchsite=rules.branchsite_batch.output.tsv,
            hyphy_done=rules.hyphy_exploratory.output.done,
        output:
            OUTDIR / "summary" / "episodic_selection_summary.txt"
        shell:
            r'''
            mkdir -p {OUTDIR}/summary
            python -m babappasnake.scripts.summarize_results \
              --rbh {input.rbh} \
              --mapping {input.mapping} \
              --absrel {input.absrel} \
              --absrel-meta {input.absrel_meta} \
              --branchsite {input.branchsite} \
              --hyphy-dir {OUTDIR}/hyphy \
              --meme-p {MEME_P} \
              --out {output}
            '''
