Coverage for src/srunx/runner.py: 96%

141 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-21 09:02 +0000

1"""Workflow runner for executing YAML-defined workflows with SLURM""" 

2 

3import time 

4from collections import defaultdict 

5from collections.abc import Sequence 

6from concurrent.futures import ThreadPoolExecutor 

7from pathlib import Path 

8from typing import Any, Self 

9 

10import jinja2 

11import yaml 

12 

13from srunx.callbacks import Callback 

14from srunx.client import Slurm 

15from srunx.exceptions import WorkflowValidationError 

16from srunx.logging import get_logger 

17from srunx.models import ( 

18 Job, 

19 JobEnvironment, 

20 JobResource, 

21 JobStatus, 

22 RunnableJobType, 

23 ShellJob, 

24 Workflow, 

25) 

26 

27logger = get_logger(__name__) 

28 

29 

30class WorkflowRunner: 

31 """Runner for executing workflows defined in YAML with dynamic job scheduling. 

32 

33 Jobs are executed as soon as their dependencies are satisfied, 

34 rather than waiting for entire dependency levels to complete. 

35 """ 

36 

37 def __init__( 

38 self, 

39 workflow: Workflow, 

40 callbacks: Sequence[Callback] | None = None, 

41 args: dict[str, Any] | None = None, 

42 ) -> None: 

43 """Initialize workflow runner. 

44 

45 Args: 

46 workflow: Workflow to execute. 

47 callbacks: List of callbacks for job notifications. 

48 args: Template variables from the YAML args section. 

49 """ 

50 self.workflow = workflow 

51 self.slurm = Slurm(callbacks=callbacks) 

52 self.callbacks = callbacks or [] 

53 self.args = args or {} 

54 

55 @classmethod 

56 def from_yaml( 

57 cls, yaml_path: str | Path, callbacks: Sequence[Callback] | None = None 

58 ) -> Self: 

59 """Load and validate a workflow from a YAML file. 

60 

61 Args: 

62 yaml_path: Path to the YAML workflow definition file. 

63 callbacks: List of callbacks for job notifications. 

64 

65 Returns: 

66 WorkflowRunner instance with loaded workflow. 

67 

68 Raises: 

69 FileNotFoundError: If the YAML file doesn't exist. 

70 yaml.YAMLError: If the YAML is malformed. 

71 ValidationError: If the workflow structure is invalid. 

72 """ 

73 yaml_file = Path(yaml_path) 

74 if not yaml_file.exists(): 

75 raise FileNotFoundError(f"Workflow file not found: {yaml_path}") 

76 

77 with open(yaml_file, encoding="utf-8") as f: 

78 data = yaml.safe_load(f) 

79 

80 name = data.get("name", "unnamed") 

81 args = data.get("args", {}) 

82 jobs_data = data.get("jobs", []) 

83 

84 # Render Jinja templates in jobs_data using args 

85 rendered_jobs_data = cls._render_jobs_with_args(jobs_data, args) 

86 

87 jobs = [] 

88 for job_data in rendered_jobs_data: 

89 job = cls.parse_job(job_data) 

90 jobs.append(job) 

91 return cls( 

92 workflow=Workflow(name=name, jobs=jobs), callbacks=callbacks, args=args 

93 ) 

94 

95 @staticmethod 

96 def _render_jobs_with_args( 

97 jobs_data: list[dict[str, Any]], args: dict[str, Any] 

98 ) -> list[dict[str, Any]]: 

99 """Render Jinja templates in job data using args. 

100 

101 Args: 

102 jobs_data: List of job configurations from YAML. 

103 args: Template variables from the YAML args section. 

104 

105 Returns: 

106 List of job configurations with rendered templates. 

107 """ 

108 if not args: 

109 return jobs_data 

110 

111 # Convert jobs_data to YAML string, render as template, then parse back 

112 jobs_yaml = yaml.dump(jobs_data, default_flow_style=False) 

113 template = jinja2.Template(jobs_yaml, undefined=jinja2.StrictUndefined) 

114 

115 try: 

116 rendered_yaml = template.render(args) 

117 return yaml.safe_load(rendered_yaml) 

118 except jinja2.TemplateError as e: 

119 logger.error(f"Jinja template rendering failed: {e}") 

120 raise WorkflowValidationError(f"Template rendering failed: {e}") from e 

121 

122 def get_independent_jobs(self) -> list[RunnableJobType]: 

123 """Get all jobs that are independent of any other job.""" 

124 independent_jobs = [] 

125 for job in self.workflow.jobs: 

126 if not job.depends_on: 

127 independent_jobs.append(job) 

128 return independent_jobs 

129 

130 def run(self) -> dict[str, RunnableJobType]: 

131 """Run a workflow with dynamic job scheduling. 

132 

133 Jobs are executed as soon as their dependencies are satisfied. 

134 

135 Returns: 

136 Dictionary mapping job names to completed Job instances. 

137 """ 

138 logger.info( 

139 f"🚀 Starting Workflow {self.workflow.name} with {len(self.workflow.jobs)} jobs" 

140 ) 

141 for callback in self.callbacks: 

142 callback.on_workflow_started(self.workflow) 

143 

144 # Track all jobs and results 

145 all_jobs = self.workflow.jobs.copy() 

146 results: dict[str, RunnableJobType] = {} 

147 running_futures: dict[str, Any] = {} 

148 

149 # Build reverse dependency map for efficient lookups 

150 dependents = defaultdict(set) 

151 for job in all_jobs: 

152 for dep in job.depends_on: 

153 dependents[dep].add(job.name) 

154 

155 def execute_job(job: RunnableJobType) -> RunnableJobType: 

156 """Execute a single job.""" 

157 logger.info(f"🌋 {'SUBMITTED':<12} Job {job.name:<12}") 

158 

159 try: 

160 result = self.slurm.run(job) 

161 return result 

162 except Exception as e: 

163 raise 

164 

165 def on_job_complete(job_name: str, result: RunnableJobType) -> list[str]: 

166 """Handle job completion and return newly ready job names.""" 

167 results[job_name] = result 

168 completed_job_names = list(set(results.keys())) 

169 

170 # Find newly ready jobs 

171 newly_ready = [] 

172 for dependent_name in dependents[job_name]: 

173 dependent_job = next(j for j in all_jobs if j.name == dependent_name) 

174 if ( 

175 dependent_job.status == JobStatus.PENDING 

176 and dependent_job.dependencies_satisfied(completed_job_names) 

177 ): 

178 newly_ready.append(dependent_name) 

179 

180 return newly_ready 

181 

182 # Execute workflow with ThreadPoolExecutor 

183 with ThreadPoolExecutor(max_workers=8) as executor: 

184 # Submit initial ready jobs 

185 initial_jobs = self.get_independent_jobs() 

186 

187 for job in initial_jobs: 

188 future = executor.submit(execute_job, job) 

189 running_futures[job.name] = future 

190 

191 # Process completed jobs and schedule new ones 

192 while running_futures: 

193 # Check for completed futures 

194 completed = [] 

195 for job_name, future in list(running_futures.items()): 

196 if future.done(): 

197 completed.append((job_name, future)) 

198 del running_futures[job_name] 

199 

200 if not completed: 

201 time.sleep(0.1) # Brief sleep to avoid busy waiting 

202 continue 

203 

204 # Handle completed jobs 

205 for job_name, future in completed: 

206 try: 

207 result = future.result() 

208 newly_ready_names = on_job_complete(job_name, result) 

209 

210 # Schedule newly ready jobs 

211 for ready_name in newly_ready_names: 

212 if ready_name not in running_futures: 

213 ready_job = next( 

214 j for j in all_jobs if j.name == ready_name 

215 ) 

216 new_future = executor.submit(execute_job, ready_job) 

217 running_futures[ready_name] = new_future 

218 

219 except Exception as e: 

220 logger.error(f"❌ Job {job_name} failed: {e}") 

221 raise 

222 

223 # Verify all jobs completed successfully 

224 failed_jobs = [j.name for j in all_jobs if j.status == JobStatus.FAILED] 

225 incomplete_jobs = [ 

226 j.name 

227 for j in all_jobs 

228 if j.status not in [JobStatus.COMPLETED, JobStatus.FAILED] 

229 ] 

230 

231 if failed_jobs: 

232 logger.error(f"❌ Jobs failed: {failed_jobs}") 

233 raise RuntimeError(f"Workflow execution failed: {failed_jobs}") 

234 

235 if incomplete_jobs: 

236 logger.error(f"❌ Jobs did not complete: {incomplete_jobs}") 

237 raise RuntimeError(f"Workflow execution incomplete: {incomplete_jobs}") 

238 

239 logger.success(f"🎉 Workflow {self.workflow.name} completed!!") 

240 

241 for callback in self.callbacks: 

242 callback.on_workflow_completed(self.workflow) 

243 

244 return results 

245 

246 def execute_from_yaml(self, yaml_path: str | Path) -> dict[str, RunnableJobType]: 

247 """Load and execute a workflow from YAML file. 

248 

249 Args: 

250 yaml_path: Path to YAML workflow file. 

251 

252 Returns: 

253 Dictionary mapping job names to completed Job instances. 

254 """ 

255 logger.info(f"Loading workflow from {yaml_path}") 

256 runner = self.from_yaml(yaml_path) 

257 return runner.run() 

258 

259 @staticmethod 

260 def parse_job(data: dict[str, Any]) -> RunnableJobType: 

261 if data.get("path") and data.get("command"): 

262 raise WorkflowValidationError("Job cannot have both 'path' and 'command'") 

263 

264 base = {"name": data["name"], "depends_on": data.get("depends_on", [])} 

265 

266 if data.get("path"): 

267 return ShellJob.model_validate({**base, "path": data["path"]}) 

268 

269 resource = JobResource.model_validate(data.get("resources", {})) 

270 environment = JobEnvironment.model_validate(data.get("environment", {})) 

271 

272 job_data = { 

273 **base, 

274 "command": data["command"], 

275 "resources": resource, 

276 "environment": environment, 

277 } 

278 if data.get("log_dir"): 

279 job_data["log_dir"] = data["log_dir"] 

280 if data.get("work_dir"): 

281 job_data["work_dir"] = data["work_dir"] 

282 

283 return Job.model_validate(job_data) 

284 

285 

286def run_workflow_from_file(yaml_path: str | Path) -> dict[str, RunnableJobType]: 

287 """Convenience function to run workflow from YAML file. 

288 

289 Args: 

290 yaml_path: Path to YAML workflow file. 

291 

292 Returns: 

293 Dictionary mapping job names to completed Job instances. 

294 """ 

295 runner = WorkflowRunner.from_yaml(yaml_path) 

296 return runner.run()