Coverage for src/srunx/workflows/tasks.py: 100%
30 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-17 20:31 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-17 20:31 +0900
1"""Prefect tasks for SLURM workflow management."""
3from prefect import task
4from srunx.client import Slurm
5from srunx.logging import get_logger
6from srunx.models import BaseJob, Job, ShellJob
8logger = get_logger(__name__)
11@task
12def submit_and_monitor_job(
13 job: Job | ShellJob, poll_interval: int = 30
14) -> Job | ShellJob:
15 """Submit a SLURM job and monitor until completion.
17 This Prefect task handles the complete lifecycle of a SLURM job:
18 submission, monitoring, and completion verification.
20 Args:
21 job: Job configuration.
22 poll_interval: Status polling interval in seconds.
24 Returns:
25 Completed Job instance.
27 Raises:
28 RuntimeError: If the SLURM job fails.
29 subprocess.CalledProcessError: If job operations fail.
30 """
31 logger.info(f"Starting SLURM job submission and monitoring for '{job.name}'")
32 client = Slurm()
34 # Submit the job
35 submitted_job = client.run(job)
36 logger.info(f"Job '{submitted_job.name}' submitted with ID {submitted_job.job_id}")
38 # Wait for completion
39 completed_job = client.monitor(submitted_job, poll_interval=poll_interval)
40 logger.info(f"Job '{completed_job.name}' (ID: {completed_job.job_id}) completed")
42 assert isinstance(completed_job, Job | ShellJob)
44 return completed_job
47@task
48def submit_job_async(job: Job | ShellJob) -> Job | ShellJob:
49 """Submit a SLURM job without waiting for completion.
51 Args:
52 job: Job configuration.
54 Returns:
55 Submitted Job instance with job_id.
56 """
57 logger.info(f"Submitting async SLURM job '{job.name}'")
58 client = Slurm()
59 submitted_job = client.run(job)
60 logger.info(
61 f"Async job '{submitted_job.name}' submitted with ID {submitted_job.job_id}"
62 )
63 assert isinstance(submitted_job, Job | ShellJob)
64 return submitted_job
67@task
68def wait_for_job(job_id: int, poll_interval: int = 30) -> BaseJob:
69 """Wait for a job to complete.
71 Args:
72 job_id: SLURM job ID.
73 poll_interval: Polling interval in seconds.
75 Returns:
76 Completed job object.
77 """
78 logger.info(f"Waiting for job {job_id} to complete")
79 client = Slurm()
80 completed_job = client.monitor(job_id, poll_interval)
81 logger.info(f"Job {job_id} completed")
82 return completed_job