Coverage for src/srunx/models.py: 79%
113 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-17 20:31 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-17 20:31 +0900
1"""Data models for SLURM job management."""
3import os
4import subprocess
5import time
6from enum import Enum
7from pathlib import Path
8from typing import Self
10import jinja2
11from pydantic import BaseModel, Field, model_validator
12from srunx.logging import get_logger
14logger = get_logger(__name__)
17class JobStatus(Enum):
18 """SLURM job status enumeration."""
20 PENDING = "PENDING"
21 RUNNING = "RUNNING"
22 COMPLETED = "COMPLETED"
23 FAILED = "FAILED"
24 CANCELLED = "CANCELLED"
25 TIMEOUT = "TIMEOUT"
28class JobResource(BaseModel):
29 """SLURM resource allocation requirements."""
31 nodes: int = Field(default=1, ge=1, description="Number of compute nodes")
32 gpus_per_node: int = Field(default=0, ge=0, description="Number of GPUs per node")
33 ntasks_per_node: int = Field(
34 default=1, ge=1, description="Number of tasks per node"
35 )
36 cpus_per_task: int = Field(default=1, ge=1, description="Number of CPUs per task")
37 memory_per_node: str | None = Field(
38 default=None, description="Memory per node (e.g., '32GB')"
39 )
40 time_limit: str | None = Field(
41 default=None, description="Time limit (e.g., '1:00:00')"
42 )
45class JobEnvironment(BaseModel):
46 """Job environment configuration."""
48 conda: str | None = Field(default=None, description="Conda environment name")
49 venv: str | None = Field(default=None, description="Virtual environment path")
50 sqsh: str | None = Field(default=None, description="SquashFS image path")
51 env_vars: dict[str, str] = Field(
52 default_factory=dict, description="Environment variables"
53 )
55 @model_validator(mode="after")
56 def validate_environment(self) -> Self:
57 envs = [self.conda, self.venv, self.sqsh]
58 non_none_count = sum(x is not None for x in envs)
59 if non_none_count != 1:
60 raise ValueError("Exactly one of 'conda', 'venv', or 'sqsh' must be set")
61 return self
64class BaseJob(BaseModel):
65 name: str = Field(default="job", description="Job name")
66 job_id: int | None = Field(default=None, description="SLURM job ID")
67 status: JobStatus | None = Field(default=None, description="Current job status")
68 dependencies: list[str] = Field(
69 default_factory=list, description="Job dependencies"
70 )
72 def refresh(self, retries: int = 3) -> None:
73 for retry in range(retries):
74 try:
75 result = subprocess.run(
76 [
77 "sacct",
78 "-j",
79 str(self.job_id),
80 "--format",
81 "JobID,JobName,State",
82 "--noheader",
83 "--parsable2",
84 ],
85 capture_output=True,
86 text=True,
87 check=True,
88 )
89 except subprocess.CalledProcessError as e:
90 logger.error(f"Failed to query job {self.job_id} status: {e}")
91 raise
93 lines = result.stdout.strip().split("\n")
94 if not lines or not lines[0]:
95 error_msg = f"No job information found for job {self.job_id}"
96 if retry < retries - 1:
97 error_msg += f" Retrying {retry + 1} of {retries}..."
98 logger.error(error_msg)
99 time.sleep(1)
100 continue
101 raise ValueError(error_msg)
102 else:
103 break
105 # Parse the first line (main job entry)
106 job_data = lines[0].split("|")
107 if len(job_data) < 3:
108 error_msg = f"Cannot parse job data for job {self.job_id}"
109 logger.error(error_msg)
110 raise ValueError(error_msg)
112 status = job_data[2]
113 self.status = JobStatus(status)
116class Job(BaseJob):
117 """Represents a SLURM job with complete configuration."""
119 command: list[str] = Field(description="Command to execute")
120 resources: JobResource = Field(
121 default_factory=JobResource, description="Resource requirements"
122 )
123 environment: JobEnvironment = Field(
124 default_factory=JobEnvironment, description="Environment setup"
125 )
126 log_dir: str = Field(
127 default=os.getenv("SLURM_LOG_DIR", "logs"),
128 description="Directory for log files",
129 )
130 work_dir: str = Field(default_factory=os.getcwd, description="Working directory")
133class ShellJob(BaseJob):
134 path: str = Field(description="Shell script path to execute")
137class WorkflowTask(BaseModel):
138 """Represents a single task in a workflow."""
140 name: str = Field(description="Task name")
141 job: BaseJob = Field(description="Job configuration")
142 depends_on: list[str] = Field(default_factory=list, description="Task dependencies")
143 async_execution: bool = Field(
144 default=False, description="Whether to run asynchronously"
145 )
148class Workflow(BaseModel):
149 """Represents a workflow containing multiple tasks with dependencies."""
151 name: str = Field(description="Workflow name")
152 tasks: list[WorkflowTask] = Field(description="List of tasks in the workflow")
154 def get_task(self, name: str) -> WorkflowTask | None:
155 """Get a task by name."""
156 for task in self.tasks:
157 if task.name == name:
158 return task
159 return None
161 def get_task_dependencies(self, task_name: str) -> list[str]:
162 """Get dependencies for a specific task."""
163 task = self.get_task(task_name)
164 return task.depends_on if task else []
167def render_job_script(
168 template_path: Path | str,
169 job: Job,
170 output_dir: Path | str,
171) -> str:
172 """Render a SLURM job script from a template.
174 Args:
175 template_path: Path to the Jinja template file.
176 job: Job configuration.
177 output_dir: Directory where the generated script will be saved.
179 Returns:
180 Path to the generated SLURM batch script.
182 Raises:
183 FileNotFoundError: If the template file does not exist.
184 jinja2.TemplateError: If template rendering fails.
185 """
186 template_file = Path(template_path)
187 if not template_file.is_file():
188 raise FileNotFoundError(f"Template file '{template_path}' not found")
190 with open(template_file, encoding="utf-8") as f:
191 template_content = f.read()
193 template = jinja2.Template(template_content, undefined=jinja2.StrictUndefined)
195 # Prepare template variables
196 template_vars = {
197 "job_name": job.name,
198 "command": " ".join(job.command or []),
199 "log_dir": job.log_dir,
200 "work_dir": job.work_dir,
201 "environment_setup": _build_environment_setup(job.environment),
202 **job.resources.model_dump(),
203 }
205 rendered_content = template.render(template_vars)
207 # Generate output file
208 output_path = Path(output_dir) / f"{job.name}.slurm"
209 with open(output_path, "w", encoding="utf-8") as f:
210 f.write(rendered_content)
212 return str(output_path)
215def _build_environment_setup(environment: JobEnvironment) -> str:
216 """Build environment setup script."""
217 setup_lines = []
219 # Set environment variables
220 for key, value in environment.env_vars.items():
221 setup_lines.append(f"export {key}={value}")
223 # Activate environments
224 if environment.conda:
225 setup_lines.extend(["conda deactivate", f"conda activate {environment.conda}"])
226 elif environment.venv:
227 setup_lines.append(f"source {environment.venv}/bin/activate")
228 elif environment.sqsh:
229 setup_lines.extend(
230 [
231 f': "${{IMAGE:={environment.sqsh}}}"',
232 "declare -a CONTAINER_ARGS=(",
233 ' --container-image "$IMAGE"',
234 ")",
235 ]
236 )
238 return "\n".join(setup_lines)