# Slurm SDK

> Python SDK for containerized Slurm workflows and task orchestration. Define tasks in Python, package them into reproducible containers, and submit workflows to your existing Slurm cluster.

## Installation

```bash
pip install slurm-sdk
# or
uv add slurm-sdk
```

## Quick Start

```python
from slurm import Cluster, task

@task(time="00:10:00", mem="4G")
def train(lr: float) -> dict:
    return {"accuracy": 0.95}

with Cluster(backend_type="ssh", hostname="login.hpc.example.com") as cluster:
    job = train(lr=0.001)
    print(job.get_result())  # {"accuracy": 0.95}
```

## Core Concepts

- `@task` — Decorates a Python function to run as a Slurm job
- `@workflow` — Decorates a function that orchestrates other tasks on the cluster
- `Cluster` — Connection to a Slurm cluster (SSH or local backend); used as context manager
- `Job` — A submitted job; call `.wait()`, `.get_result()`, `.cancel()`, `.status()`
- `ArrayJob` — A batch of parallel jobs from `.map()`; call `.get_results()`
- `JobContext` — Runtime metadata injected into tasks (job_id, rank, world_size, GPUs)
- `WorkflowContext` — Injected into workflows (cluster, shared_dir, workflow_job_id)
- `BaseCallback` — Subclass to observe lifecycle events (packaging, submission, execution)

## Decision Tree: Which API to Use

- **One task, one input** → `my_task(args)` inside `with Cluster(...) as cluster:` context
- **One task, N inputs (parallel)** → `my_task.map(items)` returns `ArrayJob`
- **Sequential pipeline (A then B)** → `task_b.after(job_a)(args)` or pass `Job` as argument
- **Complex orchestration** → `@workflow` decorator with `WorkflowContext`
- **Dynamic branching** → `@workflow` + `job.get_result()` mid-workflow to decide next steps
- **Local testing (no cluster)** → `my_task.unwrapped(args)` calls function directly
- **Override options at runtime** → `my_task.with_options(time="02:00:00")(args)`

## Recipes

### 1. Define and submit a simple task

```python
from slurm import Cluster, task

@task(time="01:00:00", mem="8G", cpus_per_task=4)
def process(data_path: str) -> dict:
    # Your processing code
    return {"status": "done", "rows": 1000}

with Cluster(backend_type="ssh", hostname="login.hpc.example.com") as cluster:
    job = process(data_path="/data/input.csv")
    job.wait()
    result = job.get_result()
    print(result)
```

### 2. Submit with container packaging

```python
from slurm import Cluster, task

@task(
    time="01:00:00",
    gpus_per_node=1,
    packaging="container:nvcr.io/nvidia/pytorch:24.01-py3",
)
def train(lr: float) -> dict:
    import torch
    return {"gpus": torch.cuda.device_count()}

with Cluster(backend_type="ssh", hostname="hpc.example.com") as cluster:
    job = train(lr=0.001)
    print(job.get_result())
```

### 3. Build container from Dockerfile

```python
@task(
    time="02:00:00",
    packaging="container",
    packaging_dockerfile="Dockerfile",
    packaging_registry="registry.example.com/my-project",
    packaging_platform="linux/amd64",
)
def train(config: dict) -> dict:
    ...
```

### 4. Run N items in parallel with map()

```python
from slurm import Cluster, task

@task(time="00:30:00", mem="4G")
def process(filepath: str) -> dict:
    return {"file": filepath, "status": "done"}

with Cluster(backend_type="ssh", hostname="hpc.example.com") as cluster:
    files = ["a.csv", "b.csv", "c.csv", "d.csv"]
    array_job = process.map(files)       # Submits one Slurm array job
    results = array_job.get_results()    # Wait and collect all results
    print(results)
```

### 5. Map with keyword arguments

```python
configs = [
    {"lr": 0.001, "epochs": 10},
    {"lr": 0.01, "epochs": 20},
    {"lr": 0.1, "epochs": 5},
]
array_job = train.map(configs)  # Each dict is unpacked as **kwargs
results = array_job.get_results()
```

### 6. Chain tasks with dependencies using after()

```python
from slurm import Cluster, task

@task(time="00:10:00")
def preprocess(raw: str) -> str:
    return f"cleaned_{raw}"

@task(time="01:00:00", gpus_per_node=1)
def train(data: str) -> dict:
    return {"model": "trained", "data": data}

with Cluster(backend_type="ssh", hostname="hpc.example.com") as cluster:
    prep_job = preprocess(raw="data.csv")
    train_job = train.after(prep_job)(data="dataset")
    # train_job waits for prep_job to complete before starting
    print(train_job.get_result())
```

### 7. Automatic dependencies by passing Job as argument

```python
with Cluster(backend_type="ssh", hostname="hpc.example.com") as cluster:
    prep_job = preprocess(raw="data.csv")
    # Passing a Job object as an argument automatically creates a dependency.
    # At runtime, the Job is replaced with its result value.
    train_job = train(data=prep_job)
    print(train_job.get_result())
```

### 8. Array job with dependencies

```python
with Cluster(backend_type="ssh", hostname="hpc.example.com") as cluster:
    prep_job = preprocess(raw="data.csv")
    configs = [{"lr": 0.001}, {"lr": 0.01}, {"lr": 0.1}]
    # All array elements wait for prep_job before starting
    sweep = train.after(prep_job).map(configs)
    results = sweep.get_results()
```

### 9. Build a multi-step workflow

```python
from slurm import Cluster, task
from slurm.decorators import workflow
from slurm.workflow import WorkflowContext

@task(time="00:30:00")
def extract(source: str) -> dict:
    return {"records": 1000}

@task(time="01:00:00")
def transform(records: int) -> dict:
    return {"transformed": records}

@task(time="00:10:00")
def load(data: dict) -> str:
    return "success"

@workflow(time="02:00:00")
def etl_pipeline(source: str, ctx: WorkflowContext) -> str:
    extract_job = extract(source=source)
    transform_job = transform(records=extract_job)
    load_job = load(data=transform_job)
    return load_job.get_result()

with Cluster(backend_type="ssh", hostname="hpc.example.com") as cluster:
    wf = cluster.submit(etl_pipeline)
    job = wf(source="s3://bucket/data")
    print(job.get_result())
```

### 10. Use JobContext for distributed training info

```python
from slurm import task
from slurm.runtime import JobContext

@task(time="04:00:00", nodes=2, gpus_per_node=4)
def distributed_train(config: dict, ctx: JobContext) -> dict:
    # ctx is auto-injected when the task runs on the compute node
    env = ctx.torch_distributed_env()
    # env contains: MASTER_ADDR, MASTER_PORT, RANK, LOCAL_RANK, WORLD_SIZE
    import os
    os.environ.update(env)

    import torch.distributed as dist
    dist.init_process_group("nccl")

    return {"rank": ctx.rank, "world_size": ctx.world_size}
```

### 11. Custom callback for observability

```python
from slurm.callbacks import BaseCallback, SubmitEndContext, CompletedContext

class MyCallback(BaseCallback):
    requires_pickling = False  # Client-side only, no serialization needed

    def on_end_submit_job_ctx(self, ctx: SubmitEndContext) -> None:
        print(f"Job {ctx.job_id} submitted to {ctx.target_job_dir}")

    def on_completed_ctx(self, ctx: CompletedContext) -> None:
        print(f"Job {ctx.job_id} finished with state {ctx.job_state}")

with Cluster(
    backend_type="ssh",
    hostname="hpc.example.com",
    callbacks=[MyCallback()],
) as cluster:
    job = train(lr=0.001)
```

### 12. Local testing without a cluster

```python
# Call the underlying function directly, bypassing Slurm
result = train.unwrapped(lr=0.001)
print(result)  # No cluster needed, runs locally
```

### 13. Override task options at runtime

```python
with Cluster(backend_type="ssh", hostname="hpc.example.com") as cluster:
    # Submit with more time and memory than the decorator defaults
    big_job = train.with_options(time="08:00:00", mem="64G")(lr=0.001)
```

### 14. Configure cluster from Slurmfile

```toml
# Slurmfile (TOML) at project root
[environments.production]
backend_type = "ssh"
hostname = "login.hpc.example.com"
default_account = "my-account"
default_partition = "gpu"
job_base_dir = "~/slurm_jobs"

[environments.production.packaging]
type = "container"
registry = "registry.example.com/my-project"
platform = "linux/amd64"
```

```python
from slurm import Cluster

with Cluster.from_env("production") as cluster:
    job = train(lr=0.001)
```

### 15. Workflow with shared directory for inter-task data

```python
from pathlib import Path
from slurm.decorators import workflow
from slurm.workflow import WorkflowContext

@workflow(time="02:00:00")
def pipeline(ctx: WorkflowContext) -> str:
    shared = Path(ctx.shared_dir)
    (shared / "checkpoints").mkdir(exist_ok=True)

    prep_job = prepare(output_dir=str(shared / "data"))
    train_job = train.after(prep_job)(
        data_dir=str(shared / "data"),
        checkpoint_dir=str(shared / "checkpoints"),
    )
    return train_job.get_result()
```

### 16. Hyperparameter sweep with best result selection

```python
from slurm import Cluster, task

@task(time="01:00:00", gpus_per_node=1)
def experiment(lr: float, batch_size: int) -> dict:
    return {"lr": lr, "batch_size": batch_size, "accuracy": 0.9}

with Cluster(backend_type="ssh", hostname="hpc.example.com") as cluster:
    configs = [
        {"lr": lr, "batch_size": bs}
        for lr in [0.001, 0.01, 0.1]
        for bs in [32, 64, 128]
    ]
    sweep = experiment.map(configs)
    results = sweep.get_results()
    best = max(results, key=lambda r: r["accuracy"])
    print(f"Best config: {best}")
```

### 17. Stream job output while waiting

```python
from slurm import Cluster, task

@task(time="01:00:00")
def train(lr: float) -> dict:
    print("Training started...")
    return {"accuracy": 0.95}

with Cluster(backend_type="ssh", hostname="hpc.example.com") as cluster:
    job = train(lr=0.001)
    # tail() blocks until job completes, streaming output in real-time
    job.tail()
    # Job is now complete
    if job.is_successful():
        print(job.get_result())
```

### 18. Capture job output to a buffer or file

```python
import io
from slurm import Cluster, task

with Cluster(backend_type="ssh", hostname="hpc.example.com") as cluster:
    job = train(lr=0.001)

    # Capture to string buffer
    buf = io.StringIO()
    job.tail(output=buf, follow=True)
    log_content = buf.getvalue()

    # Or write to a file
    with open("job.log", "w") as f:
        job.tail(output=f, follow=True)
```

### 19. Get a job status snapshot

```python
from slurm import Cluster, task

with Cluster(backend_type="ssh", hostname="hpc.example.com") as cluster:
    job = train(lr=0.001)
    snap = job.snapshot(tail_lines=20)
    print(f"Job {snap.job_id}: {snap.state}")
    if snap.is_terminal and not snap.is_successful:
        print(snap.stderr_tail)
```

### 20. Parse a packaging config string

```python
from slurm import parse_packaging_config

config = parse_packaging_config("container:nvcr.io/nvidia/pytorch:24.01")
# {'type': 'container', 'image': 'nvcr.io/nvidia/pytorch:24.01'}

config = parse_packaging_config("wheel", {"registry": "my-registry.com"})
# {'type': 'wheel', 'registry': 'my-registry.com'}
```

## Public API Reference

### Exports from `slurm` (20 items)

| Name | Type | Description |
|------|------|-------------|
| `task` | decorator | Marks a function as a Slurm task with SBATCH options |
| `workflow` | decorator | Marks a function as a workflow orchestrator |
| `parse_packaging_config` | function | Parse packaging spec strings into config dicts |
| `Cluster` | class | Connection to a Slurm cluster; use as context manager |
| `Job` | class | A submitted Slurm job |
| `JobSnapshot` | dataclass | Frozen point-in-time snapshot of job state and output |
| `ArrayJob` | class | A submitted array job (from `.map()`) |
| `SlurmTask` | class | A decorated task function with submission methods |
| `JobContext` | dataclass | Runtime metadata injected into tasks on compute nodes |
| `WorkflowContext` | dataclass | Context injected into workflow functions |
| `PackagingConfig` | TypedDict | Valid keys for packaging configuration dicts |
| `BaseCallback` | class | Base class for lifecycle event callbacks |
| `LoggerCallback` | class | Logs job lifecycle events to Python logging |
| `BenchmarkCallback` | class | Records timing metrics for jobs |
| `RichLoggerCallback` | class | Rich terminal output for job progress |
| `SubmissionError` | exception | Job submission failed |
| `DownloadError` | exception | Failed to download job results |
| `BackendError` | exception | Backend communication error |
| `BackendTimeout` | exception | Backend operation timed out |
| `BackendCommandError` | exception | Backend command returned error |
| `PackagingError` | exception | Packaging preparation failed |
| `SlurmfileError` | exception | Slurmfile parsing or validation error |

### Exports from `slurm.callbacks` (17 items)

| Name | Type | Description |
|------|------|-------------|
| `BaseCallback` | class | Base class with 11 hook methods to override |
| `LoggerCallback` | class | Logs events to Python logging |
| `BenchmarkCallback` | class | Records timing metrics |
| `RichLoggerCallback` | class | Rich terminal progress display |
| `DebugCallback` | class | Enables debugpy debugging in Slurm jobs |
| `ExecutionLocus` | enum | CLIENT, RUNNER, or BOTH |
| `PackagingBeginContext` | dataclass | Context for `on_begin_package_ctx` |
| `PackagingEndContext` | dataclass | Context for `on_end_package_ctx` |
| `SubmitBeginContext` | dataclass | Context for `on_begin_submit_job_ctx` |
| `SubmitEndContext` | dataclass | Context for `on_end_submit_job_ctx` |
| `RunBeginContext` | dataclass | Context for `on_begin_run_job_ctx` |
| `RunEndContext` | dataclass | Context for `on_end_run_job_ctx` |
| `JobStatusUpdatedContext` | dataclass | Context for `on_job_status_update_ctx` |
| `CompletedContext` | dataclass | Context for `on_completed_ctx` |
| `WorkflowCallbackContext` | dataclass | Context for `on_workflow_begin_ctx` / `on_workflow_end_ctx` |
| `WorkflowTaskSubmitContext` | dataclass | Context for `on_workflow_task_submitted_ctx` |

### Key Method Signatures

```python
# Decorators
@task(time: str, mem: str = None, cpus_per_task: int = None,
      nodes: int = None, gpus_per_node: int = None,
      partition: str = None, account: str = None,
      packaging: str = None, **sbatch_options) -> SlurmTask

@workflow(time: str = "01:00:00", **sbatch_options) -> SlurmTask

# Cluster
Cluster(backend_type="ssh", callbacks=None, job_base_dir=None,
        default_packaging=None, default_account=None,
        default_partition=None, **backend_kwargs)
Cluster.from_env(env: str) -> Cluster          # Load from Slurmfile
Cluster.from_file(config_path: str) -> Cluster  # Load from TOML file
Cluster.submit(task_func, **overrides) -> Callable  # Two-phase submit

# SlurmTask (decorated function)
task_func(args, **kwargs) -> Job                # Submit (inside Cluster context)
task_func.map(items, max_concurrent=None) -> ArrayJob  # Submit array job
task_func.after(*dependencies) -> SlurmTask     # Bind dependencies
task_func.with_options(**overrides) -> SlurmTask # Override SBATCH options
task_func.unwrapped(*args, **kwargs)            # Call directly (no Slurm)

# Job
job.wait(timeout=None) -> bool         # Block until done; True=success
job.get_result(timeout=None) -> T      # Wait + download + deserialize result
job.tail(follow=True, stderr=False, lines=10, output=sys.stdout) -> None
job.snapshot(tail_lines=80) -> JobSnapshot  # Frozen state + output tails
job.cancel() -> None                   # Cancel the job
job.status() -> dict                   # Query current SLURM status
job.id -> str                          # SLURM job ID

# Utilities
parse_packaging_config(packaging: str, kwargs=None) -> Optional[dict]

# ArrayJob
array_job.get_results(timeout=None) -> List[T]  # Wait + collect all results
array_job.wait(timeout=None) -> bool             # Block until all done
array_job[i] -> Job                              # Get individual job

# JobContext (auto-injected into tasks via type annotation)
ctx.job_id -> str
ctx.rank -> int
ctx.local_rank -> int
ctx.world_size -> int
ctx.num_nodes -> int
ctx.gpus_per_node -> int
ctx.master_addr -> str
ctx.torch_distributed_env() -> dict    # Returns env vars for torch.distributed

# WorkflowContext (auto-injected into workflows)
ctx.cluster -> Cluster
ctx.shared_dir -> str                  # Shared directory for inter-task data
ctx.workflow_job_id -> str
```

## Common Import Patterns

```python
# Basic task submission
from slurm import Cluster, task

# Workflow orchestration
from slurm import Cluster, task
from slurm.decorators import workflow
from slurm.workflow import WorkflowContext

# Distributed training with JobContext
from slurm import task
from slurm.runtime import JobContext

# Custom callbacks
from slurm.callbacks import BaseCallback, SubmitEndContext, CompletedContext

# Cluster from Slurmfile
from slurm import Cluster
cluster = Cluster.from_env("production")
```

## Error Types

| Error | When it occurs |
|-------|---------------|
| `SubmissionError` | `sbatch` command failed or backend rejected the job |
| `PackagingError` | Container build, wheel build, or image push failed |
| `DownloadError` | Result file not found, deserialization failed, or SSH transfer error |
| `BackendError` | SSH connection failed or backend is unreachable |
| `BackendTimeout` | Backend operation exceeded timeout |
| `BackendCommandError` | Remote command (squeue, sacct, etc.) returned an error |
| `SlurmfileError` | Slurmfile TOML is invalid or missing required fields |

## SBATCH Options Reference

Common options passed to `@task()`:

| Parameter | SBATCH flag | Example |
|-----------|-------------|---------|
| `time` | `--time` | `"01:00:00"` |
| `mem` | `--mem` | `"16G"` |
| `cpus_per_task` | `--cpus-per-task` | `4` |
| `nodes` | `--nodes` | `2` |
| `ntasks` | `--ntasks` | `8` |
| `ntasks_per_node` | `--ntasks-per-node` | `4` |
| `gpus_per_node` | `--gpus-per-node` | `4` |
| `gpus` | `--gpus` | `8` |
| `partition` | `--partition` | `"gpu"` |
| `account` | `--account` | `"my-project"` |
| `exclusive` | `--exclusive` | `None` (flag) |
| `job_name` | `--job-name` | `"train-v2"` |
| `array` | `--array` | `"0-99%10"` |

Python snake_case is converted to SBATCH hyphen-case automatically.

## Links

- Source: https://github.com/ville-k/slurm_sdk
- Documentation: https://ville-k.github.io/slurm_sdk
- PyPI: https://pypi.org/project/slurm-sdk/
- Changelog: https://ville-k.github.io/slurm_sdk/CHANGELOG/
