Coverage for src/srunx/cli/workflow.py: 97%

89 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-17 20:31 +0900

1"""CLI interface for workflow management.""" 

2 

3import argparse 

4import sys 

5from pathlib import Path 

6 

7from srunx.logging import configure_workflow_logging, get_logger 

8from srunx.models import Job, ShellJob, Workflow 

9from srunx.workflows.runner import WorkflowRunner 

10 

11logger = get_logger(__name__) 

12 

13 

14def create_workflow_parser() -> argparse.ArgumentParser: 

15 """Create argument parser for workflow commands.""" 

16 parser = argparse.ArgumentParser( 

17 description="Execute YAML-defined workflows using SLURM and Prefect", 

18 formatter_class=argparse.RawDescriptionHelpFormatter, 

19 epilog=""" 

20Example YAML workflow: 

21 name: ml_pipeline 

22 tasks: 

23 - name: preprocess 

24 command: ["python", "preprocess.py"] 

25 nodes: 1 

26 

27 - name: train 

28 command: ["python", "train.py"] 

29 depends_on: [preprocess] 

30 gpus_per_node: 1 

31 conda: ml_env 

32 

33 - name: evaluate 

34 command: ["python", "evaluate.py"] 

35 depends_on: [train] 

36 

37 - name: notify 

38 command: ["python", "notify.py"] 

39 depends_on: [train, evaluate] 

40 async: true 

41 """, 

42 ) 

43 

44 parser.add_argument( 

45 "yaml_file", 

46 type=str, 

47 help="Path to YAML workflow definition file", 

48 ) 

49 

50 parser.add_argument( 

51 "--validate-only", 

52 action="store_true", 

53 help="Only validate the workflow file without executing", 

54 ) 

55 

56 parser.add_argument( 

57 "--dry-run", 

58 action="store_true", 

59 help="Show what would be executed without running jobs", 

60 ) 

61 

62 parser.add_argument( 

63 "--log-level", 

64 choices=["DEBUG", "INFO", "WARNING", "ERROR"], 

65 default="INFO", 

66 help="Set logging level (default: %(default)s)", 

67 ) 

68 

69 return parser 

70 

71 

72def cmd_run_workflow(args: argparse.Namespace) -> None: 

73 """Handle workflow execution command.""" 

74 # Configure logging for workflow execution 

75 configure_workflow_logging(level=args.log_level) 

76 

77 try: 

78 yaml_file = Path(args.yaml_file) 

79 if not yaml_file.exists(): 

80 logger.error(f"Workflow file not found: {args.yaml_file}") 

81 sys.exit(1) 

82 

83 runner = WorkflowRunner() 

84 

85 # Load workflow for validation 

86 workflow = runner.load_from_yaml(yaml_file) 

87 logger.info( 

88 f"Loaded workflow '{workflow.name}' with {len(workflow.tasks)} tasks" 

89 ) 

90 

91 # Validate dependencies 

92 _validate_workflow_dependencies(workflow) 

93 

94 if args.validate_only: 

95 logger.info("Workflow validation successful") 

96 return 

97 

98 if args.dry_run: 

99 _show_workflow_plan(workflow) 

100 return 

101 

102 # Execute workflow 

103 logger.info("Starting workflow execution") 

104 results = runner.execute_workflow(workflow) 

105 

106 logger.info("Workflow execution completed successfully") 

107 logger.info("Job Results:") 

108 for task_name, job in results.items(): 

109 if hasattr(job, "job_id") and job.job_id: 

110 logger.info(f" {task_name}: Job ID {job.job_id}") 

111 else: 

112 logger.info(f" {task_name}: {job}") 

113 

114 except Exception as e: 

115 logger.error(f"Workflow execution failed: {e}") 

116 sys.exit(1) 

117 

118 

119def _validate_workflow_dependencies(workflow: Workflow) -> None: 

120 """Validate workflow task dependencies.""" 

121 task_names = {task.name for task in workflow.tasks} 

122 

123 for task in workflow.tasks: 

124 for dependency in task.depends_on: 

125 if dependency not in task_names: 

126 raise ValueError( 

127 f"Task '{task.name}' depends on unknown task '{dependency}'" 

128 ) 

129 

130 # Check for circular dependencies (simple check) 

131 visited = set() 

132 rec_stack = set() 

133 

134 def has_cycle(task_name: str) -> bool: 

135 if task_name in rec_stack: 

136 return True 

137 if task_name in visited: 

138 return False 

139 

140 visited.add(task_name) 

141 rec_stack.add(task_name) 

142 

143 task = workflow.get_task(task_name) 

144 if task: 

145 for dependency in task.depends_on: 

146 if has_cycle(dependency): 

147 return True 

148 

149 rec_stack.remove(task_name) 

150 return False 

151 

152 for task in workflow.tasks: 

153 if has_cycle(task.name): 

154 raise ValueError( 

155 f"Circular dependency detected involving task '{task.name}'" 

156 ) 

157 

158 

159def _show_workflow_plan(workflow: Workflow) -> None: 

160 """Show workflow execution plan.""" 

161 msg = f"""\ 

162Workflow execution plan: 

163 Workflow: {workflow.name} 

164 Tasks: {len(workflow.tasks)} 

165""" 

166 

167 for task in workflow.tasks: 

168 msg += f"\t\tTask: {task.name}\n" 

169 if isinstance(task.job, Job): 

170 msg += f"\t\t\tCommand: {' '.join(task.job.command or [])}\n" 

171 msg += f"\t\t\tResources: {task.job.resources.nodes} nodes, {task.job.resources.gpus_per_node} GPUs/node\n" 

172 if task.job.environment.conda: 

173 msg += f"\t\t\tConda env: {task.job.environment.conda}\n" 

174 if task.job.environment.sqsh: 

175 msg += f"\t\t\tSqsh: {task.job.environment.sqsh}\n" 

176 elif isinstance(task.job, ShellJob): 

177 msg += f"\t\t\tPath: {task.job.path}\n" 

178 

179 if task.depends_on: 

180 msg += f"\t\t\tDependencies: {', '.join(task.depends_on)}\n" 

181 if task.async_execution: 

182 msg += "\t\t\tExecution: asynchronous\n" 

183 

184 logger.info(msg) 

185 

186 

187def main() -> None: 

188 """Main entry point for workflow CLI.""" 

189 parser = create_workflow_parser() 

190 args = parser.parse_args() 

191 

192 cmd_run_workflow(args) 

193 

194 

195if __name__ == "__main__": 

196 main()