Coverage for src/srunx/models.py: 79%

113 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-17 20:31 +0900

1"""Data models for SLURM job management.""" 

2 

3import os 

4import subprocess 

5import time 

6from enum import Enum 

7from pathlib import Path 

8from typing import Self 

9 

10import jinja2 

11from pydantic import BaseModel, Field, model_validator 

12from srunx.logging import get_logger 

13 

14logger = get_logger(__name__) 

15 

16 

17class JobStatus(Enum): 

18 """SLURM job status enumeration.""" 

19 

20 PENDING = "PENDING" 

21 RUNNING = "RUNNING" 

22 COMPLETED = "COMPLETED" 

23 FAILED = "FAILED" 

24 CANCELLED = "CANCELLED" 

25 TIMEOUT = "TIMEOUT" 

26 

27 

28class JobResource(BaseModel): 

29 """SLURM resource allocation requirements.""" 

30 

31 nodes: int = Field(default=1, ge=1, description="Number of compute nodes") 

32 gpus_per_node: int = Field(default=0, ge=0, description="Number of GPUs per node") 

33 ntasks_per_node: int = Field( 

34 default=1, ge=1, description="Number of tasks per node" 

35 ) 

36 cpus_per_task: int = Field(default=1, ge=1, description="Number of CPUs per task") 

37 memory_per_node: str | None = Field( 

38 default=None, description="Memory per node (e.g., '32GB')" 

39 ) 

40 time_limit: str | None = Field( 

41 default=None, description="Time limit (e.g., '1:00:00')" 

42 ) 

43 

44 

45class JobEnvironment(BaseModel): 

46 """Job environment configuration.""" 

47 

48 conda: str | None = Field(default=None, description="Conda environment name") 

49 venv: str | None = Field(default=None, description="Virtual environment path") 

50 sqsh: str | None = Field(default=None, description="SquashFS image path") 

51 env_vars: dict[str, str] = Field( 

52 default_factory=dict, description="Environment variables" 

53 ) 

54 

55 @model_validator(mode="after") 

56 def validate_environment(self) -> Self: 

57 envs = [self.conda, self.venv, self.sqsh] 

58 non_none_count = sum(x is not None for x in envs) 

59 if non_none_count != 1: 

60 raise ValueError("Exactly one of 'conda', 'venv', or 'sqsh' must be set") 

61 return self 

62 

63 

64class BaseJob(BaseModel): 

65 name: str = Field(default="job", description="Job name") 

66 job_id: int | None = Field(default=None, description="SLURM job ID") 

67 status: JobStatus | None = Field(default=None, description="Current job status") 

68 dependencies: list[str] = Field( 

69 default_factory=list, description="Job dependencies" 

70 ) 

71 

72 def refresh(self, retries: int = 3) -> None: 

73 for retry in range(retries): 

74 try: 

75 result = subprocess.run( 

76 [ 

77 "sacct", 

78 "-j", 

79 str(self.job_id), 

80 "--format", 

81 "JobID,JobName,State", 

82 "--noheader", 

83 "--parsable2", 

84 ], 

85 capture_output=True, 

86 text=True, 

87 check=True, 

88 ) 

89 except subprocess.CalledProcessError as e: 

90 logger.error(f"Failed to query job {self.job_id} status: {e}") 

91 raise 

92 

93 lines = result.stdout.strip().split("\n") 

94 if not lines or not lines[0]: 

95 error_msg = f"No job information found for job {self.job_id}" 

96 if retry < retries - 1: 

97 error_msg += f" Retrying {retry + 1} of {retries}..." 

98 logger.error(error_msg) 

99 time.sleep(1) 

100 continue 

101 raise ValueError(error_msg) 

102 else: 

103 break 

104 

105 # Parse the first line (main job entry) 

106 job_data = lines[0].split("|") 

107 if len(job_data) < 3: 

108 error_msg = f"Cannot parse job data for job {self.job_id}" 

109 logger.error(error_msg) 

110 raise ValueError(error_msg) 

111 

112 status = job_data[2] 

113 self.status = JobStatus(status) 

114 

115 

116class Job(BaseJob): 

117 """Represents a SLURM job with complete configuration.""" 

118 

119 command: list[str] = Field(description="Command to execute") 

120 resources: JobResource = Field( 

121 default_factory=JobResource, description="Resource requirements" 

122 ) 

123 environment: JobEnvironment = Field( 

124 default_factory=JobEnvironment, description="Environment setup" 

125 ) 

126 log_dir: str = Field( 

127 default=os.getenv("SLURM_LOG_DIR", "logs"), 

128 description="Directory for log files", 

129 ) 

130 work_dir: str = Field(default_factory=os.getcwd, description="Working directory") 

131 

132 

133class ShellJob(BaseJob): 

134 path: str = Field(description="Shell script path to execute") 

135 

136 

137class WorkflowTask(BaseModel): 

138 """Represents a single task in a workflow.""" 

139 

140 name: str = Field(description="Task name") 

141 job: BaseJob = Field(description="Job configuration") 

142 depends_on: list[str] = Field(default_factory=list, description="Task dependencies") 

143 async_execution: bool = Field( 

144 default=False, description="Whether to run asynchronously" 

145 ) 

146 

147 

148class Workflow(BaseModel): 

149 """Represents a workflow containing multiple tasks with dependencies.""" 

150 

151 name: str = Field(description="Workflow name") 

152 tasks: list[WorkflowTask] = Field(description="List of tasks in the workflow") 

153 

154 def get_task(self, name: str) -> WorkflowTask | None: 

155 """Get a task by name.""" 

156 for task in self.tasks: 

157 if task.name == name: 

158 return task 

159 return None 

160 

161 def get_task_dependencies(self, task_name: str) -> list[str]: 

162 """Get dependencies for a specific task.""" 

163 task = self.get_task(task_name) 

164 return task.depends_on if task else [] 

165 

166 

167def render_job_script( 

168 template_path: Path | str, 

169 job: Job, 

170 output_dir: Path | str, 

171) -> str: 

172 """Render a SLURM job script from a template. 

173 

174 Args: 

175 template_path: Path to the Jinja template file. 

176 job: Job configuration. 

177 output_dir: Directory where the generated script will be saved. 

178 

179 Returns: 

180 Path to the generated SLURM batch script. 

181 

182 Raises: 

183 FileNotFoundError: If the template file does not exist. 

184 jinja2.TemplateError: If template rendering fails. 

185 """ 

186 template_file = Path(template_path) 

187 if not template_file.is_file(): 

188 raise FileNotFoundError(f"Template file '{template_path}' not found") 

189 

190 with open(template_file, encoding="utf-8") as f: 

191 template_content = f.read() 

192 

193 template = jinja2.Template(template_content, undefined=jinja2.StrictUndefined) 

194 

195 # Prepare template variables 

196 template_vars = { 

197 "job_name": job.name, 

198 "command": " ".join(job.command or []), 

199 "log_dir": job.log_dir, 

200 "work_dir": job.work_dir, 

201 "environment_setup": _build_environment_setup(job.environment), 

202 **job.resources.model_dump(), 

203 } 

204 

205 rendered_content = template.render(template_vars) 

206 

207 # Generate output file 

208 output_path = Path(output_dir) / f"{job.name}.slurm" 

209 with open(output_path, "w", encoding="utf-8") as f: 

210 f.write(rendered_content) 

211 

212 return str(output_path) 

213 

214 

215def _build_environment_setup(environment: JobEnvironment) -> str: 

216 """Build environment setup script.""" 

217 setup_lines = [] 

218 

219 # Set environment variables 

220 for key, value in environment.env_vars.items(): 

221 setup_lines.append(f"export {key}={value}") 

222 

223 # Activate environments 

224 if environment.conda: 

225 setup_lines.extend(["conda deactivate", f"conda activate {environment.conda}"]) 

226 elif environment.venv: 

227 setup_lines.append(f"source {environment.venv}/bin/activate") 

228 elif environment.sqsh: 

229 setup_lines.extend( 

230 [ 

231 f': "${{IMAGE:={environment.sqsh}}}"', 

232 "declare -a CONTAINER_ARGS=(", 

233 ' --container-image "$IMAGE"', 

234 ")", 

235 ] 

236 ) 

237 

238 return "\n".join(setup_lines)