Coverage for src/srunx/models.py: 82%

230 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-13 23:46 +0000

1"""Data models for SLURM job management.""" 

2 

3import os 

4import subprocess 

5import time 

6from enum import Enum 

7from pathlib import Path 

8from typing import Self 

9 

10import jinja2 

11from pydantic import BaseModel, Field, PrivateAttr, model_validator 

12 

13from srunx.exceptions import WorkflowValidationError 

14from srunx.logging import get_logger 

15 

16logger = get_logger(__name__) 

17 

18 

19def _get_config_defaults(): 

20 """Get configuration defaults, with lazy import to avoid circular dependencies.""" 

21 try: 

22 from srunx.config import get_config 

23 

24 return get_config() 

25 except ImportError: 

26 # Fallback if config module is not available 

27 return None 

28 

29 

30def _default_nodes(): 

31 """Get default nodes from config.""" 

32 config = _get_config_defaults() 

33 return config.resources.nodes if config else 1 

34 

35 

36def _default_gpus_per_node(): 

37 """Get default GPUs per node from config.""" 

38 config = _get_config_defaults() 

39 return config.resources.gpus_per_node if config else 0 

40 

41 

42def _default_ntasks_per_node(): 

43 """Get default ntasks per node from config.""" 

44 config = _get_config_defaults() 

45 return config.resources.ntasks_per_node if config else 1 

46 

47 

48def _default_cpus_per_task(): 

49 """Get default CPUs per task from config.""" 

50 config = _get_config_defaults() 

51 return config.resources.cpus_per_task if config else 1 

52 

53 

54def _default_memory_per_node(): 

55 """Get default memory per node from config.""" 

56 config = _get_config_defaults() 

57 return config.resources.memory_per_node if config else None 

58 

59 

60def _default_time_limit(): 

61 """Get default time limit from config.""" 

62 config = _get_config_defaults() 

63 return config.resources.time_limit if config else None 

64 

65 

66def _default_nodelist(): 

67 """Get default nodelist from config.""" 

68 config = _get_config_defaults() 

69 return config.resources.nodelist if config else None 

70 

71 

72def _default_partition(): 

73 """Get default partition from config.""" 

74 config = _get_config_defaults() 

75 return config.resources.partition if config else None 

76 

77 

78def _default_conda(): 

79 """Get default conda environment from config.""" 

80 config = _get_config_defaults() 

81 return config.environment.conda if config else None 

82 

83 

84def _default_venv(): 

85 """Get default venv path from config.""" 

86 config = _get_config_defaults() 

87 return config.environment.venv if config else None 

88 

89 

90def _default_sqsh(): 

91 """Get default sqsh path from config.""" 

92 config = _get_config_defaults() 

93 return config.environment.sqsh if config else None 

94 

95 

96def _default_env_vars(): 

97 """Get default environment variables from config.""" 

98 config = _get_config_defaults() 

99 return config.environment.env_vars if config else {} 

100 

101 

102def _default_log_dir(): 

103 """Get default log directory from config.""" 

104 config = _get_config_defaults() 

105 return config.log_dir if config else os.getenv("SLURM_LOG_DIR", "logs") 

106 

107 

108def _default_work_dir(): 

109 """Get default work directory from config.""" 

110 config = _get_config_defaults() 

111 return config.work_dir if config else None 

112 

113 

114class JobStatus(Enum): 

115 """Job status enumeration for both SLURM jobs and workflow jobs.""" 

116 

117 UNKNOWN = "UNKNOWN" 

118 PENDING = "PENDING" 

119 RUNNING = "RUNNING" 

120 COMPLETED = "COMPLETED" 

121 FAILED = "FAILED" 

122 CANCELLED = "CANCELLED" 

123 TIMEOUT = "TIMEOUT" 

124 

125 

126class JobResource(BaseModel): 

127 """SLURM resource allocation requirements.""" 

128 

129 nodes: int = Field( 

130 default_factory=_default_nodes, ge=1, description="Number of compute nodes" 

131 ) 

132 gpus_per_node: int = Field( 

133 default_factory=_default_gpus_per_node, 

134 ge=0, 

135 description="Number of GPUs per node", 

136 ) 

137 ntasks_per_node: int = Field( 

138 default_factory=_default_ntasks_per_node, 

139 ge=1, 

140 description="Number of jobs per node", 

141 ) 

142 cpus_per_task: int = Field( 

143 default_factory=_default_cpus_per_task, 

144 ge=1, 

145 description="Number of CPUs per task", 

146 ) 

147 memory_per_node: str | None = Field( 

148 default_factory=_default_memory_per_node, 

149 description="Memory per node (e.g., '32GB')", 

150 ) 

151 time_limit: str | None = Field( 

152 default_factory=_default_time_limit, description="Time limit (e.g., '1:00:00')" 

153 ) 

154 nodelist: str | None = Field( 

155 default_factory=_default_nodelist, 

156 description="Specific nodes to use (e.g., 'node001,node002')", 

157 ) 

158 partition: str | None = Field( 

159 default_factory=_default_partition, 

160 description="SLURM partition to use (e.g., 'gpu', 'cpu')", 

161 ) 

162 

163 

164class JobEnvironment(BaseModel): 

165 """Job environment configuration.""" 

166 

167 conda: str | None = Field( 

168 default_factory=_default_conda, description="Conda environment name" 

169 ) 

170 venv: str | None = Field( 

171 default_factory=_default_venv, description="Virtual environment path" 

172 ) 

173 sqsh: str | None = Field( 

174 default_factory=_default_sqsh, description="SquashFS image path" 

175 ) 

176 env_vars: dict[str, str] = Field( 

177 default_factory=_default_env_vars, description="Environment variables" 

178 ) 

179 

180 @model_validator(mode="after") 

181 def validate_environment(self) -> Self: 

182 envs = [self.conda, self.venv, self.sqsh] 

183 non_none_count = sum(x is not None for x in envs) 

184 if non_none_count != 1: 

185 raise ValueError("Exactly one of 'conda', 'venv', or 'sqsh' must be set") 

186 return self 

187 

188 

189class BaseJob(BaseModel): 

190 name: str = Field(default="job", description="Job name") 

191 job_id: int | None = Field(default=None, description="SLURM job ID") 

192 depends_on: list[str] = Field( 

193 default_factory=list, description="Task dependencies for workflow execution" 

194 ) 

195 

196 _status: JobStatus = PrivateAttr(default=JobStatus.PENDING) 

197 

198 @property 

199 def status(self) -> JobStatus: 

200 """ 

201 Accessing ``job.status`` always triggers a lightweight refresh 

202 (only if we have a ``job_id`` and the status isn't terminal). 

203 """ 

204 if self.job_id is not None and self._status not in { 

205 JobStatus.COMPLETED, 

206 JobStatus.FAILED, 

207 JobStatus.CANCELLED, 

208 JobStatus.TIMEOUT, 

209 }: 

210 self.refresh() 

211 return self._status 

212 

213 @status.setter 

214 def status(self, value: JobStatus) -> None: 

215 self._status = value 

216 

217 def refresh(self, retries: int = 3) -> Self: 

218 """Query sacct and update ``_status`` in-place.""" 

219 if self.job_id is None: 

220 return self 

221 

222 for retry in range(retries): 

223 try: 

224 result = subprocess.run( 

225 [ 

226 "sacct", 

227 "-j", 

228 str(self.job_id), 

229 "--format", 

230 "JobID,State", 

231 "--noheader", 

232 "--parsable2", 

233 ], 

234 capture_output=True, 

235 text=True, 

236 check=True, 

237 ) 

238 except subprocess.CalledProcessError as e: 

239 logger.error(f"Failed to query job {self.job_id}: {e}") 

240 raise 

241 

242 line = result.stdout.strip().split("\n")[0] if result.stdout.strip() else "" 

243 if not line: 

244 if retry < retries - 1: 

245 time.sleep(1) 

246 continue 

247 self._status = JobStatus.UNKNOWN 

248 return self 

249 break 

250 

251 _, state = line.split("|", 1) 

252 self._status = JobStatus(state) 

253 return self 

254 

255 def dependencies_satisfied(self, completed_job_names: list[str]) -> bool: 

256 """All dependencies are completed & this job is still pending.""" 

257 return self.status == JobStatus.PENDING and all( 

258 dep in completed_job_names for dep in self.depends_on 

259 ) 

260 

261 

262class Job(BaseJob): 

263 """Represents a SLURM job with complete configuration.""" 

264 

265 command: list[str] = Field(description="Command to execute") 

266 resources: JobResource = Field( 

267 default_factory=JobResource, description="Resource requirements" 

268 ) 

269 environment: JobEnvironment = Field( 

270 default_factory=JobEnvironment, description="Environment setup" 

271 ) 

272 log_dir: str = Field( 

273 default_factory=_default_log_dir, 

274 description="Directory for log files", 

275 ) 

276 work_dir: str = Field( 

277 default_factory=lambda: _default_work_dir() or os.getcwd(), 

278 description="Working directory", 

279 ) 

280 

281 

282class ShellJob(BaseJob): 

283 path: str = Field(description="Shell script path to execute") 

284 

285 

286type JobType = BaseJob | Job | ShellJob 

287type RunnableJobType = Job | ShellJob 

288 

289 

290class Workflow: 

291 """Represents a workflow containing multiple jobs with dependencies.""" 

292 

293 def __init__(self, name: str, jobs: list[RunnableJobType] | None = None) -> None: 

294 if jobs is None: 

295 jobs = [] 

296 

297 self.name = name 

298 self.jobs = jobs 

299 

300 def add(self, job: RunnableJobType) -> None: 

301 # Check if job already exists 

302 if job.depends_on: 

303 for dep in job.depends_on: 

304 if dep not in self.jobs: 

305 raise WorkflowValidationError( 

306 f"Job '{job.name}' depends on unknown job '{dep}'" 

307 ) 

308 self.jobs.append(job) 

309 

310 def remove(self, job: RunnableJobType) -> None: 

311 self.jobs.remove(job) 

312 

313 def get(self, name: str) -> RunnableJobType | None: 

314 """Get a job by name.""" 

315 for job in self.jobs: 

316 if job.name == name: 

317 return job.refresh() 

318 return None 

319 

320 def get_dependencies(self, job_name: str) -> list[str]: 

321 """Get dependencies for a specific job.""" 

322 job = self.get(job_name) 

323 return job.depends_on if job else [] 

324 

325 def show(self): 

326 msg = f"""\ 

327{" PLAN ":=^80} 

328Workflow: {self.name} 

329Jobs: {len(self.jobs)} 

330""" 

331 

332 def add_indent(indent: int, msg: str) -> str: 

333 return " " * indent + msg 

334 

335 for job in self.jobs: 

336 msg += add_indent(1, f"Job: {job.name}\n") 

337 if isinstance(job, Job): 

338 msg += add_indent( 

339 2, f"{'Command:': <13} {' '.join(job.command or [])}\n" 

340 ) 

341 msg += add_indent( 

342 2, 

343 f"{'Resources:': <13} {job.resources.nodes} nodes, {job.resources.gpus_per_node} GPUs/node\n", 

344 ) 

345 if job.environment.conda: 

346 msg += add_indent( 

347 2, f"{'Conda env:': <13} {job.environment.conda}\n" 

348 ) 

349 if job.environment.sqsh: 

350 msg += add_indent(2, f"{'Sqsh:': <13} {job.environment.sqsh}\n") 

351 if job.environment.venv: 

352 msg += add_indent(2, f"{'Venv:': <13} {job.environment.venv}\n") 

353 elif isinstance(job, ShellJob): 

354 msg += add_indent(2, f"{'Path:': <13} {job.path}\n") 

355 if job.depends_on: 

356 msg += add_indent( 

357 2, f"{'Dependencies:': <13} {', '.join(job.depends_on)}\n" 

358 ) 

359 

360 msg += f"{'=' * 80}\n" 

361 print(msg) 

362 

363 def validate(self): 

364 """Validate workflow job dependencies.""" 

365 job_names = {job.name for job in self.jobs} 

366 

367 if len(job_names) != len(self.jobs): 

368 raise WorkflowValidationError("Duplicate job names found in workflow") 

369 

370 for job in self.jobs: 

371 for dependency in job.depends_on: 

372 if dependency not in job_names: 

373 raise WorkflowValidationError( 

374 f"Job '{job.name}' depends on unknown job '{dependency}'" 

375 ) 

376 

377 # Check for circular dependencies (simple check) 

378 visited = set() 

379 rec_stack = set() 

380 

381 def has_cycle(job_name: str) -> bool: 

382 if job_name in rec_stack: 

383 return True 

384 if job_name in visited: 

385 return False 

386 

387 visited.add(job_name) 

388 rec_stack.add(job_name) 

389 

390 job = self.get(job_name) 

391 if job: 

392 for dependency in job.depends_on: 

393 if has_cycle(dependency): 

394 return True 

395 

396 rec_stack.remove(job_name) 

397 return False 

398 

399 for job in self.jobs: 

400 if has_cycle(job.name): 

401 raise WorkflowValidationError( 

402 f"Circular dependency detected involving job '{job.name}'" 

403 ) 

404 

405 

406def render_job_script( 

407 template_path: Path | str, 

408 job: Job, 

409 output_dir: Path | str, 

410 verbose: bool = False, 

411) -> str: 

412 """Render a SLURM job script from a template. 

413 

414 Args: 

415 template_path: Path to the Jinja template file. 

416 job: Job configuration. 

417 output_dir: Directory where the generated script will be saved. 

418 verbose: Whether to print the rendered content. 

419 

420 Returns: 

421 Path to the generated SLURM batch script. 

422 

423 Raises: 

424 FileNotFoundError: If the template file does not exist. 

425 jinja2.TemplateError: If template rendering fails. 

426 """ 

427 template_file = Path(template_path) 

428 if not template_file.is_file(): 

429 raise FileNotFoundError(f"Template file '{template_path}' not found") 

430 

431 with open(template_file, encoding="utf-8") as f: 

432 template_content = f.read() 

433 

434 template = jinja2.Template(template_content, undefined=jinja2.StrictUndefined) 

435 

436 # Prepare template variables 

437 template_vars = { 

438 "job_name": job.name, 

439 "command": " ".join(job.command or []), 

440 "log_dir": job.log_dir, 

441 "work_dir": job.work_dir, 

442 "environment_setup": _build_environment_setup(job.environment), 

443 **job.resources.model_dump(), 

444 } 

445 

446 rendered_content = template.render(template_vars) 

447 

448 if verbose: 

449 print(rendered_content) 

450 

451 # Generate output file 

452 output_path = Path(output_dir) / f"{job.name}.slurm" 

453 with open(output_path, "w", encoding="utf-8") as f: 

454 f.write(rendered_content) 

455 

456 return str(output_path) 

457 

458 

459def _build_environment_setup(environment: JobEnvironment) -> str: 

460 """Build environment setup script.""" 

461 setup_lines = [] 

462 

463 # Set environment variables 

464 for key, value in environment.env_vars.items(): 

465 setup_lines.append(f"export {key}={value}") 

466 

467 # Activate environments 

468 if environment.conda: 

469 home_dir = Path.home() 

470 setup_lines.extend( 

471 [ 

472 f"source {str(home_dir)}/miniconda3/bin/activate", 

473 "conda deactivate", 

474 f"conda activate {environment.conda}", 

475 ] 

476 ) 

477 elif environment.venv: 

478 setup_lines.append(f"source {environment.venv}/bin/activate") 

479 elif environment.sqsh: 

480 setup_lines.extend( 

481 [ 

482 f': "${{IMAGE:={environment.sqsh}}}"', 

483 "declare -a CONTAINER_ARGS=(", 

484 ' --container-image "$IMAGE"', 

485 ")", 

486 ] 

487 ) 

488 

489 return "\n".join(setup_lines)