Coverage for src/srunx/cli/main.py: 65%

249 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-17 20:31 +0900

1"""Main CLI interface for srunx.""" 

2 

3import argparse 

4import os 

5import sys 

6from pathlib import Path 

7 

8from srunx.client import Slurm 

9from srunx.logging import ( 

10 configure_cli_logging, 

11 configure_workflow_logging, 

12 get_logger, 

13) 

14from srunx.models import Job, JobEnvironment, JobResource, ShellJob, Workflow 

15from srunx.workflows.runner import WorkflowRunner 

16 

17logger = get_logger(__name__) 

18 

19 

20def create_job_parser() -> argparse.ArgumentParser: 

21 """Create argument parser for job submission.""" 

22 parser = argparse.ArgumentParser( 

23 description="Submit SLURM jobs with various configurations", 

24 formatter_class=argparse.RawDescriptionHelpFormatter, 

25 ) 

26 

27 # Required arguments 

28 parser.add_argument( 

29 "command", 

30 nargs="+", 

31 help="Command to execute in the SLURM job", 

32 ) 

33 

34 # Job configuration 

35 parser.add_argument( 

36 "--name", 

37 "--job-name", 

38 type=str, 

39 default="job", 

40 help="Job name (default: %(default)s)", 

41 ) 

42 parser.add_argument( 

43 "--log-dir", 

44 type=str, 

45 default=os.getenv("SLURM_LOG_DIR", "logs"), 

46 help="Log directory (default: %(default)s)", 

47 ) 

48 parser.add_argument( 

49 "--work-dir", 

50 "--chdir", 

51 type=str, 

52 help="Working directory for the job", 

53 ) 

54 

55 # Resource configuration 

56 resource_group = parser.add_argument_group("Resource Options") 

57 resource_group.add_argument( 

58 "-N", 

59 "--nodes", 

60 type=int, 

61 default=1, 

62 help="Number of nodes (default: %(default)s)", 

63 ) 

64 resource_group.add_argument( 

65 "--gpus-per-node", 

66 type=int, 

67 default=0, 

68 help="Number of GPUs per node (default: %(default)s)", 

69 ) 

70 resource_group.add_argument( 

71 "--ntasks-per-node", 

72 type=int, 

73 default=1, 

74 help="Number of tasks per node (default: %(default)s)", 

75 ) 

76 resource_group.add_argument( 

77 "--cpus-per-task", 

78 type=int, 

79 default=1, 

80 help="Number of CPUs per task (default: %(default)s)", 

81 ) 

82 resource_group.add_argument( 

83 "--memory", 

84 "--mem", 

85 type=str, 

86 help="Memory per node (e.g., '32GB', '1TB')", 

87 ) 

88 resource_group.add_argument( 

89 "--time", 

90 "--time-limit", 

91 type=str, 

92 help="Time limit (e.g., '1:00:00', '30:00', '1-12:00:00')", 

93 ) 

94 

95 # Environment configuration 

96 env_group = parser.add_argument_group("Environment Options") 

97 env_group.add_argument( 

98 "--conda", 

99 type=str, 

100 help="Conda environment name", 

101 ) 

102 env_group.add_argument( 

103 "--venv", 

104 type=str, 

105 help="Virtual environment path", 

106 ) 

107 env_group.add_argument( 

108 "--sqsh", 

109 type=str, 

110 help="SquashFS image path", 

111 ) 

112 env_group.add_argument( 

113 "--env", 

114 action="append", 

115 dest="env_vars", 

116 help="Environment variable KEY=VALUE (can be used multiple times)", 

117 ) 

118 

119 # Execution options 

120 exec_group = parser.add_argument_group("Execution Options") 

121 exec_group.add_argument( 

122 "--template", 

123 type=str, 

124 help="Path to custom SLURM template file", 

125 ) 

126 exec_group.add_argument( 

127 "--wait", 

128 action="store_true", 

129 help="Wait for job completion", 

130 ) 

131 exec_group.add_argument( 

132 "--poll-interval", 

133 type=int, 

134 default=30, 

135 help="Polling interval in seconds when waiting (default: %(default)s)", 

136 ) 

137 

138 # Logging options 

139 log_group = parser.add_argument_group("Logging Options") 

140 log_group.add_argument( 

141 "--log-level", 

142 choices=["DEBUG", "INFO", "WARNING", "ERROR"], 

143 default="INFO", 

144 help="Set logging level (default: %(default)s)", 

145 ) 

146 log_group.add_argument( 

147 "--quiet", 

148 "-q", 

149 action="store_true", 

150 help="Only show warnings and errors", 

151 ) 

152 

153 return parser 

154 

155 

156def create_status_parser() -> argparse.ArgumentParser: 

157 """Create argument parser for job status.""" 

158 parser = argparse.ArgumentParser( 

159 description="Check SLURM job status", 

160 formatter_class=argparse.RawDescriptionHelpFormatter, 

161 ) 

162 

163 parser.add_argument( 

164 "job_id", 

165 type=int, 

166 help="SLURM job ID to check", 

167 ) 

168 

169 return parser 

170 

171 

172def create_list_parser() -> argparse.ArgumentParser: 

173 """Create argument parser for listing jobs.""" 

174 parser = argparse.ArgumentParser( 

175 description="List SLURM jobs", 

176 formatter_class=argparse.RawDescriptionHelpFormatter, 

177 ) 

178 

179 parser.add_argument( 

180 "--user", 

181 "-u", 

182 type=str, 

183 help="List jobs for specific user (default: current user)", 

184 ) 

185 

186 return parser 

187 

188 

189def create_cancel_parser() -> argparse.ArgumentParser: 

190 """Create argument parser for job cancellation.""" 

191 parser = argparse.ArgumentParser( 

192 description="Cancel SLURM job", 

193 formatter_class=argparse.RawDescriptionHelpFormatter, 

194 ) 

195 

196 parser.add_argument( 

197 "job_id", 

198 type=int, 

199 help="SLURM job ID to cancel", 

200 ) 

201 

202 return parser 

203 

204 

205def create_main_parser() -> argparse.ArgumentParser: 

206 """Create main argument parser with subcommands.""" 

207 parser = argparse.ArgumentParser( 

208 description="srunx - Python library for SLURM job management", 

209 formatter_class=argparse.RawDescriptionHelpFormatter, 

210 ) 

211 

212 # Global options 

213 parser.add_argument( 

214 "--log-level", 

215 "-l", 

216 choices=["DEBUG", "INFO", "WARNING", "ERROR"], 

217 default="INFO", 

218 help="Set logging level (default: %(default)s)", 

219 ) 

220 parser.add_argument( 

221 "--quiet", 

222 "-q", 

223 action="store_true", 

224 help="Only show warnings and errors", 

225 ) 

226 

227 subparsers = parser.add_subparsers(dest="command", help="Available commands") 

228 

229 # Submit command (default) 

230 submit_parser = subparsers.add_parser("submit", help="Submit a SLURM job") 

231 submit_parser.set_defaults(func=cmd_submit) 

232 _copy_parser_args(create_job_parser(), submit_parser) 

233 

234 # Status command 

235 status_parser = subparsers.add_parser("status", help="Check job status") 

236 status_parser.set_defaults(func=cmd_status) 

237 _copy_parser_args(create_status_parser(), status_parser) 

238 

239 # List command 

240 list_parser = subparsers.add_parser("list", help="List jobs") 

241 list_parser.set_defaults(func=cmd_list) 

242 _copy_parser_args(create_list_parser(), list_parser) 

243 

244 # Cancel command 

245 cancel_parser = subparsers.add_parser("cancel", help="Cancel job") 

246 cancel_parser.set_defaults(func=cmd_cancel) 

247 _copy_parser_args(create_cancel_parser(), cancel_parser) 

248 

249 # Flow command 

250 flow_parser = subparsers.add_parser("flow", help="Workflow management") 

251 flow_parser.set_defaults(func=None) # Will be overridden by subcommands 

252 

253 # Flow subcommands 

254 flow_subparsers = flow_parser.add_subparsers( 

255 dest="flow_command", help="Flow commands" 

256 ) 

257 

258 # Flow run command 

259 flow_run_parser = flow_subparsers.add_parser("run", help="Execute workflow") 

260 flow_run_parser.set_defaults(func=cmd_flow_run) 

261 flow_run_parser.add_argument( 

262 "yaml_file", 

263 type=str, 

264 help="Path to YAML workflow definition file", 

265 ) 

266 flow_run_parser.add_argument( 

267 "--dry-run", 

268 action="store_true", 

269 help="Show what would be executed without running jobs", 

270 ) 

271 

272 # Flow validate command 

273 flow_validate_parser = flow_subparsers.add_parser( 

274 "validate", help="Validate workflow" 

275 ) 

276 flow_validate_parser.set_defaults(func=cmd_flow_validate) 

277 flow_validate_parser.add_argument( 

278 "yaml_file", 

279 type=str, 

280 help="Path to YAML workflow definition file", 

281 ) 

282 

283 return parser 

284 

285 

286def _copy_parser_args( 

287 source_parser: argparse.ArgumentParser, target_parser: argparse.ArgumentParser 

288) -> None: 

289 """Copy arguments from source parser to target parser.""" 

290 for action in source_parser._actions: 

291 if action.dest == "help": 

292 continue 

293 target_parser._add_action(action) 

294 

295 

296def _parse_env_vars(env_var_list: list[str] | None) -> dict[str, str]: 

297 """Parse environment variables from list of KEY=VALUE strings.""" 

298 env_vars = {} 

299 if env_var_list: 

300 for env_var in env_var_list: 

301 if "=" in env_var: 

302 key, value = env_var.split("=", 1) 

303 env_vars[key] = value 

304 else: 

305 logger.warning(f"Invalid environment variable format: {env_var}") 

306 return env_vars 

307 

308 

309def cmd_submit(args: argparse.Namespace) -> None: 

310 """Handle job submission command.""" 

311 try: 

312 # Parse environment variables 

313 env_vars = _parse_env_vars(getattr(args, "env_vars", None)) 

314 

315 # Create job configuration 

316 resources = JobResource( 

317 nodes=args.nodes, 

318 gpus_per_node=args.gpus_per_node, 

319 ntasks_per_node=args.ntasks_per_node, 

320 cpus_per_task=args.cpus_per_task, 

321 memory_per_node=getattr(args, "memory", None), 

322 time_limit=getattr(args, "time", None), 

323 ) 

324 

325 environment = JobEnvironment( 

326 conda=getattr(args, "conda", None), 

327 venv=getattr(args, "venv", None), 

328 sqsh=getattr(args, "sqsh", None), 

329 env_vars=env_vars, 

330 ) 

331 

332 job_data = { 

333 "name": args.name, 

334 "command": args.command, 

335 "resources": resources, 

336 "environment": environment, 

337 "log_dir": args.log_dir, 

338 } 

339 

340 if args.work_dir is not None: 

341 job_data["work_dir"] = args.work_dir 

342 

343 job = Job.model_validate(job_data) 

344 

345 # Submit job 

346 client = Slurm() 

347 submitted_job = client.run(job, getattr(args, "template", None)) 

348 

349 logger.info(f"Submitted job {submitted_job.job_id}: {submitted_job.name}") 

350 

351 # Wait for completion if requested 

352 if getattr(args, "wait", False): 

353 logger.info(f"Waiting for job {submitted_job.job_id} to complete...") 

354 completed_job = client.monitor( 

355 submitted_job, poll_interval=args.poll_interval 

356 ) 

357 status_str = ( 

358 completed_job.status.value if completed_job.status else "Unknown" 

359 ) 

360 logger.info( 

361 f"Job {submitted_job.job_id} completed with status: {status_str}" 

362 ) 

363 

364 except Exception as e: 

365 logger.error(f"Error submitting job: {e}") 

366 sys.exit(1) 

367 

368 

369def cmd_status(args: argparse.Namespace) -> None: 

370 """Handle job status command.""" 

371 try: 

372 client = Slurm() 

373 job = client.retrieve(args.job_id) 

374 

375 logger.info(f"Job ID: {job.job_id}") 

376 logger.info(f"Name: {job.name}") 

377 if job.status: 

378 logger.info(f"Status: {job.status.value}") 

379 else: 

380 logger.info("Status: Unknown") 

381 

382 except Exception as e: 

383 logger.error(f"Error getting job status: {e}") 

384 sys.exit(1) 

385 

386 

387def cmd_list(args: argparse.Namespace) -> None: 

388 """Handle job listing command.""" 

389 try: 

390 client = Slurm() 

391 jobs = client.list(getattr(args, "user", None)) 

392 

393 if not jobs: 

394 logger.info("No jobs found") 

395 return 

396 

397 logger.info(f"{'Job ID':<10} {'Name':<20} {'Status':<12}") 

398 logger.info("-" * 45) 

399 for job in jobs: 

400 status_str = job.status.value if job.status else "Unknown" 

401 logger.info(f"{job.job_id:<10} {job.name:<20} {status_str:<12}") 

402 

403 except Exception as e: 

404 logger.error(f"Error listing jobs: {e}") 

405 sys.exit(1) 

406 

407 

408def cmd_cancel(args: argparse.Namespace) -> None: 

409 """Handle job cancellation command.""" 

410 try: 

411 client = Slurm() 

412 client.cancel(args.job_id) 

413 logger.info(f"Cancelled job {args.job_id}") 

414 

415 except Exception as e: 

416 logger.error(f"Error cancelling job: {e}") 

417 sys.exit(1) 

418 

419 

420def cmd_flow_run(args: argparse.Namespace) -> None: 

421 """Handle flow run command.""" 

422 # Configure logging for workflow execution 

423 configure_workflow_logging(level=getattr(args, "log_level", "INFO")) 

424 

425 try: 

426 yaml_file = Path(args.yaml_file) 

427 if not yaml_file.exists(): 

428 logger.error(f"Workflow file not found: {args.yaml_file}") 

429 sys.exit(1) 

430 

431 runner = WorkflowRunner() 

432 

433 # Load workflow for validation 

434 workflow = runner.load_from_yaml(yaml_file) 

435 logger.info( 

436 f"Loaded workflow '{workflow.name}' with {len(workflow.tasks)} tasks" 

437 ) 

438 

439 # Validate dependencies 

440 _validate_workflow_dependencies(workflow) 

441 

442 if args.dry_run: 

443 _show_workflow_plan(workflow) 

444 return 

445 

446 # Execute workflow 

447 logger.info("Starting workflow execution") 

448 results = runner.execute_workflow(workflow) 

449 

450 logger.info("Workflow execution completed successfully") 

451 logger.info("Job Results:") 

452 for task_name, job in results.items(): 

453 if hasattr(job, "job_id") and job.job_id: 

454 logger.info(f" {task_name}: Job ID {job.job_id}") 

455 else: 

456 logger.info(f" {task_name}: {job}") 

457 

458 except Exception as e: 

459 logger.error(f"Workflow execution failed: {e}") 

460 sys.exit(1) 

461 

462 

463def cmd_flow_validate(args: argparse.Namespace) -> None: 

464 """Handle flow validate command.""" 

465 # Configure logging for workflow validation 

466 configure_workflow_logging(level=getattr(args, "log_level", "INFO")) 

467 

468 try: 

469 yaml_file = Path(args.yaml_file) 

470 if not yaml_file.exists(): 

471 logger.error(f"Workflow file not found: {args.yaml_file}") 

472 sys.exit(1) 

473 

474 runner = WorkflowRunner() 

475 

476 # Load workflow for validation 

477 workflow = runner.load_from_yaml(yaml_file) 

478 logger.info( 

479 f"Loaded workflow '{workflow.name}' with {len(workflow.tasks)} tasks" 

480 ) 

481 

482 # Validate dependencies 

483 _validate_workflow_dependencies(workflow) 

484 

485 logger.info("Workflow validation successful") 

486 

487 except Exception as e: 

488 logger.error(f"Workflow validation failed: {e}") 

489 sys.exit(1) 

490 

491 

492def _validate_workflow_dependencies(workflow: Workflow) -> None: 

493 """Validate workflow task dependencies.""" 

494 task_names = {task.name for task in workflow.tasks} 

495 

496 for task in workflow.tasks: 

497 for dependency in task.depends_on: 

498 if dependency not in task_names: 

499 raise ValueError( 

500 f"Task '{task.name}' depends on unknown task '{dependency}'" 

501 ) 

502 

503 # Check for circular dependencies (simple check) 

504 visited = set() 

505 rec_stack = set() 

506 

507 def has_cycle(task_name: str) -> bool: 

508 if task_name in rec_stack: 

509 return True 

510 if task_name in visited: 

511 return False 

512 

513 visited.add(task_name) 

514 rec_stack.add(task_name) 

515 

516 task = workflow.get_task(task_name) 

517 if task: 

518 for dependency in task.depends_on: 

519 if has_cycle(dependency): 

520 return True 

521 

522 rec_stack.remove(task_name) 

523 return False 

524 

525 for task in workflow.tasks: 

526 if has_cycle(task.name): 

527 raise ValueError( 

528 f"Circular dependency detected involving task '{task.name}'" 

529 ) 

530 

531 

532def _show_workflow_plan(workflow: Workflow) -> None: 

533 """Show workflow execution plan.""" 

534 msg = """\ 

535Workflow execution plan: 

536 Workflow: {workflow.name} 

537 Tasks: {len(workflow.tasks)} 

538""" 

539 

540 for task in workflow.tasks: 

541 msg += f"\t\tTask: {task.name}\n" 

542 if isinstance(task.job, Job): 

543 msg += f"\t\t\tCommand: {' '.join(task.job.command or [])}\n" 

544 msg += f"\t\t\tResources: {task.job.resources.nodes} nodes, {task.job.resources.gpus_per_node} GPUs/node\n" 

545 if task.job.environment.conda: 

546 msg += f"\t\t\tConda env: {task.job.environment.conda}\n" 

547 if task.job.environment.sqsh: 

548 msg += f"\t\t\tSqsh: {task.job.environment.sqsh}\n" 

549 if task.job.environment.venv: 

550 msg += f"\t\t\tVenv: {task.job.environment.venv}\n" 

551 elif isinstance(task.job, ShellJob): 

552 msg += f"\t\t\tPath: {task.job.path}\n" 

553 if task.depends_on: 

554 msg += f"\t\t\tDependencies: {', '.join(task.depends_on)}\n" 

555 if task.async_execution: 

556 msg += "\t\t\tExecution: asynchronous\n" 

557 

558 logger.info(msg) 

559 

560 

561def main() -> None: 

562 """Main entry point for the CLI.""" 

563 parser = create_main_parser() 

564 args = parser.parse_args() 

565 

566 # Configure logging 

567 log_level = getattr(args, "log_level", "INFO") 

568 quiet = getattr(args, "quiet", False) 

569 configure_cli_logging(level=log_level, quiet=quiet) 

570 

571 # If no command specified, default to submit behavior for backward compatibility 

572 if not hasattr(args, "func") or args.func is None: 

573 # Check if this is a flow command without subcommand 

574 if hasattr(args, "command") and args.command == "flow": 

575 if not hasattr(args, "flow_command") or args.flow_command is None: 

576 logger.error("Flow command requires a subcommand (run or validate)") 

577 parser.print_help() 

578 sys.exit(1) 

579 else: 

580 # Try to parse as submit command 

581 submit_parser = create_job_parser() 

582 try: 

583 submit_args = submit_parser.parse_args() 

584 cmd_submit(submit_args) 

585 except SystemExit: 

586 parser.print_help() 

587 sys.exit(1) 

588 else: 

589 args.func(args) 

590 

591 

592if __name__ == "__main__": 

593 main()