from pathlib import Path

OUTDIR = Path(config["outdir"])
QUERY = config["query_fasta"]
PROTEOMES = config["proteomes_dir"]
USER_CDS = Path(config["user_cds"])
THREADS = int(config.get("threads", 4))
COVERAGE = float(config.get("coverage", 0.70))
OUTGROUP_QUERY = str(config.get("outgroup_query", "")).strip()
ABSREL_P = float(config.get("absrel_p", 0.1))
ABSREL_DYNAMIC_START = float(config.get("absrel_dynamic_start", 0.05))
ABSREL_DYNAMIC_STEP = float(config.get("absrel_dynamic_step", 0.01))
ABSREL_DYNAMIC_MAX = float(config.get("absrel_dynamic_max", 1.0))
MEME_P = float(config.get("meme_p", 0.1))
USE_CLIPKIT = bool(config.get("use_clipkit", True))
CLIPKIT_MODE_PROTEIN = str(config.get("clipkit_mode_protein", "kpic-smart-gap")).strip() or "kpic-smart-gap"
CLIPKIT_MODE_CODON = str(config.get("clipkit_mode_codon", "kpic-smart-gap")).strip() or "kpic-smart-gap"
IQTREE_MODEL = str(config.get("iqtree_model", "MFP")).strip() or "MFP"
IQTREE_BOOTSTRAP = int(config.get("iqtree_bootstrap", 1000))
IQTREE_BNNI = bool(config.get("iqtree_bnni", False))
ABSREL_BRANCHES = str(config.get("absrel_branches", "Leaves")).strip() or "Leaves"
MEME_BRANCHES = str(config.get("meme_branches", "Leaves")).strip() or "Leaves"
CODEML_CODONFREQ = int(config.get("codeml_codonfreq", 2))

SUPPORTED_ALIGNMENT_METHODS = ("babappalign", "mafft", "prank")


def _normalize_alignment_methods(raw):
    if isinstance(raw, str):
        methods = [token.strip().lower() for token in raw.split(",") if token.strip()]
    elif isinstance(raw, list):
        methods = [str(token).strip().lower() for token in raw if str(token).strip()]
    else:
        methods = []

    if not methods:
        methods = list(SUPPORTED_ALIGNMENT_METHODS)

    normalized = []
    seen = set()
    unknown = []
    for method in methods:
        if method not in SUPPORTED_ALIGNMENT_METHODS:
            unknown.append(method)
            continue
        if method not in seen:
            seen.add(method)
            normalized.append(method)

    if unknown:
        raise ValueError(
            f"Unsupported alignment methods in config: {', '.join(unknown)}. "
            f"Supported: {', '.join(SUPPORTED_ALIGNMENT_METHODS)}"
        )

    if not normalized:
        raise ValueError("No valid alignment methods configured.")

    return normalized


ALIGNMENT_METHODS = _normalize_alignment_methods(config.get("alignment_methods", list(SUPPORTED_ALIGNMENT_METHODS)))
METHOD_COUNT = max(1, len(ALIGNMENT_METHODS))
PER_METHOD_CORES = min(
    max(1, THREADS),
    max(1, int(config.get("per_method_cores", max(1, THREADS // METHOD_COUNT)))),
)
METHOD_PATTERN = "|".join(ALIGNMENT_METHODS)
PAL2NAL_METHODS = [method for method in ALIGNMENT_METHODS if method in {"mafft", "prank"}]
PAL2NAL_PATTERN = "|".join(PAL2NAL_METHODS)
PRIMARY_METHOD = ALIGNMENT_METHODS[0]

EXE = config.get("executables", {})
BLASTP = EXE["blastp"]
MAKEBLASTDB = EXE["makeblastdb"]
IQTREE = EXE["iqtree"]
HYPHY = EXE["hyphy"]
CODEML = EXE["codeml"]
CLIPKIT = EXE["clipkit"]
BABAPPALIGN = EXE.get("babappalign", "babappalign")
MAFFT = EXE.get("mafft", "mafft")
PRANK = EXE.get("prank", "prank")
PAL2NAL = EXE.get("pal2nal", "pal2nal")

HAVE_CDS = USER_CDS.exists() and USER_CDS.stat().st_size > 0

if HAVE_CDS:
    rule all:
        input:
            expand(OUTDIR / "summary" / "{method}" / "episodic_selection_summary.txt", method=ALIGNMENT_METHODS),
            OUTDIR / "summary" / "comparative_reproducibility_summary.txt",
            OUTDIR / "summary" / "episodic_selection_summary.txt",
            expand(OUTDIR / "asr" / "{method}" / "asr_done.json", method=ALIGNMENT_METHODS),
            OUTDIR / "asr" / "asr_done.json"
else:
    rule all:
        input:
            OUTDIR / "orthogroup" / "WAITING_FOR_CDS.txt"


rule rbh_orthogroup:
    input:
        query=QUERY,
        proteomes=PROTEOMES,
    output:
        proteins=OUTDIR / "orthogroup" / "orthogroup_proteins.fasta",
        headers=OUTDIR / "orthogroup" / "orthogroup_headers.txt",
        rbh=OUTDIR / "orthogroup" / "rbh_summary.tsv",
    threads: THREADS
    shell:
        r'''
        mkdir -p {OUTDIR}/orthogroup
        python -m babappasnake.scripts.run_rbh_pipeline \
          --query {input.query} \
          --proteomes {input.proteomes} \
          --outdir {OUTDIR}/orthogroup \
          --coverage {COVERAGE} \
          --threads {threads} \
          --blastp "{BLASTP}" \
          --makeblastdb "{MAKEBLASTDB}"
        '''


rule waiting_note:
    input:
        proteins=rules.rbh_orthogroup.output.proteins,
        headers=rules.rbh_orthogroup.output.headers,
    output:
        note=OUTDIR / "orthogroup" / "WAITING_FOR_CDS.txt",
    shell:
        r'''
        python -m babappasnake.scripts.write_waiting_note \
          --orthogroup {input.proteins} \
          --headers {input.headers} \
          --expected-cds {USER_CDS} \
          --outfile {output.note}
        '''


if HAVE_CDS:

    rule map_cds:
        input:
            proteins=rules.rbh_orthogroup.output.proteins,
            cds=USER_CDS,
        output:
            mapped=OUTDIR / "mapped_cds" / "mapped_orthogroup_cds.fasta",
            proteins=OUTDIR / "mapped_cds" / "mapped_orthogroup_proteins.fasta",
            table=OUTDIR / "mapped_cds" / "cds_protein_mapping.tsv",
        shell:
            r'''
            mkdir -p {OUTDIR}/mapped_cds
            python -m babappasnake.scripts.map_cds_to_proteins \
              --proteins {input.proteins} \
              --cds {input.cds} \
              --out-cds {output.mapped} \
              --out-proteins {output.proteins} \
              --mapping {output.table}
            '''


    if "babappalign" in ALIGNMENT_METHODS:
        rule align_proteins_babappalign:
            input:
                rules.map_cds.output.proteins
            output:
                aln=OUTDIR / "alignments" / "babappalign" / "orthogroup_proteins.protein.aln.fasta"
            threads: PER_METHOD_CORES
            shell:
                r'''
                mkdir -p {OUTDIR}/alignments/babappalign
                cp {input} {OUTDIR}/alignments/babappalign/orthogroup_proteins.fasta
                cd {OUTDIR}/alignments/babappalign && \
                "{BABAPPALIGN}" orthogroup_proteins.fasta --model babappascore --mode protein
                test -f {output.aln}
                '''


        rule align_cds_babappalign:
            input:
                rules.map_cds.output.mapped
            output:
                protein_aln=OUTDIR / "alignments" / "babappalign" / "mapped_orthogroup_cds.protein.aln.fasta",
                codon_aln=OUTDIR / "alignments" / "babappalign" / "mapped_orthogroup_cds.codon.aln.fasta"
            threads: PER_METHOD_CORES
            shell:
                r'''
                mkdir -p {OUTDIR}/alignments/babappalign
                cp {input} {OUTDIR}/alignments/babappalign/mapped_orthogroup_cds.fasta
                cd {OUTDIR}/alignments/babappalign && \
                "{BABAPPALIGN}" mapped_orthogroup_cds.fasta --model babappascore --mode codon
                test -f {output.protein_aln}
                test -f {output.codon_aln}
                '''


    if "mafft" in ALIGNMENT_METHODS:
        rule align_proteins_mafft:
            input:
                rules.map_cds.output.proteins
            output:
                aln=OUTDIR / "alignments" / "mafft" / "orthogroup_proteins.protein.aln.fasta"
            threads: PER_METHOD_CORES
            shell:
                r'''
                mkdir -p {OUTDIR}/alignments/mafft
                "{MAFFT}" --auto --thread {threads} {input} > {output.aln}
                test -s {output.aln}
                '''


    if "prank" in ALIGNMENT_METHODS:
        rule align_proteins_prank:
            input:
                rules.map_cds.output.proteins
            output:
                aln=OUTDIR / "alignments" / "prank" / "orthogroup_proteins.protein.aln.fasta"
            threads: 1
            shell:
                r'''
                mkdir -p {OUTDIR}/alignments/prank
                "{PRANK}" -d={input} -o={OUTDIR}/alignments/prank/orthogroup_proteins.prank -protein -F
                test -f {OUTDIR}/alignments/prank/orthogroup_proteins.prank.best.fas
                cp {OUTDIR}/alignments/prank/orthogroup_proteins.prank.best.fas {output.aln}
                '''


    if PAL2NAL_METHODS:
        rule align_cds_pal2nal_method:
            input:
                protein_aln=OUTDIR / "alignments" / "{method}" / "orthogroup_proteins.protein.aln.fasta",
                cds=rules.map_cds.output.mapped,
            output:
                protein_aln=OUTDIR / "alignments" / "{method}" / "mapped_orthogroup_cds.protein.aln.fasta",
                codon_aln=OUTDIR / "alignments" / "{method}" / "mapped_orthogroup_cds.codon.aln.fasta"
            wildcard_constraints:
                method=PAL2NAL_PATTERN
            shell:
                r'''
                mkdir -p {OUTDIR}/alignments/{wildcards.method}
                cp {input.protein_aln} {output.protein_aln}
                ("{PAL2NAL}" {input.protein_aln} {input.cds} -output fasta > {output.codon_aln}) || \
                (perl "{PAL2NAL}" {input.protein_aln} {input.cds} -output fasta > {output.codon_aln})
                test -s {output.codon_aln}
                '''


    rule align_proteins_all_methods:
        input:
            expand(OUTDIR / "alignments" / "{method}" / "orthogroup_proteins.protein.aln.fasta", method=ALIGNMENT_METHODS)


    rule align_cds_all_methods:
        input:
            expand(OUTDIR / "alignments" / "{method}" / "mapped_orthogroup_cds.protein.aln.fasta", method=ALIGNMENT_METHODS),
            expand(OUTDIR / "alignments" / "{method}" / "mapped_orthogroup_cds.codon.aln.fasta", method=ALIGNMENT_METHODS)


    if USE_CLIPKIT:
        rule trim_protein_alignment_method:
            input:
                OUTDIR / "alignments" / "{method}" / "orthogroup_proteins.protein.aln.fasta"
            output:
                trimmed=OUTDIR / "trimmed" / "{method}" / "orthogroup_proteins.clipkit.fasta"
            wildcard_constraints:
                method=METHOD_PATTERN
            shell:
                r'''
                mkdir -p {OUTDIR}/trimmed/{wildcards.method}
                "{CLIPKIT}" {input} -m {CLIPKIT_MODE_PROTEIN:q} -s aa -o {output.trimmed}
                '''


        rule trim_codon_alignment_method:
            input:
                OUTDIR / "alignments" / "{method}" / "mapped_orthogroup_cds.codon.aln.fasta"
            output:
                trimmed=OUTDIR / "trimmed" / "{method}" / "mapped_orthogroup_cds.clipkit.fasta"
            wildcard_constraints:
                method=METHOD_PATTERN
            shell:
                r'''
                mkdir -p {OUTDIR}/trimmed/{wildcards.method}
                "{CLIPKIT}" {input} -m {CLIPKIT_MODE_CODON:q} --codon -o {output.trimmed}
                '''


        rule strip_terminal_stop_codon_method:
            input:
                OUTDIR / "trimmed" / "{method}" / "mapped_orthogroup_cds.clipkit.fasta"
            output:
                cleaned=OUTDIR / "trimmed" / "{method}" / "mapped_orthogroup_cds.clipkit.nostop.fasta"
            wildcard_constraints:
                method=METHOD_PATTERN
            shell:
                r'''
                python -m babappasnake.scripts.strip_terminal_stop_codon \
                  --input {input} \
                  --output {output.cleaned}
                '''


        rule trim_protein_alignment_all_methods:
            input:
                expand(OUTDIR / "trimmed" / "{method}" / "orthogroup_proteins.clipkit.fasta", method=ALIGNMENT_METHODS)


        rule trim_codon_alignment_all_methods:
            input:
                expand(OUTDIR / "trimmed" / "{method}" / "mapped_orthogroup_cds.clipkit.fasta", method=ALIGNMENT_METHODS)


        rule strip_terminal_stop_codon_all_methods:
            input:
                expand(OUTDIR / "trimmed" / "{method}" / "mapped_orthogroup_cds.clipkit.nostop.fasta", method=ALIGNMENT_METHODS)


    def tree_protein_alignment(wildcards):
        if USE_CLIPKIT:
            return OUTDIR / "trimmed" / wildcards.method / "orthogroup_proteins.clipkit.fasta"
        return OUTDIR / "alignments" / wildcards.method / "orthogroup_proteins.protein.aln.fasta"


    def test_cds_alignment(wildcards):
        if USE_CLIPKIT:
            return OUTDIR / "trimmed" / wildcards.method / "mapped_orthogroup_cds.clipkit.nostop.fasta"
        return OUTDIR / "alignments" / wildcards.method / "mapped_orthogroup_cds.codon.aln.fasta"


    rule iqtree_ml_method:
        input:
            tree_protein_alignment
        output:
            tree=OUTDIR / "tree" / "{method}" / "orthogroup.treefile"
        wildcard_constraints:
            method=METHOD_PATTERN
        threads: PER_METHOD_CORES
        params:
            bnni="-bnni" if IQTREE_BNNI else ""
        shell:
            r'''
            mkdir -p {OUTDIR}/tree/{wildcards.method}
            "{IQTREE}" -s {input} -nt {threads} -m {IQTREE_MODEL:q} -B {IQTREE_BOOTSTRAP} {params.bnni} -redo -pre {OUTDIR}/tree/{wildcards.method}/orthogroup
            test -f {output.tree}
            '''


    rule iqtree_ml_all_methods:
        input:
            expand(OUTDIR / "tree" / "{method}" / "orthogroup.treefile", method=ALIGNMENT_METHODS)


    rule root_iqtree_outgroup_method:
        input:
            tree=OUTDIR / "tree" / "{method}" / "orthogroup.treefile"
        output:
            rooted=OUTDIR / "tree" / "{method}" / "orthogroup.rooted.treefile"
        wildcard_constraints:
            method=METHOD_PATTERN
        params:
            outgroup=OUTGROUP_QUERY
        shell:
            r'''
            python -m babappasnake.scripts.root_tree_outgroup \
              --tree {input.tree} \
              --output {output.rooted} \
              --outgroup {params.outgroup:q}
            '''


    rule root_iqtree_outgroup_all_methods:
        input:
            expand(OUTDIR / "tree" / "{method}" / "orthogroup.rooted.treefile", method=ALIGNMENT_METHODS)


    rule codeml_asr_method:
        input:
            aln=test_cds_alignment,
            tree=OUTDIR / "tree" / "{method}" / "orthogroup.rooted.treefile",
        output:
            done=OUTDIR / "asr" / "{method}" / "asr_done.json",
            mlc=OUTDIR / "asr" / "{method}" / "mlc_asr.txt",
            rst=OUTDIR / "asr" / "{method}" / "rst",
        wildcard_constraints:
            method=METHOD_PATTERN
        shell:
            r'''
            mkdir -p {OUTDIR}/asr/{wildcards.method}
            python -m babappasnake.scripts.run_codeml_asr \
              --alignment {input.aln} \
              --tree {input.tree} \
              --outdir {OUTDIR}/asr/{wildcards.method} \
              --codeml "{CODEML}" \
              --codonfreq {CODEML_CODONFREQ}
            test -f {output.done}
            test -f {output.mlc}
            test -f {output.rst}
            '''


    rule codeml_asr_all_methods:
        input:
            expand(OUTDIR / "asr" / "{method}" / "asr_done.json", method=ALIGNMENT_METHODS)


    rule hyphy_exploratory_method:
        input:
            aln=test_cds_alignment,
            tree=OUTDIR / "tree" / "{method}" / "orthogroup.rooted.treefile",
        output:
            done=OUTDIR / "hyphy" / "{method}" / "hyphy_done.json",
            absrel=OUTDIR / "hyphy" / "{method}" / "absrel.json",
            meme=OUTDIR / "hyphy" / "{method}" / "meme.json",
        wildcard_constraints:
            method=METHOD_PATTERN
        threads: PER_METHOD_CORES
        shell:
            r'''
            python -m babappasnake.scripts.run_hyphy \
              --cds-aln {input.aln} \
              --tree {input.tree} \
              --outdir {OUTDIR}/hyphy/{wildcards.method} \
              --threads {threads} \
              --hyphy "{HYPHY}" \
              --absrel-branches {ABSREL_BRANCHES:q} \
              --meme-branches {MEME_BRANCHES:q}
            test -f {output.done}
            test -f {output.absrel}
            test -f {output.meme}
            '''


    rule hyphy_exploratory_all_methods:
        input:
            expand(OUTDIR / "hyphy" / "{method}" / "hyphy_done.json", method=ALIGNMENT_METHODS)


    rule parse_foregrounds_method:
        input:
            absrel=OUTDIR / "hyphy" / "{method}" / "absrel.json"
        output:
            tsv=OUTDIR / "hyphy" / "{method}" / "significant_foregrounds.tsv",
            lst=OUTDIR / "hyphy" / "{method}" / "significant_foregrounds.txt",
            meta=OUTDIR / "hyphy" / "{method}" / "foreground_threshold.json",
        wildcard_constraints:
            method=METHOD_PATTERN
        shell:
            r'''
            python -m babappasnake.scripts.parse_hyphy_foregrounds \
              --absrel-json {input.absrel} \
              --dynamic \
              --dynamic-start {ABSREL_DYNAMIC_START} \
              --dynamic-step {ABSREL_DYNAMIC_STEP} \
              --dynamic-max {ABSREL_DYNAMIC_MAX} \
              --out-tsv {output.tsv} \
              --out-list {output.lst} \
              --out-meta {output.meta}
            '''


    rule parse_foregrounds_all_methods:
        input:
            expand(OUTDIR / "hyphy" / "{method}" / "significant_foregrounds.tsv", method=ALIGNMENT_METHODS),
            expand(OUTDIR / "hyphy" / "{method}" / "significant_foregrounds.txt", method=ALIGNMENT_METHODS),
            expand(OUTDIR / "hyphy" / "{method}" / "foreground_threshold.json", method=ALIGNMENT_METHODS)


    rule prepare_foreground_trees_method:
        input:
            tree=OUTDIR / "tree" / "{method}" / "orthogroup.rooted.treefile",
            lst=OUTDIR / "hyphy" / "{method}" / "significant_foregrounds.txt",
        output:
            manifest=OUTDIR / "branchsite" / "{method}" / "foreground_trees.tsv"
        wildcard_constraints:
            method=METHOD_PATTERN
        shell:
            r'''
            mkdir -p {OUTDIR}/branchsite/{wildcards.method}
            python -m babappasnake.scripts.prepare_foreground_trees \
              --tree {input.tree} \
              --foreground-list {input.lst} \
              --outdir {OUTDIR}/branchsite/{wildcards.method} \
              --manifest {output.manifest}
            '''


    rule prepare_foreground_trees_all_methods:
        input:
            expand(OUTDIR / "branchsite" / "{method}" / "foreground_trees.tsv", method=ALIGNMENT_METHODS)


    rule branchsite_batch_method:
        input:
            aln=test_cds_alignment,
            manifest=OUTDIR / "branchsite" / "{method}" / "foreground_trees.tsv",
            lst=OUTDIR / "hyphy" / "{method}" / "significant_foregrounds.txt",
        output:
            tsv=OUTDIR / "branchsite" / "{method}" / "branchsite_results.tsv"
        wildcard_constraints:
            method=METHOD_PATTERN
        threads: PER_METHOD_CORES
        shell:
            r'''
            python -m babappasnake.scripts.run_branchsite_batch \
              --alignment {input.aln} \
              --tree-dir {OUTDIR}/branchsite/{wildcards.method} \
              --foreground-list {input.lst} \
              --out-tsv {output.tsv} \
              --codeml "{CODEML}" \
              --codonfreq {CODEML_CODONFREQ} \
              --jobs {threads}
            '''


    rule branchsite_batch_all_methods:
        input:
            expand(OUTDIR / "branchsite" / "{method}" / "branchsite_results.tsv", method=ALIGNMENT_METHODS)


    rule final_summary_method:
        input:
            rbh=rules.rbh_orthogroup.output.rbh,
            mapping=rules.map_cds.output.table,
            absrel=OUTDIR / "hyphy" / "{method}" / "significant_foregrounds.tsv",
            absrel_meta=OUTDIR / "hyphy" / "{method}" / "foreground_threshold.json",
            branchsite=OUTDIR / "branchsite" / "{method}" / "branchsite_results.tsv",
            hyphy_done=OUTDIR / "hyphy" / "{method}" / "hyphy_done.json",
        output:
            OUTDIR / "summary" / "{method}" / "episodic_selection_summary.txt"
        wildcard_constraints:
            method=METHOD_PATTERN
        shell:
            r'''
            mkdir -p {OUTDIR}/summary/{wildcards.method}
            python -m babappasnake.scripts.summarize_results \
              --rbh {input.rbh} \
              --mapping {input.mapping} \
              --absrel {input.absrel} \
              --absrel-meta {input.absrel_meta} \
              --branchsite {input.branchsite} \
              --hyphy-dir {OUTDIR}/hyphy/{wildcards.method} \
              --meme-p {MEME_P} \
              --method {wildcards.method} \
              --out {output}
            '''


    rule final_summary_all_methods:
        input:
            expand(OUTDIR / "summary" / "{method}" / "episodic_selection_summary.txt", method=ALIGNMENT_METHODS)


    rule compare_alignment_methods:
        input:
            summaries=expand(OUTDIR / "summary" / "{method}" / "episodic_selection_summary.txt", method=ALIGNMENT_METHODS),
            absrel=expand(OUTDIR / "hyphy" / "{method}" / "significant_foregrounds.tsv", method=ALIGNMENT_METHODS),
            absrel_meta=expand(OUTDIR / "hyphy" / "{method}" / "foreground_threshold.json", method=ALIGNMENT_METHODS),
            branchsite=expand(OUTDIR / "branchsite" / "{method}" / "branchsite_results.tsv", method=ALIGNMENT_METHODS),
        output:
            OUTDIR / "summary" / "comparative_reproducibility_summary.txt"
        params:
            methods=",".join(ALIGNMENT_METHODS)
        shell:
            r'''
            mkdir -p {OUTDIR}/summary
            python -m babappasnake.scripts.compare_alignment_methods \
              --outdir {OUTDIR} \
              --methods {params.methods:q} \
              --out {output}
            '''


    rule final_summary_primary_alias:
        input:
            OUTDIR / "summary" / PRIMARY_METHOD / "episodic_selection_summary.txt"
        output:
            OUTDIR / "summary" / "episodic_selection_summary.txt"
        shell:
            r'''
            cp {input} {output}
            '''


    rule asr_primary_alias:
        input:
            OUTDIR / "asr" / PRIMARY_METHOD / "asr_done.json"
        output:
            OUTDIR / "asr" / "asr_done.json"
        shell:
            r'''
            mkdir -p {OUTDIR}/asr
            cp {input} {output}
            '''
