Coverage for src/srunx/workflows/runner.py: 99%

75 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-17 20:31 +0900

1"""Workflow runner for executing YAML-defined workflows with SLURM and Prefect.""" 

2 

3from pathlib import Path 

4from typing import Any 

5 

6import yaml 

7from prefect import flow 

8from prefect.states import State 

9from srunx.logging import get_logger 

10from srunx.models import ( 

11 Job, 

12 JobEnvironment, 

13 JobResource, 

14 ShellJob, 

15 Workflow, 

16 WorkflowTask, 

17) 

18from srunx.workflows.tasks import submit_and_monitor_job, submit_job_async 

19 

20logger = get_logger(__name__) 

21 

22 

23class WorkflowRunner: 

24 """Runner for executing workflows defined in YAML.""" 

25 

26 def __init__(self) -> None: 

27 """Initialize workflow runner.""" 

28 self.executed_tasks: dict[str, State[Job | ShellJob]] = {} 

29 

30 def load_from_yaml(self, yaml_path: str | Path) -> Workflow: 

31 """Load and validate a workflow from a YAML file. 

32 

33 Args: 

34 yaml_path: Path to the YAML workflow definition file. 

35 

36 Returns: 

37 Validated Workflow object. 

38 

39 Raises: 

40 FileNotFoundError: If the YAML file doesn't exist. 

41 yaml.YAMLError: If the YAML is malformed. 

42 ValidationError: If the workflow structure is invalid. 

43 """ 

44 yaml_file = Path(yaml_path) 

45 if not yaml_file.exists(): 

46 raise FileNotFoundError(f"Workflow file not found: {yaml_path}") 

47 

48 with open(yaml_file, encoding="utf-8") as f: 

49 data = yaml.safe_load(f) 

50 

51 return self._parse_workflow_data(data) 

52 

53 def _parse_workflow_data(self, data: dict) -> Workflow: 

54 """Parse workflow data from dictionary.""" 

55 workflow_name = data.get("name", "unnamed_workflow") 

56 tasks_data = data.get("tasks", []) 

57 

58 tasks = [] 

59 for task_data in tasks_data: 

60 task = self._parse_task_data(task_data) 

61 tasks.append(task) 

62 

63 return Workflow(name=workflow_name, tasks=tasks) 

64 

65 def _parse_task_data(self, task_data: dict) -> WorkflowTask: 

66 """Parse a single task from dictionary.""" 

67 # Basic task properties 

68 name = task_data["name"] 

69 file = task_data.get("file") 

70 async_execution = task_data.get("async", False) 

71 depends_on = task_data.get("depends_on", []) 

72 

73 job_data: dict[str, Any] = {"name": name} 

74 if task_data.get("log_dir") is not None: 

75 job_data["log_dir"] = task_data.get("log_dir") 

76 if task_data.get("work_dir") is not None: 

77 job_data["work_dir"] = task_data.get("work_dir") 

78 

79 if file: 

80 job_data |= { 

81 "file": file, 

82 "resources": None, 

83 "environment": None, 

84 } 

85 else: 

86 command = task_data.get("command") 

87 

88 # Resource configuration 

89 resources = JobResource( 

90 nodes=task_data.get("nodes", 1), 

91 gpus_per_node=task_data.get("gpus_per_node", 0), 

92 ntasks_per_node=task_data.get("ntasks_per_node", 1), 

93 cpus_per_task=task_data.get("cpus_per_task", 1), 

94 memory_per_node=task_data.get("memory_per_node"), 

95 time_limit=task_data.get("time_limit"), 

96 ) 

97 

98 # Environment configuration 

99 environment = JobEnvironment( 

100 conda=task_data.get("conda"), 

101 venv=task_data.get("venv"), 

102 sqsh=task_data.get("sqsh") or task_data.get("container"), 

103 env_vars=task_data.get("env_vars", {}), 

104 ) 

105 

106 job_data |= { 

107 "command": command, 

108 "resources": resources, 

109 "environment": environment, 

110 } 

111 

112 # Create job 

113 job = Job.model_validate(job_data) 

114 

115 return WorkflowTask( 

116 name=name, 

117 job=job, 

118 depends_on=depends_on, 

119 async_execution=async_execution, 

120 ) 

121 

122 def execute_workflow(self, workflow: Workflow) -> dict[str, State[Job | ShellJob]]: 

123 """Execute a workflow using Prefect. 

124 

125 Args: 

126 workflow: Workflow to execute. 

127 

128 Returns: 

129 Dictionary mapping task names to Job instances. 

130 """ 

131 task_map = {task.name: task for task in workflow.tasks} 

132 

133 @flow(name=workflow.name) 

134 def workflow_flow() -> dict[str, State[Job | ShellJob]]: 

135 """Prefect flow for workflow execution.""" 

136 

137 def execute_task(task_name: str) -> State[Job | ShellJob]: 

138 """Execute a task and its dependencies recursively.""" 

139 if task_name in self.executed_tasks: 

140 return self.executed_tasks[task_name] 

141 

142 task = task_map[task_name] 

143 

144 # Execute dependencies first 

145 for dependency in task.depends_on: 

146 execute_task(dependency) 

147 

148 # Execute the task 

149 if task.async_execution: 

150 job_future = submit_job_async(task.job) 

151 else: 

152 job_future = submit_and_monitor_job(task.job) 

153 

154 self.executed_tasks[task_name] = job_future 

155 return job_future 

156 

157 # Execute all tasks 

158 results: dict[str, State[Job | ShellJob]] = {} 

159 for task in workflow.tasks: 

160 results[task.name] = execute_task(task.name) 

161 

162 return results 

163 

164 return workflow_flow() 

165 

166 def execute_from_yaml( 

167 self, yaml_path: str | Path 

168 ) -> dict[str, State[Job | ShellJob]]: 

169 """Load and execute a workflow from YAML file. 

170 

171 Args: 

172 yaml_path: Path to YAML workflow file. 

173 

174 Returns: 

175 Dictionary mapping task names to Job instances. 

176 """ 

177 logger.info(f"Loading workflow from {yaml_path}") 

178 workflow = self.load_from_yaml(yaml_path) 

179 

180 logger.info( 

181 f"Executing workflow '{workflow.name}' with {len(workflow.tasks)} tasks" 

182 ) 

183 results = self.execute_workflow(workflow) 

184 

185 logger.info("Workflow execution completed") 

186 return results 

187 

188 

189def run_workflow_from_file(yaml_path: str | Path) -> dict[str, State[Job | ShellJob]]: 

190 """Convenience function to run workflow from YAML file. 

191 

192 Args: 

193 yaml_path: Path to YAML workflow file. 

194 

195 Returns: 

196 Dictionary mapping task names to Job instances. 

197 """ 

198 runner = WorkflowRunner() 

199 return runner.execute_from_yaml(yaml_path)