Coverage for src/srunx/cli/main.py: 65%
249 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-17 20:31 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-17 20:31 +0900
1"""Main CLI interface for srunx."""
3import argparse
4import os
5import sys
6from pathlib import Path
8from srunx.client import Slurm
9from srunx.logging import (
10 configure_cli_logging,
11 configure_workflow_logging,
12 get_logger,
13)
14from srunx.models import Job, JobEnvironment, JobResource, ShellJob, Workflow
15from srunx.workflows.runner import WorkflowRunner
17logger = get_logger(__name__)
20def create_job_parser() -> argparse.ArgumentParser:
21 """Create argument parser for job submission."""
22 parser = argparse.ArgumentParser(
23 description="Submit SLURM jobs with various configurations",
24 formatter_class=argparse.RawDescriptionHelpFormatter,
25 )
27 # Required arguments
28 parser.add_argument(
29 "command",
30 nargs="+",
31 help="Command to execute in the SLURM job",
32 )
34 # Job configuration
35 parser.add_argument(
36 "--name",
37 "--job-name",
38 type=str,
39 default="job",
40 help="Job name (default: %(default)s)",
41 )
42 parser.add_argument(
43 "--log-dir",
44 type=str,
45 default=os.getenv("SLURM_LOG_DIR", "logs"),
46 help="Log directory (default: %(default)s)",
47 )
48 parser.add_argument(
49 "--work-dir",
50 "--chdir",
51 type=str,
52 help="Working directory for the job",
53 )
55 # Resource configuration
56 resource_group = parser.add_argument_group("Resource Options")
57 resource_group.add_argument(
58 "-N",
59 "--nodes",
60 type=int,
61 default=1,
62 help="Number of nodes (default: %(default)s)",
63 )
64 resource_group.add_argument(
65 "--gpus-per-node",
66 type=int,
67 default=0,
68 help="Number of GPUs per node (default: %(default)s)",
69 )
70 resource_group.add_argument(
71 "--ntasks-per-node",
72 type=int,
73 default=1,
74 help="Number of tasks per node (default: %(default)s)",
75 )
76 resource_group.add_argument(
77 "--cpus-per-task",
78 type=int,
79 default=1,
80 help="Number of CPUs per task (default: %(default)s)",
81 )
82 resource_group.add_argument(
83 "--memory",
84 "--mem",
85 type=str,
86 help="Memory per node (e.g., '32GB', '1TB')",
87 )
88 resource_group.add_argument(
89 "--time",
90 "--time-limit",
91 type=str,
92 help="Time limit (e.g., '1:00:00', '30:00', '1-12:00:00')",
93 )
95 # Environment configuration
96 env_group = parser.add_argument_group("Environment Options")
97 env_group.add_argument(
98 "--conda",
99 type=str,
100 help="Conda environment name",
101 )
102 env_group.add_argument(
103 "--venv",
104 type=str,
105 help="Virtual environment path",
106 )
107 env_group.add_argument(
108 "--sqsh",
109 type=str,
110 help="SquashFS image path",
111 )
112 env_group.add_argument(
113 "--env",
114 action="append",
115 dest="env_vars",
116 help="Environment variable KEY=VALUE (can be used multiple times)",
117 )
119 # Execution options
120 exec_group = parser.add_argument_group("Execution Options")
121 exec_group.add_argument(
122 "--template",
123 type=str,
124 help="Path to custom SLURM template file",
125 )
126 exec_group.add_argument(
127 "--wait",
128 action="store_true",
129 help="Wait for job completion",
130 )
131 exec_group.add_argument(
132 "--poll-interval",
133 type=int,
134 default=30,
135 help="Polling interval in seconds when waiting (default: %(default)s)",
136 )
138 # Logging options
139 log_group = parser.add_argument_group("Logging Options")
140 log_group.add_argument(
141 "--log-level",
142 choices=["DEBUG", "INFO", "WARNING", "ERROR"],
143 default="INFO",
144 help="Set logging level (default: %(default)s)",
145 )
146 log_group.add_argument(
147 "--quiet",
148 "-q",
149 action="store_true",
150 help="Only show warnings and errors",
151 )
153 return parser
156def create_status_parser() -> argparse.ArgumentParser:
157 """Create argument parser for job status."""
158 parser = argparse.ArgumentParser(
159 description="Check SLURM job status",
160 formatter_class=argparse.RawDescriptionHelpFormatter,
161 )
163 parser.add_argument(
164 "job_id",
165 type=int,
166 help="SLURM job ID to check",
167 )
169 return parser
172def create_list_parser() -> argparse.ArgumentParser:
173 """Create argument parser for listing jobs."""
174 parser = argparse.ArgumentParser(
175 description="List SLURM jobs",
176 formatter_class=argparse.RawDescriptionHelpFormatter,
177 )
179 parser.add_argument(
180 "--user",
181 "-u",
182 type=str,
183 help="List jobs for specific user (default: current user)",
184 )
186 return parser
189def create_cancel_parser() -> argparse.ArgumentParser:
190 """Create argument parser for job cancellation."""
191 parser = argparse.ArgumentParser(
192 description="Cancel SLURM job",
193 formatter_class=argparse.RawDescriptionHelpFormatter,
194 )
196 parser.add_argument(
197 "job_id",
198 type=int,
199 help="SLURM job ID to cancel",
200 )
202 return parser
205def create_main_parser() -> argparse.ArgumentParser:
206 """Create main argument parser with subcommands."""
207 parser = argparse.ArgumentParser(
208 description="srunx - Python library for SLURM job management",
209 formatter_class=argparse.RawDescriptionHelpFormatter,
210 )
212 # Global options
213 parser.add_argument(
214 "--log-level",
215 "-l",
216 choices=["DEBUG", "INFO", "WARNING", "ERROR"],
217 default="INFO",
218 help="Set logging level (default: %(default)s)",
219 )
220 parser.add_argument(
221 "--quiet",
222 "-q",
223 action="store_true",
224 help="Only show warnings and errors",
225 )
227 subparsers = parser.add_subparsers(dest="command", help="Available commands")
229 # Submit command (default)
230 submit_parser = subparsers.add_parser("submit", help="Submit a SLURM job")
231 submit_parser.set_defaults(func=cmd_submit)
232 _copy_parser_args(create_job_parser(), submit_parser)
234 # Status command
235 status_parser = subparsers.add_parser("status", help="Check job status")
236 status_parser.set_defaults(func=cmd_status)
237 _copy_parser_args(create_status_parser(), status_parser)
239 # List command
240 list_parser = subparsers.add_parser("list", help="List jobs")
241 list_parser.set_defaults(func=cmd_list)
242 _copy_parser_args(create_list_parser(), list_parser)
244 # Cancel command
245 cancel_parser = subparsers.add_parser("cancel", help="Cancel job")
246 cancel_parser.set_defaults(func=cmd_cancel)
247 _copy_parser_args(create_cancel_parser(), cancel_parser)
249 # Flow command
250 flow_parser = subparsers.add_parser("flow", help="Workflow management")
251 flow_parser.set_defaults(func=None) # Will be overridden by subcommands
253 # Flow subcommands
254 flow_subparsers = flow_parser.add_subparsers(
255 dest="flow_command", help="Flow commands"
256 )
258 # Flow run command
259 flow_run_parser = flow_subparsers.add_parser("run", help="Execute workflow")
260 flow_run_parser.set_defaults(func=cmd_flow_run)
261 flow_run_parser.add_argument(
262 "yaml_file",
263 type=str,
264 help="Path to YAML workflow definition file",
265 )
266 flow_run_parser.add_argument(
267 "--dry-run",
268 action="store_true",
269 help="Show what would be executed without running jobs",
270 )
272 # Flow validate command
273 flow_validate_parser = flow_subparsers.add_parser(
274 "validate", help="Validate workflow"
275 )
276 flow_validate_parser.set_defaults(func=cmd_flow_validate)
277 flow_validate_parser.add_argument(
278 "yaml_file",
279 type=str,
280 help="Path to YAML workflow definition file",
281 )
283 return parser
286def _copy_parser_args(
287 source_parser: argparse.ArgumentParser, target_parser: argparse.ArgumentParser
288) -> None:
289 """Copy arguments from source parser to target parser."""
290 for action in source_parser._actions:
291 if action.dest == "help":
292 continue
293 target_parser._add_action(action)
296def _parse_env_vars(env_var_list: list[str] | None) -> dict[str, str]:
297 """Parse environment variables from list of KEY=VALUE strings."""
298 env_vars = {}
299 if env_var_list:
300 for env_var in env_var_list:
301 if "=" in env_var:
302 key, value = env_var.split("=", 1)
303 env_vars[key] = value
304 else:
305 logger.warning(f"Invalid environment variable format: {env_var}")
306 return env_vars
309def cmd_submit(args: argparse.Namespace) -> None:
310 """Handle job submission command."""
311 try:
312 # Parse environment variables
313 env_vars = _parse_env_vars(getattr(args, "env_vars", None))
315 # Create job configuration
316 resources = JobResource(
317 nodes=args.nodes,
318 gpus_per_node=args.gpus_per_node,
319 ntasks_per_node=args.ntasks_per_node,
320 cpus_per_task=args.cpus_per_task,
321 memory_per_node=getattr(args, "memory", None),
322 time_limit=getattr(args, "time", None),
323 )
325 environment = JobEnvironment(
326 conda=getattr(args, "conda", None),
327 venv=getattr(args, "venv", None),
328 sqsh=getattr(args, "sqsh", None),
329 env_vars=env_vars,
330 )
332 job_data = {
333 "name": args.name,
334 "command": args.command,
335 "resources": resources,
336 "environment": environment,
337 "log_dir": args.log_dir,
338 }
340 if args.work_dir is not None:
341 job_data["work_dir"] = args.work_dir
343 job = Job.model_validate(job_data)
345 # Submit job
346 client = Slurm()
347 submitted_job = client.run(job, getattr(args, "template", None))
349 logger.info(f"Submitted job {submitted_job.job_id}: {submitted_job.name}")
351 # Wait for completion if requested
352 if getattr(args, "wait", False):
353 logger.info(f"Waiting for job {submitted_job.job_id} to complete...")
354 completed_job = client.monitor(
355 submitted_job, poll_interval=args.poll_interval
356 )
357 status_str = (
358 completed_job.status.value if completed_job.status else "Unknown"
359 )
360 logger.info(
361 f"Job {submitted_job.job_id} completed with status: {status_str}"
362 )
364 except Exception as e:
365 logger.error(f"Error submitting job: {e}")
366 sys.exit(1)
369def cmd_status(args: argparse.Namespace) -> None:
370 """Handle job status command."""
371 try:
372 client = Slurm()
373 job = client.retrieve(args.job_id)
375 logger.info(f"Job ID: {job.job_id}")
376 logger.info(f"Name: {job.name}")
377 if job.status:
378 logger.info(f"Status: {job.status.value}")
379 else:
380 logger.info("Status: Unknown")
382 except Exception as e:
383 logger.error(f"Error getting job status: {e}")
384 sys.exit(1)
387def cmd_list(args: argparse.Namespace) -> None:
388 """Handle job listing command."""
389 try:
390 client = Slurm()
391 jobs = client.list(getattr(args, "user", None))
393 if not jobs:
394 logger.info("No jobs found")
395 return
397 logger.info(f"{'Job ID':<10} {'Name':<20} {'Status':<12}")
398 logger.info("-" * 45)
399 for job in jobs:
400 status_str = job.status.value if job.status else "Unknown"
401 logger.info(f"{job.job_id:<10} {job.name:<20} {status_str:<12}")
403 except Exception as e:
404 logger.error(f"Error listing jobs: {e}")
405 sys.exit(1)
408def cmd_cancel(args: argparse.Namespace) -> None:
409 """Handle job cancellation command."""
410 try:
411 client = Slurm()
412 client.cancel(args.job_id)
413 logger.info(f"Cancelled job {args.job_id}")
415 except Exception as e:
416 logger.error(f"Error cancelling job: {e}")
417 sys.exit(1)
420def cmd_flow_run(args: argparse.Namespace) -> None:
421 """Handle flow run command."""
422 # Configure logging for workflow execution
423 configure_workflow_logging(level=getattr(args, "log_level", "INFO"))
425 try:
426 yaml_file = Path(args.yaml_file)
427 if not yaml_file.exists():
428 logger.error(f"Workflow file not found: {args.yaml_file}")
429 sys.exit(1)
431 runner = WorkflowRunner()
433 # Load workflow for validation
434 workflow = runner.load_from_yaml(yaml_file)
435 logger.info(
436 f"Loaded workflow '{workflow.name}' with {len(workflow.tasks)} tasks"
437 )
439 # Validate dependencies
440 _validate_workflow_dependencies(workflow)
442 if args.dry_run:
443 _show_workflow_plan(workflow)
444 return
446 # Execute workflow
447 logger.info("Starting workflow execution")
448 results = runner.execute_workflow(workflow)
450 logger.info("Workflow execution completed successfully")
451 logger.info("Job Results:")
452 for task_name, job in results.items():
453 if hasattr(job, "job_id") and job.job_id:
454 logger.info(f" {task_name}: Job ID {job.job_id}")
455 else:
456 logger.info(f" {task_name}: {job}")
458 except Exception as e:
459 logger.error(f"Workflow execution failed: {e}")
460 sys.exit(1)
463def cmd_flow_validate(args: argparse.Namespace) -> None:
464 """Handle flow validate command."""
465 # Configure logging for workflow validation
466 configure_workflow_logging(level=getattr(args, "log_level", "INFO"))
468 try:
469 yaml_file = Path(args.yaml_file)
470 if not yaml_file.exists():
471 logger.error(f"Workflow file not found: {args.yaml_file}")
472 sys.exit(1)
474 runner = WorkflowRunner()
476 # Load workflow for validation
477 workflow = runner.load_from_yaml(yaml_file)
478 logger.info(
479 f"Loaded workflow '{workflow.name}' with {len(workflow.tasks)} tasks"
480 )
482 # Validate dependencies
483 _validate_workflow_dependencies(workflow)
485 logger.info("Workflow validation successful")
487 except Exception as e:
488 logger.error(f"Workflow validation failed: {e}")
489 sys.exit(1)
492def _validate_workflow_dependencies(workflow: Workflow) -> None:
493 """Validate workflow task dependencies."""
494 task_names = {task.name for task in workflow.tasks}
496 for task in workflow.tasks:
497 for dependency in task.depends_on:
498 if dependency not in task_names:
499 raise ValueError(
500 f"Task '{task.name}' depends on unknown task '{dependency}'"
501 )
503 # Check for circular dependencies (simple check)
504 visited = set()
505 rec_stack = set()
507 def has_cycle(task_name: str) -> bool:
508 if task_name in rec_stack:
509 return True
510 if task_name in visited:
511 return False
513 visited.add(task_name)
514 rec_stack.add(task_name)
516 task = workflow.get_task(task_name)
517 if task:
518 for dependency in task.depends_on:
519 if has_cycle(dependency):
520 return True
522 rec_stack.remove(task_name)
523 return False
525 for task in workflow.tasks:
526 if has_cycle(task.name):
527 raise ValueError(
528 f"Circular dependency detected involving task '{task.name}'"
529 )
532def _show_workflow_plan(workflow: Workflow) -> None:
533 """Show workflow execution plan."""
534 msg = """\
535Workflow execution plan:
536 Workflow: {workflow.name}
537 Tasks: {len(workflow.tasks)}
538"""
540 for task in workflow.tasks:
541 msg += f"\t\tTask: {task.name}\n"
542 if isinstance(task.job, Job):
543 msg += f"\t\t\tCommand: {' '.join(task.job.command or [])}\n"
544 msg += f"\t\t\tResources: {task.job.resources.nodes} nodes, {task.job.resources.gpus_per_node} GPUs/node\n"
545 if task.job.environment.conda:
546 msg += f"\t\t\tConda env: {task.job.environment.conda}\n"
547 if task.job.environment.sqsh:
548 msg += f"\t\t\tSqsh: {task.job.environment.sqsh}\n"
549 if task.job.environment.venv:
550 msg += f"\t\t\tVenv: {task.job.environment.venv}\n"
551 elif isinstance(task.job, ShellJob):
552 msg += f"\t\t\tPath: {task.job.path}\n"
553 if task.depends_on:
554 msg += f"\t\t\tDependencies: {', '.join(task.depends_on)}\n"
555 if task.async_execution:
556 msg += "\t\t\tExecution: asynchronous\n"
558 logger.info(msg)
561def main() -> None:
562 """Main entry point for the CLI."""
563 parser = create_main_parser()
564 args = parser.parse_args()
566 # Configure logging
567 log_level = getattr(args, "log_level", "INFO")
568 quiet = getattr(args, "quiet", False)
569 configure_cli_logging(level=log_level, quiet=quiet)
571 # If no command specified, default to submit behavior for backward compatibility
572 if not hasattr(args, "func") or args.func is None:
573 # Check if this is a flow command without subcommand
574 if hasattr(args, "command") and args.command == "flow":
575 if not hasattr(args, "flow_command") or args.flow_command is None:
576 logger.error("Flow command requires a subcommand (run or validate)")
577 parser.print_help()
578 sys.exit(1)
579 else:
580 # Try to parse as submit command
581 submit_parser = create_job_parser()
582 try:
583 submit_args = submit_parser.parse_args()
584 cmd_submit(submit_args)
585 except SystemExit:
586 parser.print_help()
587 sys.exit(1)
588 else:
589 args.func(args)
592if __name__ == "__main__":
593 main()