Coverage for src/srunx/cli/workflow.py: 97%
89 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-17 20:31 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-17 20:31 +0900
1"""CLI interface for workflow management."""
3import argparse
4import sys
5from pathlib import Path
7from srunx.logging import configure_workflow_logging, get_logger
8from srunx.models import Job, ShellJob, Workflow
9from srunx.workflows.runner import WorkflowRunner
11logger = get_logger(__name__)
14def create_workflow_parser() -> argparse.ArgumentParser:
15 """Create argument parser for workflow commands."""
16 parser = argparse.ArgumentParser(
17 description="Execute YAML-defined workflows using SLURM and Prefect",
18 formatter_class=argparse.RawDescriptionHelpFormatter,
19 epilog="""
20Example YAML workflow:
21 name: ml_pipeline
22 tasks:
23 - name: preprocess
24 command: ["python", "preprocess.py"]
25 nodes: 1
27 - name: train
28 command: ["python", "train.py"]
29 depends_on: [preprocess]
30 gpus_per_node: 1
31 conda: ml_env
33 - name: evaluate
34 command: ["python", "evaluate.py"]
35 depends_on: [train]
37 - name: notify
38 command: ["python", "notify.py"]
39 depends_on: [train, evaluate]
40 async: true
41 """,
42 )
44 parser.add_argument(
45 "yaml_file",
46 type=str,
47 help="Path to YAML workflow definition file",
48 )
50 parser.add_argument(
51 "--validate-only",
52 action="store_true",
53 help="Only validate the workflow file without executing",
54 )
56 parser.add_argument(
57 "--dry-run",
58 action="store_true",
59 help="Show what would be executed without running jobs",
60 )
62 parser.add_argument(
63 "--log-level",
64 choices=["DEBUG", "INFO", "WARNING", "ERROR"],
65 default="INFO",
66 help="Set logging level (default: %(default)s)",
67 )
69 return parser
72def cmd_run_workflow(args: argparse.Namespace) -> None:
73 """Handle workflow execution command."""
74 # Configure logging for workflow execution
75 configure_workflow_logging(level=args.log_level)
77 try:
78 yaml_file = Path(args.yaml_file)
79 if not yaml_file.exists():
80 logger.error(f"Workflow file not found: {args.yaml_file}")
81 sys.exit(1)
83 runner = WorkflowRunner()
85 # Load workflow for validation
86 workflow = runner.load_from_yaml(yaml_file)
87 logger.info(
88 f"Loaded workflow '{workflow.name}' with {len(workflow.tasks)} tasks"
89 )
91 # Validate dependencies
92 _validate_workflow_dependencies(workflow)
94 if args.validate_only:
95 logger.info("Workflow validation successful")
96 return
98 if args.dry_run:
99 _show_workflow_plan(workflow)
100 return
102 # Execute workflow
103 logger.info("Starting workflow execution")
104 results = runner.execute_workflow(workflow)
106 logger.info("Workflow execution completed successfully")
107 logger.info("Job Results:")
108 for task_name, job in results.items():
109 if hasattr(job, "job_id") and job.job_id:
110 logger.info(f" {task_name}: Job ID {job.job_id}")
111 else:
112 logger.info(f" {task_name}: {job}")
114 except Exception as e:
115 logger.error(f"Workflow execution failed: {e}")
116 sys.exit(1)
119def _validate_workflow_dependencies(workflow: Workflow) -> None:
120 """Validate workflow task dependencies."""
121 task_names = {task.name for task in workflow.tasks}
123 for task in workflow.tasks:
124 for dependency in task.depends_on:
125 if dependency not in task_names:
126 raise ValueError(
127 f"Task '{task.name}' depends on unknown task '{dependency}'"
128 )
130 # Check for circular dependencies (simple check)
131 visited = set()
132 rec_stack = set()
134 def has_cycle(task_name: str) -> bool:
135 if task_name in rec_stack:
136 return True
137 if task_name in visited:
138 return False
140 visited.add(task_name)
141 rec_stack.add(task_name)
143 task = workflow.get_task(task_name)
144 if task:
145 for dependency in task.depends_on:
146 if has_cycle(dependency):
147 return True
149 rec_stack.remove(task_name)
150 return False
152 for task in workflow.tasks:
153 if has_cycle(task.name):
154 raise ValueError(
155 f"Circular dependency detected involving task '{task.name}'"
156 )
159def _show_workflow_plan(workflow: Workflow) -> None:
160 """Show workflow execution plan."""
161 msg = f"""\
162Workflow execution plan:
163 Workflow: {workflow.name}
164 Tasks: {len(workflow.tasks)}
165"""
167 for task in workflow.tasks:
168 msg += f"\t\tTask: {task.name}\n"
169 if isinstance(task.job, Job):
170 msg += f"\t\t\tCommand: {' '.join(task.job.command or [])}\n"
171 msg += f"\t\t\tResources: {task.job.resources.nodes} nodes, {task.job.resources.gpus_per_node} GPUs/node\n"
172 if task.job.environment.conda:
173 msg += f"\t\t\tConda env: {task.job.environment.conda}\n"
174 if task.job.environment.sqsh:
175 msg += f"\t\t\tSqsh: {task.job.environment.sqsh}\n"
176 elif isinstance(task.job, ShellJob):
177 msg += f"\t\t\tPath: {task.job.path}\n"
179 if task.depends_on:
180 msg += f"\t\t\tDependencies: {', '.join(task.depends_on)}\n"
181 if task.async_execution:
182 msg += "\t\t\tExecution: asynchronous\n"
184 logger.info(msg)
187def main() -> None:
188 """Main entry point for workflow CLI."""
189 parser = create_workflow_parser()
190 args = parser.parse_args()
192 cmd_run_workflow(args)
195if __name__ == "__main__":
196 main()