Coverage for src/srunx/models.py: 82%
230 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-13 23:46 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-13 23:46 +0000
1"""Data models for SLURM job management."""
3import os
4import subprocess
5import time
6from enum import Enum
7from pathlib import Path
8from typing import Self
10import jinja2
11from pydantic import BaseModel, Field, PrivateAttr, model_validator
13from srunx.exceptions import WorkflowValidationError
14from srunx.logging import get_logger
16logger = get_logger(__name__)
19def _get_config_defaults():
20 """Get configuration defaults, with lazy import to avoid circular dependencies."""
21 try:
22 from srunx.config import get_config
24 return get_config()
25 except ImportError:
26 # Fallback if config module is not available
27 return None
30def _default_nodes():
31 """Get default nodes from config."""
32 config = _get_config_defaults()
33 return config.resources.nodes if config else 1
36def _default_gpus_per_node():
37 """Get default GPUs per node from config."""
38 config = _get_config_defaults()
39 return config.resources.gpus_per_node if config else 0
42def _default_ntasks_per_node():
43 """Get default ntasks per node from config."""
44 config = _get_config_defaults()
45 return config.resources.ntasks_per_node if config else 1
48def _default_cpus_per_task():
49 """Get default CPUs per task from config."""
50 config = _get_config_defaults()
51 return config.resources.cpus_per_task if config else 1
54def _default_memory_per_node():
55 """Get default memory per node from config."""
56 config = _get_config_defaults()
57 return config.resources.memory_per_node if config else None
60def _default_time_limit():
61 """Get default time limit from config."""
62 config = _get_config_defaults()
63 return config.resources.time_limit if config else None
66def _default_nodelist():
67 """Get default nodelist from config."""
68 config = _get_config_defaults()
69 return config.resources.nodelist if config else None
72def _default_partition():
73 """Get default partition from config."""
74 config = _get_config_defaults()
75 return config.resources.partition if config else None
78def _default_conda():
79 """Get default conda environment from config."""
80 config = _get_config_defaults()
81 return config.environment.conda if config else None
84def _default_venv():
85 """Get default venv path from config."""
86 config = _get_config_defaults()
87 return config.environment.venv if config else None
90def _default_sqsh():
91 """Get default sqsh path from config."""
92 config = _get_config_defaults()
93 return config.environment.sqsh if config else None
96def _default_env_vars():
97 """Get default environment variables from config."""
98 config = _get_config_defaults()
99 return config.environment.env_vars if config else {}
102def _default_log_dir():
103 """Get default log directory from config."""
104 config = _get_config_defaults()
105 return config.log_dir if config else os.getenv("SLURM_LOG_DIR", "logs")
108def _default_work_dir():
109 """Get default work directory from config."""
110 config = _get_config_defaults()
111 return config.work_dir if config else None
114class JobStatus(Enum):
115 """Job status enumeration for both SLURM jobs and workflow jobs."""
117 UNKNOWN = "UNKNOWN"
118 PENDING = "PENDING"
119 RUNNING = "RUNNING"
120 COMPLETED = "COMPLETED"
121 FAILED = "FAILED"
122 CANCELLED = "CANCELLED"
123 TIMEOUT = "TIMEOUT"
126class JobResource(BaseModel):
127 """SLURM resource allocation requirements."""
129 nodes: int = Field(
130 default_factory=_default_nodes, ge=1, description="Number of compute nodes"
131 )
132 gpus_per_node: int = Field(
133 default_factory=_default_gpus_per_node,
134 ge=0,
135 description="Number of GPUs per node",
136 )
137 ntasks_per_node: int = Field(
138 default_factory=_default_ntasks_per_node,
139 ge=1,
140 description="Number of jobs per node",
141 )
142 cpus_per_task: int = Field(
143 default_factory=_default_cpus_per_task,
144 ge=1,
145 description="Number of CPUs per task",
146 )
147 memory_per_node: str | None = Field(
148 default_factory=_default_memory_per_node,
149 description="Memory per node (e.g., '32GB')",
150 )
151 time_limit: str | None = Field(
152 default_factory=_default_time_limit, description="Time limit (e.g., '1:00:00')"
153 )
154 nodelist: str | None = Field(
155 default_factory=_default_nodelist,
156 description="Specific nodes to use (e.g., 'node001,node002')",
157 )
158 partition: str | None = Field(
159 default_factory=_default_partition,
160 description="SLURM partition to use (e.g., 'gpu', 'cpu')",
161 )
164class JobEnvironment(BaseModel):
165 """Job environment configuration."""
167 conda: str | None = Field(
168 default_factory=_default_conda, description="Conda environment name"
169 )
170 venv: str | None = Field(
171 default_factory=_default_venv, description="Virtual environment path"
172 )
173 sqsh: str | None = Field(
174 default_factory=_default_sqsh, description="SquashFS image path"
175 )
176 env_vars: dict[str, str] = Field(
177 default_factory=_default_env_vars, description="Environment variables"
178 )
180 @model_validator(mode="after")
181 def validate_environment(self) -> Self:
182 envs = [self.conda, self.venv, self.sqsh]
183 non_none_count = sum(x is not None for x in envs)
184 if non_none_count != 1:
185 raise ValueError("Exactly one of 'conda', 'venv', or 'sqsh' must be set")
186 return self
189class BaseJob(BaseModel):
190 name: str = Field(default="job", description="Job name")
191 job_id: int | None = Field(default=None, description="SLURM job ID")
192 depends_on: list[str] = Field(
193 default_factory=list, description="Task dependencies for workflow execution"
194 )
196 _status: JobStatus = PrivateAttr(default=JobStatus.PENDING)
198 @property
199 def status(self) -> JobStatus:
200 """
201 Accessing ``job.status`` always triggers a lightweight refresh
202 (only if we have a ``job_id`` and the status isn't terminal).
203 """
204 if self.job_id is not None and self._status not in {
205 JobStatus.COMPLETED,
206 JobStatus.FAILED,
207 JobStatus.CANCELLED,
208 JobStatus.TIMEOUT,
209 }:
210 self.refresh()
211 return self._status
213 @status.setter
214 def status(self, value: JobStatus) -> None:
215 self._status = value
217 def refresh(self, retries: int = 3) -> Self:
218 """Query sacct and update ``_status`` in-place."""
219 if self.job_id is None:
220 return self
222 for retry in range(retries):
223 try:
224 result = subprocess.run(
225 [
226 "sacct",
227 "-j",
228 str(self.job_id),
229 "--format",
230 "JobID,State",
231 "--noheader",
232 "--parsable2",
233 ],
234 capture_output=True,
235 text=True,
236 check=True,
237 )
238 except subprocess.CalledProcessError as e:
239 logger.error(f"Failed to query job {self.job_id}: {e}")
240 raise
242 line = result.stdout.strip().split("\n")[0] if result.stdout.strip() else ""
243 if not line:
244 if retry < retries - 1:
245 time.sleep(1)
246 continue
247 self._status = JobStatus.UNKNOWN
248 return self
249 break
251 _, state = line.split("|", 1)
252 self._status = JobStatus(state)
253 return self
255 def dependencies_satisfied(self, completed_job_names: list[str]) -> bool:
256 """All dependencies are completed & this job is still pending."""
257 return self.status == JobStatus.PENDING and all(
258 dep in completed_job_names for dep in self.depends_on
259 )
262class Job(BaseJob):
263 """Represents a SLURM job with complete configuration."""
265 command: list[str] = Field(description="Command to execute")
266 resources: JobResource = Field(
267 default_factory=JobResource, description="Resource requirements"
268 )
269 environment: JobEnvironment = Field(
270 default_factory=JobEnvironment, description="Environment setup"
271 )
272 log_dir: str = Field(
273 default_factory=_default_log_dir,
274 description="Directory for log files",
275 )
276 work_dir: str = Field(
277 default_factory=lambda: _default_work_dir() or os.getcwd(),
278 description="Working directory",
279 )
282class ShellJob(BaseJob):
283 path: str = Field(description="Shell script path to execute")
286type JobType = BaseJob | Job | ShellJob
287type RunnableJobType = Job | ShellJob
290class Workflow:
291 """Represents a workflow containing multiple jobs with dependencies."""
293 def __init__(self, name: str, jobs: list[RunnableJobType] | None = None) -> None:
294 if jobs is None:
295 jobs = []
297 self.name = name
298 self.jobs = jobs
300 def add(self, job: RunnableJobType) -> None:
301 # Check if job already exists
302 if job.depends_on:
303 for dep in job.depends_on:
304 if dep not in self.jobs:
305 raise WorkflowValidationError(
306 f"Job '{job.name}' depends on unknown job '{dep}'"
307 )
308 self.jobs.append(job)
310 def remove(self, job: RunnableJobType) -> None:
311 self.jobs.remove(job)
313 def get(self, name: str) -> RunnableJobType | None:
314 """Get a job by name."""
315 for job in self.jobs:
316 if job.name == name:
317 return job.refresh()
318 return None
320 def get_dependencies(self, job_name: str) -> list[str]:
321 """Get dependencies for a specific job."""
322 job = self.get(job_name)
323 return job.depends_on if job else []
325 def show(self):
326 msg = f"""\
327{" PLAN ":=^80}
328Workflow: {self.name}
329Jobs: {len(self.jobs)}
330"""
332 def add_indent(indent: int, msg: str) -> str:
333 return " " * indent + msg
335 for job in self.jobs:
336 msg += add_indent(1, f"Job: {job.name}\n")
337 if isinstance(job, Job):
338 msg += add_indent(
339 2, f"{'Command:': <13} {' '.join(job.command or [])}\n"
340 )
341 msg += add_indent(
342 2,
343 f"{'Resources:': <13} {job.resources.nodes} nodes, {job.resources.gpus_per_node} GPUs/node\n",
344 )
345 if job.environment.conda:
346 msg += add_indent(
347 2, f"{'Conda env:': <13} {job.environment.conda}\n"
348 )
349 if job.environment.sqsh:
350 msg += add_indent(2, f"{'Sqsh:': <13} {job.environment.sqsh}\n")
351 if job.environment.venv:
352 msg += add_indent(2, f"{'Venv:': <13} {job.environment.venv}\n")
353 elif isinstance(job, ShellJob):
354 msg += add_indent(2, f"{'Path:': <13} {job.path}\n")
355 if job.depends_on:
356 msg += add_indent(
357 2, f"{'Dependencies:': <13} {', '.join(job.depends_on)}\n"
358 )
360 msg += f"{'=' * 80}\n"
361 print(msg)
363 def validate(self):
364 """Validate workflow job dependencies."""
365 job_names = {job.name for job in self.jobs}
367 if len(job_names) != len(self.jobs):
368 raise WorkflowValidationError("Duplicate job names found in workflow")
370 for job in self.jobs:
371 for dependency in job.depends_on:
372 if dependency not in job_names:
373 raise WorkflowValidationError(
374 f"Job '{job.name}' depends on unknown job '{dependency}'"
375 )
377 # Check for circular dependencies (simple check)
378 visited = set()
379 rec_stack = set()
381 def has_cycle(job_name: str) -> bool:
382 if job_name in rec_stack:
383 return True
384 if job_name in visited:
385 return False
387 visited.add(job_name)
388 rec_stack.add(job_name)
390 job = self.get(job_name)
391 if job:
392 for dependency in job.depends_on:
393 if has_cycle(dependency):
394 return True
396 rec_stack.remove(job_name)
397 return False
399 for job in self.jobs:
400 if has_cycle(job.name):
401 raise WorkflowValidationError(
402 f"Circular dependency detected involving job '{job.name}'"
403 )
406def render_job_script(
407 template_path: Path | str,
408 job: Job,
409 output_dir: Path | str,
410 verbose: bool = False,
411) -> str:
412 """Render a SLURM job script from a template.
414 Args:
415 template_path: Path to the Jinja template file.
416 job: Job configuration.
417 output_dir: Directory where the generated script will be saved.
418 verbose: Whether to print the rendered content.
420 Returns:
421 Path to the generated SLURM batch script.
423 Raises:
424 FileNotFoundError: If the template file does not exist.
425 jinja2.TemplateError: If template rendering fails.
426 """
427 template_file = Path(template_path)
428 if not template_file.is_file():
429 raise FileNotFoundError(f"Template file '{template_path}' not found")
431 with open(template_file, encoding="utf-8") as f:
432 template_content = f.read()
434 template = jinja2.Template(template_content, undefined=jinja2.StrictUndefined)
436 # Prepare template variables
437 template_vars = {
438 "job_name": job.name,
439 "command": " ".join(job.command or []),
440 "log_dir": job.log_dir,
441 "work_dir": job.work_dir,
442 "environment_setup": _build_environment_setup(job.environment),
443 **job.resources.model_dump(),
444 }
446 rendered_content = template.render(template_vars)
448 if verbose:
449 print(rendered_content)
451 # Generate output file
452 output_path = Path(output_dir) / f"{job.name}.slurm"
453 with open(output_path, "w", encoding="utf-8") as f:
454 f.write(rendered_content)
456 return str(output_path)
459def _build_environment_setup(environment: JobEnvironment) -> str:
460 """Build environment setup script."""
461 setup_lines = []
463 # Set environment variables
464 for key, value in environment.env_vars.items():
465 setup_lines.append(f"export {key}={value}")
467 # Activate environments
468 if environment.conda:
469 home_dir = Path.home()
470 setup_lines.extend(
471 [
472 f"source {str(home_dir)}/miniconda3/bin/activate",
473 "conda deactivate",
474 f"conda activate {environment.conda}",
475 ]
476 )
477 elif environment.venv:
478 setup_lines.append(f"source {environment.venv}/bin/activate")
479 elif environment.sqsh:
480 setup_lines.extend(
481 [
482 f': "${{IMAGE:={environment.sqsh}}}"',
483 "declare -a CONTAINER_ARGS=(",
484 ' --container-image "$IMAGE"',
485 ")",
486 ]
487 )
489 return "\n".join(setup_lines)