Coverage for src/srunx/client.py: 87%

134 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-17 20:31 +0900

1"""SLURM client for job submission and management.""" 

2 

3import subprocess 

4import tempfile 

5import time 

6from importlib.resources import files 

7from pathlib import Path 

8 

9from srunx.logging import get_logger 

10from srunx.models import ( 

11 BaseJob, 

12 Job, 

13 JobStatus, 

14 ShellJob, 

15 render_job_script, 

16) 

17from srunx.utils import get_job_status 

18 

19logger = get_logger(__name__) 

20 

21 

22class Slurm: 

23 """Client for interacting with SLURM workload manager.""" 

24 

25 def __init__(self, default_template: str | None = None): 

26 """Initialize SLURM client. 

27 

28 Args: 

29 default_template: Path to default job template. 

30 """ 

31 self.default_template = default_template or self._get_default_template() 

32 

33 def run( 

34 self, job: Job | ShellJob, template_path: str | None = None 

35 ) -> Job | ShellJob: 

36 """Submit a job to SLURM. 

37 

38 Args: 

39 job: Job configuration. 

40 template_path: Optional template path (uses default if not provided). 

41 

42 Returns: 

43 Job instance with updated job_id and status. 

44 

45 Raises: 

46 subprocess.CalledProcessError: If job submission fails. 

47 """ 

48 

49 if isinstance(job, Job): 

50 template = template_path or self.default_template 

51 

52 with tempfile.TemporaryDirectory() as temp_dir: 

53 script_path = render_job_script(template, job, temp_dir) 

54 logger.debug(f"Generated SLURM script at: {script_path}") 

55 

56 # Handle container execution 

57 sbatch_cmd = ["sbatch"] 

58 if job.environment.sqsh: 

59 sbatch_cmd.extend(["--sqsh", job.environment.sqsh]) 

60 logger.debug(f"Using sqsh: {job.environment.sqsh}") 

61 

62 sbatch_cmd.append(script_path) 

63 logger.debug(f"Executing command: {' '.join(sbatch_cmd)}") 

64 

65 try: 

66 result = subprocess.run( 

67 sbatch_cmd, 

68 capture_output=True, 

69 text=True, 

70 check=True, 

71 ) 

72 except subprocess.CalledProcessError as e: 

73 logger.error(f"Failed to submit job '{job.name}': {e}") 

74 logger.error(f"Command: {' '.join(e.cmd)}") 

75 logger.error(f"Return code: {e.returncode}") 

76 logger.error(f"Stdout: {e.stdout}") 

77 logger.error(f"Stderr: {e.stderr}") 

78 raise 

79 

80 elif isinstance(job, ShellJob): 

81 try: 

82 result = subprocess.run( 

83 ["sbatch", job.path], 

84 capture_output=True, 

85 text=True, 

86 check=True, 

87 ) 

88 except subprocess.CalledProcessError as e: 

89 logger.error(f"Failed to submit job '{job.name}': {e}") 

90 logger.error(f"Command: {' '.join(e.cmd)}") 

91 logger.error(f"Return code: {e.returncode}") 

92 logger.error(f"Stdout: {e.stdout}") 

93 logger.error(f"Stderr: {e.stderr}") 

94 raise 

95 

96 else: 

97 raise ValueError("Either 'command' or 'file' must be set") 

98 

99 time.sleep(3) 

100 job_id = int(result.stdout.split()[-1]) 

101 job.job_id = job_id 

102 job.status = JobStatus.PENDING 

103 

104 logger.info(f"Successfully submitted job '{job.name}' with ID {job_id}") 

105 return job 

106 

107 @staticmethod 

108 def retrieve(job_id: int) -> BaseJob: 

109 return get_job_status(job_id) 

110 

111 def cancel(self, job_id: int) -> None: 

112 """Cancel a SLURM job. 

113 

114 Args: 

115 job_id: SLURM job ID to cancel. 

116 

117 Raises: 

118 subprocess.CalledProcessError: If job cancellation fails. 

119 """ 

120 logger.info(f"Cancelling job {job_id}") 

121 

122 try: 

123 subprocess.run( 

124 ["scancel", str(job_id)], 

125 check=True, 

126 ) 

127 logger.info(f"Successfully cancelled job {job_id}") 

128 except subprocess.CalledProcessError as e: 

129 logger.error(f"Failed to cancel job {job_id}: {e}") 

130 raise 

131 

132 def list(self, user: str | None = None) -> list[BaseJob]: 

133 """List jobs for a user. 

134 

135 Args: 

136 user: Username (defaults to current user). 

137 

138 Returns: 

139 List of Job objects. 

140 """ 

141 cmd = [ 

142 "squeue", 

143 "--format", 

144 "%.18i %.9P %.15j %.8u %.8T %.10M %.9l %.6D %R", 

145 "--noheader", 

146 ] 

147 if user: 

148 cmd.extend(["--user", user]) 

149 

150 result = subprocess.run(cmd, capture_output=True, text=True, check=True) 

151 

152 jobs = [] 

153 for line in result.stdout.strip().split("\n"): 

154 if not line.strip(): 

155 continue 

156 

157 parts = line.split() 

158 if len(parts) >= 5: 

159 job_id = int(parts[0]) 

160 job_name = parts[2] 

161 status_str = parts[4] 

162 

163 try: 

164 status = JobStatus(status_str) 

165 except ValueError: 

166 status = JobStatus.PENDING # Default for unknown status 

167 

168 job = BaseJob( 

169 name=job_name, 

170 job_id=job_id, 

171 status=status, 

172 ) 

173 jobs.append(job) 

174 

175 return jobs 

176 

177 def monitor( 

178 self, job_obj_or_id: BaseJob | Job | ShellJob | int, poll_interval: int = 30 

179 ) -> BaseJob | Job | ShellJob: 

180 """Wait for a job to complete. 

181 

182 Args: 

183 job_obj_or_id: Job object or job ID. 

184 poll_interval: Polling interval in seconds. 

185 

186 Returns: 

187 Completed job object. 

188 

189 Raises: 

190 RuntimeError: If job fails. 

191 """ 

192 if isinstance(job_obj_or_id, int): 

193 job = self.retrieve(job_obj_or_id) 

194 else: 

195 job = job_obj_or_id 

196 

197 msg = f"Waiting for job {job.job_id} to complete (polling every {poll_interval}s)." 

198 if isinstance(job, Job): 

199 msg += f" Logging to {job.log_dir}/{job.name}_{job.job_id}.out" 

200 logger.info(msg) 

201 

202 previous_status = None 

203 

204 while True: 

205 job.refresh() 

206 

207 # Log status changes 

208 if job.status != previous_status: 

209 status_str = job.status.value if job.status else "Unknown" 

210 logger.info(f"Job {job.job_id} status: {status_str}") 

211 previous_status = job.status 

212 

213 match job.status: 

214 case JobStatus.COMPLETED: 

215 logger.info(f"Job {job.job_id} completed successfully") 

216 return job 

217 case JobStatus.FAILED: 

218 err_msg = f"SLURM job {job.job_id} failed.\n" 

219 if isinstance(job, Job): 

220 log_file = Path(job.log_dir) / f"{job.name}_{job.job_id}.out" 

221 if log_file.exists(): 

222 with open(log_file) as f: 

223 err_msg += f.read() 

224 err_msg += f"\nLog file: {log_file}" 

225 else: 

226 err_msg += f"Log file not found: {log_file}" 

227 raise RuntimeError(err_msg) 

228 case JobStatus.CANCELLED | JobStatus.TIMEOUT: 

229 err_msg = ( 

230 f"SLURM job {job.job_id} was {job.status.value.lower()}.\n" 

231 ) 

232 if isinstance(job, Job): 

233 log_file = Path(job.log_dir) / f"{job.name}_{job.job_id}.out" 

234 if log_file.exists(): 

235 with open(log_file) as f: 

236 err_msg += f.read() 

237 err_msg += f"\nLog file: {log_file}" 

238 else: 

239 err_msg += f"Log file not found: {log_file}" 

240 raise RuntimeError(err_msg) 

241 time.sleep(poll_interval) 

242 

243 def _get_default_template(self) -> str: 

244 """Get the default job template path.""" 

245 from srunx import templates 

246 

247 return str(files(templates).joinpath("base.slurm.jinja")) 

248 

249 

250# Convenience functions for backward compatibility 

251def submit_job(job: Job | ShellJob, template_path: str | None = None) -> Job | ShellJob: 

252 """Submit a job to SLURM (convenience function).""" 

253 client = Slurm() 

254 return client.run(job, template_path) 

255 

256 

257def retrieve_job(job_id: int) -> BaseJob: 

258 """Get job status (convenience function).""" 

259 client = Slurm() 

260 return client.retrieve(job_id) 

261 

262 

263def cancel_job(job_id: int) -> None: 

264 """Cancel a job (convenience function).""" 

265 client = Slurm() 

266 client.cancel(job_id)