Coverage for src/srunx/client.py: 87%
134 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-17 20:31 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-17 20:31 +0900
1"""SLURM client for job submission and management."""
3import subprocess
4import tempfile
5import time
6from importlib.resources import files
7from pathlib import Path
9from srunx.logging import get_logger
10from srunx.models import (
11 BaseJob,
12 Job,
13 JobStatus,
14 ShellJob,
15 render_job_script,
16)
17from srunx.utils import get_job_status
19logger = get_logger(__name__)
22class Slurm:
23 """Client for interacting with SLURM workload manager."""
25 def __init__(self, default_template: str | None = None):
26 """Initialize SLURM client.
28 Args:
29 default_template: Path to default job template.
30 """
31 self.default_template = default_template or self._get_default_template()
33 def run(
34 self, job: Job | ShellJob, template_path: str | None = None
35 ) -> Job | ShellJob:
36 """Submit a job to SLURM.
38 Args:
39 job: Job configuration.
40 template_path: Optional template path (uses default if not provided).
42 Returns:
43 Job instance with updated job_id and status.
45 Raises:
46 subprocess.CalledProcessError: If job submission fails.
47 """
49 if isinstance(job, Job):
50 template = template_path or self.default_template
52 with tempfile.TemporaryDirectory() as temp_dir:
53 script_path = render_job_script(template, job, temp_dir)
54 logger.debug(f"Generated SLURM script at: {script_path}")
56 # Handle container execution
57 sbatch_cmd = ["sbatch"]
58 if job.environment.sqsh:
59 sbatch_cmd.extend(["--sqsh", job.environment.sqsh])
60 logger.debug(f"Using sqsh: {job.environment.sqsh}")
62 sbatch_cmd.append(script_path)
63 logger.debug(f"Executing command: {' '.join(sbatch_cmd)}")
65 try:
66 result = subprocess.run(
67 sbatch_cmd,
68 capture_output=True,
69 text=True,
70 check=True,
71 )
72 except subprocess.CalledProcessError as e:
73 logger.error(f"Failed to submit job '{job.name}': {e}")
74 logger.error(f"Command: {' '.join(e.cmd)}")
75 logger.error(f"Return code: {e.returncode}")
76 logger.error(f"Stdout: {e.stdout}")
77 logger.error(f"Stderr: {e.stderr}")
78 raise
80 elif isinstance(job, ShellJob):
81 try:
82 result = subprocess.run(
83 ["sbatch", job.path],
84 capture_output=True,
85 text=True,
86 check=True,
87 )
88 except subprocess.CalledProcessError as e:
89 logger.error(f"Failed to submit job '{job.name}': {e}")
90 logger.error(f"Command: {' '.join(e.cmd)}")
91 logger.error(f"Return code: {e.returncode}")
92 logger.error(f"Stdout: {e.stdout}")
93 logger.error(f"Stderr: {e.stderr}")
94 raise
96 else:
97 raise ValueError("Either 'command' or 'file' must be set")
99 time.sleep(3)
100 job_id = int(result.stdout.split()[-1])
101 job.job_id = job_id
102 job.status = JobStatus.PENDING
104 logger.info(f"Successfully submitted job '{job.name}' with ID {job_id}")
105 return job
107 @staticmethod
108 def retrieve(job_id: int) -> BaseJob:
109 return get_job_status(job_id)
111 def cancel(self, job_id: int) -> None:
112 """Cancel a SLURM job.
114 Args:
115 job_id: SLURM job ID to cancel.
117 Raises:
118 subprocess.CalledProcessError: If job cancellation fails.
119 """
120 logger.info(f"Cancelling job {job_id}")
122 try:
123 subprocess.run(
124 ["scancel", str(job_id)],
125 check=True,
126 )
127 logger.info(f"Successfully cancelled job {job_id}")
128 except subprocess.CalledProcessError as e:
129 logger.error(f"Failed to cancel job {job_id}: {e}")
130 raise
132 def list(self, user: str | None = None) -> list[BaseJob]:
133 """List jobs for a user.
135 Args:
136 user: Username (defaults to current user).
138 Returns:
139 List of Job objects.
140 """
141 cmd = [
142 "squeue",
143 "--format",
144 "%.18i %.9P %.15j %.8u %.8T %.10M %.9l %.6D %R",
145 "--noheader",
146 ]
147 if user:
148 cmd.extend(["--user", user])
150 result = subprocess.run(cmd, capture_output=True, text=True, check=True)
152 jobs = []
153 for line in result.stdout.strip().split("\n"):
154 if not line.strip():
155 continue
157 parts = line.split()
158 if len(parts) >= 5:
159 job_id = int(parts[0])
160 job_name = parts[2]
161 status_str = parts[4]
163 try:
164 status = JobStatus(status_str)
165 except ValueError:
166 status = JobStatus.PENDING # Default for unknown status
168 job = BaseJob(
169 name=job_name,
170 job_id=job_id,
171 status=status,
172 )
173 jobs.append(job)
175 return jobs
177 def monitor(
178 self, job_obj_or_id: BaseJob | Job | ShellJob | int, poll_interval: int = 30
179 ) -> BaseJob | Job | ShellJob:
180 """Wait for a job to complete.
182 Args:
183 job_obj_or_id: Job object or job ID.
184 poll_interval: Polling interval in seconds.
186 Returns:
187 Completed job object.
189 Raises:
190 RuntimeError: If job fails.
191 """
192 if isinstance(job_obj_or_id, int):
193 job = self.retrieve(job_obj_or_id)
194 else:
195 job = job_obj_or_id
197 msg = f"Waiting for job {job.job_id} to complete (polling every {poll_interval}s)."
198 if isinstance(job, Job):
199 msg += f" Logging to {job.log_dir}/{job.name}_{job.job_id}.out"
200 logger.info(msg)
202 previous_status = None
204 while True:
205 job.refresh()
207 # Log status changes
208 if job.status != previous_status:
209 status_str = job.status.value if job.status else "Unknown"
210 logger.info(f"Job {job.job_id} status: {status_str}")
211 previous_status = job.status
213 match job.status:
214 case JobStatus.COMPLETED:
215 logger.info(f"Job {job.job_id} completed successfully")
216 return job
217 case JobStatus.FAILED:
218 err_msg = f"SLURM job {job.job_id} failed.\n"
219 if isinstance(job, Job):
220 log_file = Path(job.log_dir) / f"{job.name}_{job.job_id}.out"
221 if log_file.exists():
222 with open(log_file) as f:
223 err_msg += f.read()
224 err_msg += f"\nLog file: {log_file}"
225 else:
226 err_msg += f"Log file not found: {log_file}"
227 raise RuntimeError(err_msg)
228 case JobStatus.CANCELLED | JobStatus.TIMEOUT:
229 err_msg = (
230 f"SLURM job {job.job_id} was {job.status.value.lower()}.\n"
231 )
232 if isinstance(job, Job):
233 log_file = Path(job.log_dir) / f"{job.name}_{job.job_id}.out"
234 if log_file.exists():
235 with open(log_file) as f:
236 err_msg += f.read()
237 err_msg += f"\nLog file: {log_file}"
238 else:
239 err_msg += f"Log file not found: {log_file}"
240 raise RuntimeError(err_msg)
241 time.sleep(poll_interval)
243 def _get_default_template(self) -> str:
244 """Get the default job template path."""
245 from srunx import templates
247 return str(files(templates).joinpath("base.slurm.jinja"))
250# Convenience functions for backward compatibility
251def submit_job(job: Job | ShellJob, template_path: str | None = None) -> Job | ShellJob:
252 """Submit a job to SLURM (convenience function)."""
253 client = Slurm()
254 return client.run(job, template_path)
257def retrieve_job(job_id: int) -> BaseJob:
258 """Get job status (convenience function)."""
259 client = Slurm()
260 return client.retrieve(job_id)
263def cancel_job(job_id: int) -> None:
264 """Cancel a job (convenience function)."""
265 client = Slurm()
266 client.cancel(job_id)