Coverage for src/srunx/workflows/runner.py: 99%
75 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-17 20:31 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-17 20:31 +0900
1"""Workflow runner for executing YAML-defined workflows with SLURM and Prefect."""
3from pathlib import Path
4from typing import Any
6import yaml
7from prefect import flow
8from prefect.states import State
9from srunx.logging import get_logger
10from srunx.models import (
11 Job,
12 JobEnvironment,
13 JobResource,
14 ShellJob,
15 Workflow,
16 WorkflowTask,
17)
18from srunx.workflows.tasks import submit_and_monitor_job, submit_job_async
20logger = get_logger(__name__)
23class WorkflowRunner:
24 """Runner for executing workflows defined in YAML."""
26 def __init__(self) -> None:
27 """Initialize workflow runner."""
28 self.executed_tasks: dict[str, State[Job | ShellJob]] = {}
30 def load_from_yaml(self, yaml_path: str | Path) -> Workflow:
31 """Load and validate a workflow from a YAML file.
33 Args:
34 yaml_path: Path to the YAML workflow definition file.
36 Returns:
37 Validated Workflow object.
39 Raises:
40 FileNotFoundError: If the YAML file doesn't exist.
41 yaml.YAMLError: If the YAML is malformed.
42 ValidationError: If the workflow structure is invalid.
43 """
44 yaml_file = Path(yaml_path)
45 if not yaml_file.exists():
46 raise FileNotFoundError(f"Workflow file not found: {yaml_path}")
48 with open(yaml_file, encoding="utf-8") as f:
49 data = yaml.safe_load(f)
51 return self._parse_workflow_data(data)
53 def _parse_workflow_data(self, data: dict) -> Workflow:
54 """Parse workflow data from dictionary."""
55 workflow_name = data.get("name", "unnamed_workflow")
56 tasks_data = data.get("tasks", [])
58 tasks = []
59 for task_data in tasks_data:
60 task = self._parse_task_data(task_data)
61 tasks.append(task)
63 return Workflow(name=workflow_name, tasks=tasks)
65 def _parse_task_data(self, task_data: dict) -> WorkflowTask:
66 """Parse a single task from dictionary."""
67 # Basic task properties
68 name = task_data["name"]
69 file = task_data.get("file")
70 async_execution = task_data.get("async", False)
71 depends_on = task_data.get("depends_on", [])
73 job_data: dict[str, Any] = {"name": name}
74 if task_data.get("log_dir") is not None:
75 job_data["log_dir"] = task_data.get("log_dir")
76 if task_data.get("work_dir") is not None:
77 job_data["work_dir"] = task_data.get("work_dir")
79 if file:
80 job_data |= {
81 "file": file,
82 "resources": None,
83 "environment": None,
84 }
85 else:
86 command = task_data.get("command")
88 # Resource configuration
89 resources = JobResource(
90 nodes=task_data.get("nodes", 1),
91 gpus_per_node=task_data.get("gpus_per_node", 0),
92 ntasks_per_node=task_data.get("ntasks_per_node", 1),
93 cpus_per_task=task_data.get("cpus_per_task", 1),
94 memory_per_node=task_data.get("memory_per_node"),
95 time_limit=task_data.get("time_limit"),
96 )
98 # Environment configuration
99 environment = JobEnvironment(
100 conda=task_data.get("conda"),
101 venv=task_data.get("venv"),
102 sqsh=task_data.get("sqsh") or task_data.get("container"),
103 env_vars=task_data.get("env_vars", {}),
104 )
106 job_data |= {
107 "command": command,
108 "resources": resources,
109 "environment": environment,
110 }
112 # Create job
113 job = Job.model_validate(job_data)
115 return WorkflowTask(
116 name=name,
117 job=job,
118 depends_on=depends_on,
119 async_execution=async_execution,
120 )
122 def execute_workflow(self, workflow: Workflow) -> dict[str, State[Job | ShellJob]]:
123 """Execute a workflow using Prefect.
125 Args:
126 workflow: Workflow to execute.
128 Returns:
129 Dictionary mapping task names to Job instances.
130 """
131 task_map = {task.name: task for task in workflow.tasks}
133 @flow(name=workflow.name)
134 def workflow_flow() -> dict[str, State[Job | ShellJob]]:
135 """Prefect flow for workflow execution."""
137 def execute_task(task_name: str) -> State[Job | ShellJob]:
138 """Execute a task and its dependencies recursively."""
139 if task_name in self.executed_tasks:
140 return self.executed_tasks[task_name]
142 task = task_map[task_name]
144 # Execute dependencies first
145 for dependency in task.depends_on:
146 execute_task(dependency)
148 # Execute the task
149 if task.async_execution:
150 job_future = submit_job_async(task.job)
151 else:
152 job_future = submit_and_monitor_job(task.job)
154 self.executed_tasks[task_name] = job_future
155 return job_future
157 # Execute all tasks
158 results: dict[str, State[Job | ShellJob]] = {}
159 for task in workflow.tasks:
160 results[task.name] = execute_task(task.name)
162 return results
164 return workflow_flow()
166 def execute_from_yaml(
167 self, yaml_path: str | Path
168 ) -> dict[str, State[Job | ShellJob]]:
169 """Load and execute a workflow from YAML file.
171 Args:
172 yaml_path: Path to YAML workflow file.
174 Returns:
175 Dictionary mapping task names to Job instances.
176 """
177 logger.info(f"Loading workflow from {yaml_path}")
178 workflow = self.load_from_yaml(yaml_path)
180 logger.info(
181 f"Executing workflow '{workflow.name}' with {len(workflow.tasks)} tasks"
182 )
183 results = self.execute_workflow(workflow)
185 logger.info("Workflow execution completed")
186 return results
189def run_workflow_from_file(yaml_path: str | Path) -> dict[str, State[Job | ShellJob]]:
190 """Convenience function to run workflow from YAML file.
192 Args:
193 yaml_path: Path to YAML workflow file.
195 Returns:
196 Dictionary mapping task names to Job instances.
197 """
198 runner = WorkflowRunner()
199 return runner.execute_from_yaml(yaml_path)