# ---- begin snakebids boilerplate ----------------------------------------------

from snakebids import bids, generate_inputs, get_wildcard_constraints


configfile: workflow.source_path("../config/snakebids.yml")


# Get input wildcards
inputs_raw = generate_inputs(
    bids_dir=config["bids_dir"],
    pybids_inputs=config["pybids_inputs"],
    pybids_database_dir=config.get("pybids_db_dir"),
    pybids_reset_database=config.get("pybids_db_reset"),
    derivatives=config.get("derivatives", None),
    participant_label=config.get("participant_label", None),
    exclude_participant_label=config.get("exclude_participant_label", None),
)
inputs_afids = generate_inputs(
    bids_dir=config["afids_dir"],
    pybids_inputs=config["pybids_inputs_afids"],
    pybids_database_dir=config.get("pybids_db_dir"),
    pybids_reset_database=config.get("pybids_db_reset"),
    derivatives=config.get("derivatives", None),
    participant_label=config.get("participant_label", None),
    exclude_participant_label=config.get("exclude_participant_label", None),
)


# this adds constraints to the bids naming
wildcard_constraints:
    **get_wildcard_constraints(config["pybids_inputs"]),


# ---- end snakebids boilerplate ------------------------------------------------


rule extract_afid:
    input:
        fcsv=inputs_afids["afids"].path,
    output:
        txt=bids(
            root=str(Path(config["output_dir"]) / "c3d"),
            suffix="afid-{afid}.txt",
            **inputs_raw["T1w"].wildcards
        ),
    log:
        bids(
            root="logs",
            suffix="landmark.log",
            afid="{afid}",
            **inputs_raw["T1w"].wildcards
        ),
    script:
        './scripts/extract_afids.py"'


rule gen_sphere:
    input:
        t1w=inputs_raw["T1w"].path,
        txt=rules.extract_afid.output.txt,
    output:
        sphere=bids(
            root=str(Path(config["output_dir"]) / "c3d"),
            suffix="afid-{afid}.nii.gz",
            **inputs_raw["T1w"].wildcards
        ),
    container:
        config["containers"]["c3d"]
    shell:
        "c3d {input.t1w} -scale 0 -landmarks-to-spheres {input.txt} 1 -o {output.sphere}"


rule gen_prob:
    input:
        t1w=inputs_raw["T1w"].path,
    output:
        prob=bids(
            root=str(Path(config["output_dir"]) / "c3d"),
            suffix="afid-{afid}_prob.nii.gz",
            **inputs_raw["T1w"].wildcards
        ),
    container:
        config["containers"]["c3d"]
    shell:
        "c3d {input.t1w} -sdt -o {output.prob}"


rule gen_norm:
    input:
        prob=rules.gen_prob.output.prob,
    output:
        norm=bids(
            root=str(Path(config["output_dir"]) / "c3d"),
            suffix="afid-{afid}_prob_norm_-{k}.nii.gz",
            **inputs_raw["T1w"].wildcards
        ),
    container:
        config["containers"]["c3d"]
    shell:
        "c3d {input.prob} -scale {k} -exp -o {output.norm}"


rule gen_mask:
    input:
        prob=rules.gen_prob.output.prob,
    output:
        bin=bids(
            root=str(Path(config["output_dir"]) / "c3d"),
            suffix="afid-{afid}_bin.nii.gz",
            **inputs_raw["T1w"].wildcards
        ),
    container:
        config["containers"]["c3d"]
    shell:
        "c3d {input.prob} -threshold 0 10 1 0 -o {output.bin}"


rule gen_patches:
    input:
        t1w=inputs_raw["T1w"].path,
        prob=rules.gen_prob.output.prob,
        mask=rules.gen_mask.output.bin,
    output:
        patch=temp(
            bids(
                root=str(Path(config["output_dir"]) / "c3d"),
                desc="num{num_augment}angle{angle_stdev}radius{radius}freq{frequency}",
                suffix="afid-{afid}_patch.dat",
                **inputs_raw["T1w"].wildcards
            )
        ),
    params:
        radius_arg=lambda wildcards: "x".join([wildcards["radius"] for _ in range(3)]),
    container:
        config["containers"]["c3d"]
    shell:
        "c3d {input.t1w} {input.prob} {input.mask} "
        "-xpa {wildcards.num_augment} {wildcards.angle_stdev} "
        "-xp {output.patch} {params.radius_arg} {wildcards.frequency}"


rule cat_patches:
    input:
        patch=inputs_raw["T1w"].expand(
            rules.gen_patches.output.patch, allow_missing=True
        ),
    output:
        combined_patch=protected(
            bids(
                root=str(Path(config["output_dir"]) / "c3d"),
                desc="num{num_augment}angle{angle_stdev}radius{radius}freq{frequency}",
                suffix="afid-{afid}_allpatches.dat",
            )
        ),
    shell:
        "cat {input.patch} > {output.combined_patch}"


rule train_model:
    input:
        combined_patch=rules.cat_patches.output.combined_patch,
    output:
        model=directory(
            bids(
                root=str(Path(config["output_dir"]) / "afids-cnn-train"),
                desc="num{num_augment}angle{angle_stdev}radius{radius}freq{frequency}",
                suffix="afid-{afid}_cnn.model",
            )
        ),
        history=bids(
            root=str(Path(config["output_dir"]) / "afids-cnn-train"),
            desc="num{num_augment}angle{angle_stdev}radius{radius}freq{frequency}",
            suffix="afid-{afid}_loss.csv",
        ),
    params:
        num_channels=config["num_channels"],
    shell:
        "auto_afids_cnn_train {params.num_channels} {wildcards.radius} "
        "{input.combined_patch} {output.model} --loss_out_path {output.history}"


rule all:
    input:
        models=expand(
            rules.train_model.output.model,
            afid=[f"{afid:02}" for afid in range(1, 33)],
            num_augment=config["num_augment"],
            angle_stdev=config["angle_stdev"],
            radius=config["radius"],
            frequency=config["frequency"],
        ),
    output:
        combined_model=bids(
            root=str(Path(config["output_dir"]) / "afids-cnn-model"),
            suffix="model-combined.afidsmodel",
        ),
    params:
        radius=config["radius"],
    resources:
        script=str(Path(workflow.basedir) / "scripts" / "assemble_models.py"),
    default_target: True
    shell:
        "python3 {resources.script} {input.models} {params.radius} "
        "{output.combined_model}"
