Coverage for src/srunx/runner.py: 96%
141 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-21 09:02 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-21 09:02 +0000
1"""Workflow runner for executing YAML-defined workflows with SLURM"""
3import time
4from collections import defaultdict
5from collections.abc import Sequence
6from concurrent.futures import ThreadPoolExecutor
7from pathlib import Path
8from typing import Any, Self
10import jinja2
11import yaml
13from srunx.callbacks import Callback
14from srunx.client import Slurm
15from srunx.exceptions import WorkflowValidationError
16from srunx.logging import get_logger
17from srunx.models import (
18 Job,
19 JobEnvironment,
20 JobResource,
21 JobStatus,
22 RunnableJobType,
23 ShellJob,
24 Workflow,
25)
27logger = get_logger(__name__)
30class WorkflowRunner:
31 """Runner for executing workflows defined in YAML with dynamic job scheduling.
33 Jobs are executed as soon as their dependencies are satisfied,
34 rather than waiting for entire dependency levels to complete.
35 """
37 def __init__(
38 self,
39 workflow: Workflow,
40 callbacks: Sequence[Callback] | None = None,
41 args: dict[str, Any] | None = None,
42 ) -> None:
43 """Initialize workflow runner.
45 Args:
46 workflow: Workflow to execute.
47 callbacks: List of callbacks for job notifications.
48 args: Template variables from the YAML args section.
49 """
50 self.workflow = workflow
51 self.slurm = Slurm(callbacks=callbacks)
52 self.callbacks = callbacks or []
53 self.args = args or {}
55 @classmethod
56 def from_yaml(
57 cls, yaml_path: str | Path, callbacks: Sequence[Callback] | None = None
58 ) -> Self:
59 """Load and validate a workflow from a YAML file.
61 Args:
62 yaml_path: Path to the YAML workflow definition file.
63 callbacks: List of callbacks for job notifications.
65 Returns:
66 WorkflowRunner instance with loaded workflow.
68 Raises:
69 FileNotFoundError: If the YAML file doesn't exist.
70 yaml.YAMLError: If the YAML is malformed.
71 ValidationError: If the workflow structure is invalid.
72 """
73 yaml_file = Path(yaml_path)
74 if not yaml_file.exists():
75 raise FileNotFoundError(f"Workflow file not found: {yaml_path}")
77 with open(yaml_file, encoding="utf-8") as f:
78 data = yaml.safe_load(f)
80 name = data.get("name", "unnamed")
81 args = data.get("args", {})
82 jobs_data = data.get("jobs", [])
84 # Render Jinja templates in jobs_data using args
85 rendered_jobs_data = cls._render_jobs_with_args(jobs_data, args)
87 jobs = []
88 for job_data in rendered_jobs_data:
89 job = cls.parse_job(job_data)
90 jobs.append(job)
91 return cls(
92 workflow=Workflow(name=name, jobs=jobs), callbacks=callbacks, args=args
93 )
95 @staticmethod
96 def _render_jobs_with_args(
97 jobs_data: list[dict[str, Any]], args: dict[str, Any]
98 ) -> list[dict[str, Any]]:
99 """Render Jinja templates in job data using args.
101 Args:
102 jobs_data: List of job configurations from YAML.
103 args: Template variables from the YAML args section.
105 Returns:
106 List of job configurations with rendered templates.
107 """
108 if not args:
109 return jobs_data
111 # Convert jobs_data to YAML string, render as template, then parse back
112 jobs_yaml = yaml.dump(jobs_data, default_flow_style=False)
113 template = jinja2.Template(jobs_yaml, undefined=jinja2.StrictUndefined)
115 try:
116 rendered_yaml = template.render(args)
117 return yaml.safe_load(rendered_yaml)
118 except jinja2.TemplateError as e:
119 logger.error(f"Jinja template rendering failed: {e}")
120 raise WorkflowValidationError(f"Template rendering failed: {e}") from e
122 def get_independent_jobs(self) -> list[RunnableJobType]:
123 """Get all jobs that are independent of any other job."""
124 independent_jobs = []
125 for job in self.workflow.jobs:
126 if not job.depends_on:
127 independent_jobs.append(job)
128 return independent_jobs
130 def run(self) -> dict[str, RunnableJobType]:
131 """Run a workflow with dynamic job scheduling.
133 Jobs are executed as soon as their dependencies are satisfied.
135 Returns:
136 Dictionary mapping job names to completed Job instances.
137 """
138 logger.info(
139 f"🚀 Starting Workflow {self.workflow.name} with {len(self.workflow.jobs)} jobs"
140 )
141 for callback in self.callbacks:
142 callback.on_workflow_started(self.workflow)
144 # Track all jobs and results
145 all_jobs = self.workflow.jobs.copy()
146 results: dict[str, RunnableJobType] = {}
147 running_futures: dict[str, Any] = {}
149 # Build reverse dependency map for efficient lookups
150 dependents = defaultdict(set)
151 for job in all_jobs:
152 for dep in job.depends_on:
153 dependents[dep].add(job.name)
155 def execute_job(job: RunnableJobType) -> RunnableJobType:
156 """Execute a single job."""
157 logger.info(f"🌋 {'SUBMITTED':<12} Job {job.name:<12}")
159 try:
160 result = self.slurm.run(job)
161 return result
162 except Exception as e:
163 raise
165 def on_job_complete(job_name: str, result: RunnableJobType) -> list[str]:
166 """Handle job completion and return newly ready job names."""
167 results[job_name] = result
168 completed_job_names = list(set(results.keys()))
170 # Find newly ready jobs
171 newly_ready = []
172 for dependent_name in dependents[job_name]:
173 dependent_job = next(j for j in all_jobs if j.name == dependent_name)
174 if (
175 dependent_job.status == JobStatus.PENDING
176 and dependent_job.dependencies_satisfied(completed_job_names)
177 ):
178 newly_ready.append(dependent_name)
180 return newly_ready
182 # Execute workflow with ThreadPoolExecutor
183 with ThreadPoolExecutor(max_workers=8) as executor:
184 # Submit initial ready jobs
185 initial_jobs = self.get_independent_jobs()
187 for job in initial_jobs:
188 future = executor.submit(execute_job, job)
189 running_futures[job.name] = future
191 # Process completed jobs and schedule new ones
192 while running_futures:
193 # Check for completed futures
194 completed = []
195 for job_name, future in list(running_futures.items()):
196 if future.done():
197 completed.append((job_name, future))
198 del running_futures[job_name]
200 if not completed:
201 time.sleep(0.1) # Brief sleep to avoid busy waiting
202 continue
204 # Handle completed jobs
205 for job_name, future in completed:
206 try:
207 result = future.result()
208 newly_ready_names = on_job_complete(job_name, result)
210 # Schedule newly ready jobs
211 for ready_name in newly_ready_names:
212 if ready_name not in running_futures:
213 ready_job = next(
214 j for j in all_jobs if j.name == ready_name
215 )
216 new_future = executor.submit(execute_job, ready_job)
217 running_futures[ready_name] = new_future
219 except Exception as e:
220 logger.error(f"❌ Job {job_name} failed: {e}")
221 raise
223 # Verify all jobs completed successfully
224 failed_jobs = [j.name for j in all_jobs if j.status == JobStatus.FAILED]
225 incomplete_jobs = [
226 j.name
227 for j in all_jobs
228 if j.status not in [JobStatus.COMPLETED, JobStatus.FAILED]
229 ]
231 if failed_jobs:
232 logger.error(f"❌ Jobs failed: {failed_jobs}")
233 raise RuntimeError(f"Workflow execution failed: {failed_jobs}")
235 if incomplete_jobs:
236 logger.error(f"❌ Jobs did not complete: {incomplete_jobs}")
237 raise RuntimeError(f"Workflow execution incomplete: {incomplete_jobs}")
239 logger.success(f"🎉 Workflow {self.workflow.name} completed!!")
241 for callback in self.callbacks:
242 callback.on_workflow_completed(self.workflow)
244 return results
246 def execute_from_yaml(self, yaml_path: str | Path) -> dict[str, RunnableJobType]:
247 """Load and execute a workflow from YAML file.
249 Args:
250 yaml_path: Path to YAML workflow file.
252 Returns:
253 Dictionary mapping job names to completed Job instances.
254 """
255 logger.info(f"Loading workflow from {yaml_path}")
256 runner = self.from_yaml(yaml_path)
257 return runner.run()
259 @staticmethod
260 def parse_job(data: dict[str, Any]) -> RunnableJobType:
261 if data.get("path") and data.get("command"):
262 raise WorkflowValidationError("Job cannot have both 'path' and 'command'")
264 base = {"name": data["name"], "depends_on": data.get("depends_on", [])}
266 if data.get("path"):
267 return ShellJob.model_validate({**base, "path": data["path"]})
269 resource = JobResource.model_validate(data.get("resources", {}))
270 environment = JobEnvironment.model_validate(data.get("environment", {}))
272 job_data = {
273 **base,
274 "command": data["command"],
275 "resources": resource,
276 "environment": environment,
277 }
278 if data.get("log_dir"):
279 job_data["log_dir"] = data["log_dir"]
280 if data.get("work_dir"):
281 job_data["work_dir"] = data["work_dir"]
283 return Job.model_validate(job_data)
286def run_workflow_from_file(yaml_path: str | Path) -> dict[str, RunnableJobType]:
287 """Convenience function to run workflow from YAML file.
289 Args:
290 yaml_path: Path to YAML workflow file.
292 Returns:
293 Dictionary mapping job names to completed Job instances.
294 """
295 runner = WorkflowRunner.from_yaml(yaml_path)
296 return runner.run()