Coverage for src/srunx/workflows/tasks.py: 100%

30 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-17 20:31 +0900

1"""Prefect tasks for SLURM workflow management.""" 

2 

3from prefect import task 

4from srunx.client import Slurm 

5from srunx.logging import get_logger 

6from srunx.models import BaseJob, Job, ShellJob 

7 

8logger = get_logger(__name__) 

9 

10 

11@task 

12def submit_and_monitor_job( 

13 job: Job | ShellJob, poll_interval: int = 30 

14) -> Job | ShellJob: 

15 """Submit a SLURM job and monitor until completion. 

16 

17 This Prefect task handles the complete lifecycle of a SLURM job: 

18 submission, monitoring, and completion verification. 

19 

20 Args: 

21 job: Job configuration. 

22 poll_interval: Status polling interval in seconds. 

23 

24 Returns: 

25 Completed Job instance. 

26 

27 Raises: 

28 RuntimeError: If the SLURM job fails. 

29 subprocess.CalledProcessError: If job operations fail. 

30 """ 

31 logger.info(f"Starting SLURM job submission and monitoring for '{job.name}'") 

32 client = Slurm() 

33 

34 # Submit the job 

35 submitted_job = client.run(job) 

36 logger.info(f"Job '{submitted_job.name}' submitted with ID {submitted_job.job_id}") 

37 

38 # Wait for completion 

39 completed_job = client.monitor(submitted_job, poll_interval=poll_interval) 

40 logger.info(f"Job '{completed_job.name}' (ID: {completed_job.job_id}) completed") 

41 

42 assert isinstance(completed_job, Job | ShellJob) 

43 

44 return completed_job 

45 

46 

47@task 

48def submit_job_async(job: Job | ShellJob) -> Job | ShellJob: 

49 """Submit a SLURM job without waiting for completion. 

50 

51 Args: 

52 job: Job configuration. 

53 

54 Returns: 

55 Submitted Job instance with job_id. 

56 """ 

57 logger.info(f"Submitting async SLURM job '{job.name}'") 

58 client = Slurm() 

59 submitted_job = client.run(job) 

60 logger.info( 

61 f"Async job '{submitted_job.name}' submitted with ID {submitted_job.job_id}" 

62 ) 

63 assert isinstance(submitted_job, Job | ShellJob) 

64 return submitted_job 

65 

66 

67@task 

68def wait_for_job(job_id: int, poll_interval: int = 30) -> BaseJob: 

69 """Wait for a job to complete. 

70 

71 Args: 

72 job_id: SLURM job ID. 

73 poll_interval: Polling interval in seconds. 

74 

75 Returns: 

76 Completed job object. 

77 """ 

78 logger.info(f"Waiting for job {job_id} to complete") 

79 client = Slurm() 

80 completed_job = client.monitor(job_id, poll_interval) 

81 logger.info(f"Job {job_id} completed") 

82 return completed_job