Coverage for little_loops / fsm / schema.py: 36%

248 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-03-18 16:20 -0500

1"""FSM loop schema dataclasses. 

2 

3This module defines the type-safe dataclasses that represent FSM loop 

4definitions. These match the universal FSM schema described in the 

5design documentation. 

6 

7The schema supports: 

8- Multiple evaluator types (exit_code, output_numeric, etc.) 

9- Two-layer transition system (evaluate + route) 

10- Both shorthand (on_success/on_failure) and full routing syntax 

11- Context variables and captured values 

12- LLM evaluation configuration 

13""" 

14 

15from __future__ import annotations 

16 

17from dataclasses import dataclass, field 

18from typing import Any, Literal 

19 

20# Default LLM model for structured evaluation 

21DEFAULT_LLM_MODEL: str = "sonnet" 

22 

23 

24@dataclass 

25class EvaluateConfig: 

26 """Evaluator configuration for action result interpretation. 

27 

28 The evaluator determines how to interpret an action's output and 

29 produce a verdict string for routing. 

30 

31 Attributes: 

32 type: Evaluator type. One of: 

33 - exit_code: Map exit codes to verdicts (default for shell) 

34 - output_numeric: Compare numeric output to target 

35 - output_json: Extract and compare JSON path value 

36 - output_contains: Pattern matching on stdout 

37 - convergence: Compare current vs previous value toward target 

38 - llm_structured: Use LLM with structured output (default for slash) 

39 operator: Comparison operator (eq, ne, lt, le, gt, ge) 

40 target: Target value for comparison 

41 tolerance: Acceptable distance from target (for convergence) 

42 pattern: Pattern string for output_contains 

43 negate: If True, invert the match result (output_contains) 

44 path: JSON path for output_json (jq-style) 

45 prompt: Custom prompt for llm_structured 

46 schema: Custom JSON schema for llm_structured response 

47 min_confidence: Minimum confidence threshold for llm_structured 

48 uncertain_suffix: If True, append _uncertain to low-confidence verdicts 

49 source: Override default source (current action output) 

50 previous: Previous value reference for convergence 

51 direction: Optimization direction for convergence (minimize/maximize) 

52 scope: Paths to limit git diff to for diff_stall evaluator 

53 max_stall: Consecutive no-change iterations before failure (diff_stall) 

54 """ 

55 

56 type: Literal[ 

57 "exit_code", 

58 "output_numeric", 

59 "output_json", 

60 "output_contains", 

61 "convergence", 

62 "diff_stall", 

63 "llm_structured", 

64 "mcp_result", 

65 ] 

66 operator: str | None = None 

67 target: int | float | str | None = None 

68 tolerance: float | str | None = None # str for interpolation (e.g., "${context.tolerance}") 

69 pattern: str | None = None 

70 negate: bool = False 

71 path: str | None = None 

72 prompt: str | None = None 

73 schema: dict[str, Any] | None = None 

74 min_confidence: float = 0.5 

75 uncertain_suffix: bool = False 

76 source: str | None = None 

77 previous: str | None = None 

78 direction: Literal["minimize", "maximize"] = "minimize" 

79 scope: list[str] | None = None # for diff_stall: limit git diff to these paths 

80 max_stall: int = 1 # for diff_stall: consecutive no-change iterations before failure 

81 

82 def to_dict(self) -> dict[str, Any]: 

83 """Convert to dictionary for JSON/YAML serialization.""" 

84 result: dict[str, Any] = {"type": self.type} 

85 

86 # Only include non-None optional fields 

87 if self.operator is not None: 

88 result["operator"] = self.operator 

89 if self.target is not None: 

90 result["target"] = self.target 

91 if self.tolerance is not None: 

92 result["tolerance"] = self.tolerance 

93 if self.pattern is not None: 

94 result["pattern"] = self.pattern 

95 if self.negate: 

96 result["negate"] = self.negate 

97 if self.path is not None: 

98 result["path"] = self.path 

99 if self.prompt is not None: 

100 result["prompt"] = self.prompt 

101 if self.schema is not None: 

102 result["schema"] = self.schema 

103 if self.min_confidence != 0.5: 

104 result["min_confidence"] = self.min_confidence 

105 if self.uncertain_suffix: 

106 result["uncertain_suffix"] = self.uncertain_suffix 

107 if self.source is not None: 

108 result["source"] = self.source 

109 if self.previous is not None: 

110 result["previous"] = self.previous 

111 if self.direction != "minimize": 

112 result["direction"] = self.direction 

113 if self.scope is not None: 

114 result["scope"] = self.scope 

115 if self.max_stall != 1: 

116 result["max_stall"] = self.max_stall 

117 

118 return result 

119 

120 @classmethod 

121 def from_dict(cls, data: dict[str, Any]) -> EvaluateConfig: 

122 """Create from dictionary (JSON/YAML deserialization).""" 

123 return cls( 

124 type=data["type"], 

125 operator=data.get("operator"), 

126 target=data.get("target"), 

127 tolerance=data.get("tolerance"), 

128 pattern=data.get("pattern"), 

129 negate=data.get("negate", False), 

130 path=data.get("path"), 

131 prompt=data.get("prompt"), 

132 schema=data.get("schema"), 

133 min_confidence=data.get("min_confidence", 0.5), 

134 uncertain_suffix=data.get("uncertain_suffix", False), 

135 source=data.get("source"), 

136 previous=data.get("previous"), 

137 direction=data.get("direction", "minimize"), 

138 scope=data.get("scope"), 

139 max_stall=data.get("max_stall", 1), 

140 ) 

141 

142 

143@dataclass 

144class RouteConfig: 

145 """Routing table configuration for verdict-to-state mapping. 

146 

147 Maps verdict strings from evaluators to next state names. 

148 

149 Attributes: 

150 routes: Mapping from verdict string to next state name 

151 default: Default state for unmatched verdicts (the "_" key) 

152 error: State for evaluation/execution errors (the "_error" key) 

153 """ 

154 

155 routes: dict[str, str] = field(default_factory=dict) 

156 default: str | None = None 

157 error: str | None = None 

158 

159 def to_dict(self) -> dict[str, Any]: 

160 """Convert to dictionary for JSON/YAML serialization.""" 

161 result = dict(self.routes) 

162 if self.default is not None: 

163 result["_"] = self.default 

164 if self.error is not None: 

165 result["_error"] = self.error 

166 return result 

167 

168 @classmethod 

169 def from_dict(cls, data: dict[str, Any]) -> RouteConfig: 

170 """Create from dictionary (JSON/YAML deserialization).""" 

171 routes = {k: v for k, v in data.items() if not k.startswith("_")} 

172 return cls( 

173 routes=routes, 

174 default=data.get("_"), 

175 error=data.get("_error"), 

176 ) 

177 

178 

179@dataclass 

180class StateConfig: 

181 """Configuration for a single FSM state. 

182 

183 States can have actions, evaluators, and routing. Supports both 

184 shorthand (on_success/on_failure) and full routing table syntax. 

185 

186 Attributes: 

187 action: Command to execute (shell, slash command, or "server/tool-name" for mcp_tool) 

188 action_type: How to execute the action (prompt, slash_command, shell, mcp_tool). 

189 If None, uses heuristic: / prefix = slash_command, else = shell. 

190 params: MCP tool arguments (only used with action_type: mcp_tool). Supports 

191 ${variable} interpolation in string values. 

192 evaluate: Evaluator configuration for result interpretation 

193 route: Full routing table (verdict -> state mapping) 

194 on_yes: Shorthand for yes verdict routing 

195 on_no: Shorthand for no verdict routing 

196 on_error: Shorthand for error verdict routing 

197 on_partial: Shorthand for partial verdict routing 

198 next: Unconditional transition (no evaluation) 

199 terminal: If True, this is an end state 

200 capture: Variable name to store action output 

201 timeout: Action-level timeout in seconds 

202 on_maintain: State to transition to when maintain=True and loop completes 

203 max_retries: Max consecutive re-entries before transitioning to on_retry_exhausted. 

204 A value of N allows N retries after the initial execution (N+1 total entries). 

205 Requires on_retry_exhausted to also be set. 

206 on_retry_exhausted: State to transition to when max_retries consecutive re-entries 

207 are exceeded. Required when max_retries is set. 

208 loop: Name of a loop YAML to execute as a sub-FSM. Mutually exclusive with action. 

209 context_passthrough: When True, pass parent context variables to child loop and 

210 merge child captures back into parent context. 

211 """ 

212 

213 action: str | None = None 

214 action_type: Literal["prompt", "slash_command", "shell", "mcp_tool"] | None = None 

215 params: dict[str, Any] = field(default_factory=dict) 

216 evaluate: EvaluateConfig | None = None 

217 route: RouteConfig | None = None 

218 on_yes: str | None = None 

219 on_no: str | None = None 

220 on_error: str | None = None 

221 on_partial: str | None = None 

222 on_blocked: str | None = None 

223 next: str | None = None 

224 terminal: bool = False 

225 capture: str | None = None 

226 timeout: int | None = None 

227 on_maintain: str | None = None 

228 max_retries: int | None = None 

229 on_retry_exhausted: str | None = None 

230 loop: str | None = None 

231 context_passthrough: bool = False 

232 

233 def to_dict(self) -> dict[str, Any]: 

234 """Convert to dictionary for JSON/YAML serialization.""" 

235 result: dict[str, Any] = {} 

236 

237 if self.action is not None: 

238 result["action"] = self.action 

239 if self.action_type is not None: 

240 result["action_type"] = self.action_type 

241 if self.params: 

242 result["params"] = self.params 

243 if self.evaluate is not None: 

244 result["evaluate"] = self.evaluate.to_dict() 

245 if self.route is not None: 

246 result["route"] = self.route.to_dict() 

247 if self.on_yes is not None: 

248 result["on_yes"] = self.on_yes 

249 if self.on_no is not None: 

250 result["on_no"] = self.on_no 

251 if self.on_error is not None: 

252 result["on_error"] = self.on_error 

253 if self.on_partial is not None: 

254 result["on_partial"] = self.on_partial 

255 if self.on_blocked is not None: 

256 result["on_blocked"] = self.on_blocked 

257 if self.next is not None: 

258 result["next"] = self.next 

259 if self.terminal: 

260 result["terminal"] = self.terminal 

261 if self.capture is not None: 

262 result["capture"] = self.capture 

263 if self.timeout is not None: 

264 result["timeout"] = self.timeout 

265 if self.on_maintain is not None: 

266 result["on_maintain"] = self.on_maintain 

267 if self.max_retries is not None: 

268 result["max_retries"] = self.max_retries 

269 if self.on_retry_exhausted is not None: 

270 result["on_retry_exhausted"] = self.on_retry_exhausted 

271 if self.loop is not None: 

272 result["loop"] = self.loop 

273 if self.context_passthrough: 

274 result["context_passthrough"] = self.context_passthrough 

275 

276 return result 

277 

278 @classmethod 

279 def from_dict(cls, data: dict[str, Any]) -> StateConfig: 

280 """Create from dictionary (JSON/YAML deserialization).""" 

281 evaluate = None 

282 if "evaluate" in data: 

283 evaluate = EvaluateConfig.from_dict(data["evaluate"]) 

284 

285 route = None 

286 if "route" in data: 

287 route = RouteConfig.from_dict(data["route"]) 

288 

289 return cls( 

290 action=data.get("action"), 

291 action_type=data.get("action_type"), 

292 params=data.get("params", {}), 

293 evaluate=evaluate, 

294 route=route, 

295 on_yes=data.get("on_yes") or data.get("on_success"), 

296 on_no=data.get("on_no") or data.get("on_failure"), 

297 on_error=data.get("on_error"), 

298 on_partial=data.get("on_partial"), 

299 on_blocked=data.get("on_blocked"), 

300 next=data.get("next"), 

301 terminal=data.get("terminal", False), 

302 capture=data.get("capture"), 

303 timeout=data.get("timeout"), 

304 on_maintain=data.get("on_maintain"), 

305 max_retries=data.get("max_retries"), 

306 on_retry_exhausted=data.get("on_retry_exhausted"), 

307 loop=data.get("loop"), 

308 context_passthrough=data.get("context_passthrough", False), 

309 ) 

310 

311 def get_referenced_states(self) -> set[str]: 

312 """Get all state names referenced by this state configuration. 

313 

314 Returns: 

315 Set of state names that this state can transition to. 

316 """ 

317 refs: set[str] = set() 

318 

319 if self.on_yes is not None: 

320 refs.add(self.on_yes) 

321 if self.on_no is not None: 

322 refs.add(self.on_no) 

323 if self.on_error is not None: 

324 refs.add(self.on_error) 

325 if self.on_partial is not None: 

326 refs.add(self.on_partial) 

327 if self.on_blocked is not None: 

328 refs.add(self.on_blocked) 

329 if self.next is not None: 

330 refs.add(self.next) 

331 if self.on_maintain is not None: 

332 refs.add(self.on_maintain) 

333 if self.on_retry_exhausted is not None: 

334 refs.add(self.on_retry_exhausted) 

335 if self.route is not None: 

336 refs.update(self.route.routes.values()) 

337 if self.route.default is not None: 

338 refs.add(self.route.default) 

339 if self.route.error is not None: 

340 refs.add(self.route.error) 

341 

342 return refs 

343 

344 

345@dataclass 

346class LLMConfig: 

347 """LLM evaluation configuration. 

348 

349 Settings for the llm_structured evaluator. 

350 

351 Attributes: 

352 enabled: If False, disable LLM evaluation entirely 

353 model: Model identifier for LLM calls 

354 max_tokens: Maximum tokens for evaluation response 

355 timeout: Timeout for LLM calls in seconds 

356 """ 

357 

358 enabled: bool = True 

359 model: str = DEFAULT_LLM_MODEL 

360 max_tokens: int = 256 

361 timeout: int = 1800 

362 

363 def to_dict(self) -> dict[str, Any]: 

364 """Convert to dictionary for JSON/YAML serialization.""" 

365 result: dict[str, Any] = {} 

366 

367 if not self.enabled: 

368 result["enabled"] = self.enabled 

369 if self.model != DEFAULT_LLM_MODEL: 

370 result["model"] = self.model 

371 if self.max_tokens != 256: 

372 result["max_tokens"] = self.max_tokens 

373 if self.timeout != 1800: 

374 result["timeout"] = self.timeout 

375 

376 return result if result else {} 

377 

378 @classmethod 

379 def from_dict(cls, data: dict[str, Any]) -> LLMConfig: 

380 """Create from dictionary (JSON/YAML deserialization).""" 

381 return cls( 

382 enabled=data.get("enabled", True), 

383 model=data.get("model", DEFAULT_LLM_MODEL), 

384 max_tokens=data.get("max_tokens", 256), 

385 timeout=data.get("timeout", 1800), 

386 ) 

387 

388 

389@dataclass 

390class FSMLoop: 

391 """Complete FSM loop definition. 

392 

393 The main dataclass representing a loop configuration. 

394 

395 Attributes: 

396 name: Unique loop identifier 

397 initial: Starting state name 

398 states: Mapping from state name to StateConfig 

399 context: User-defined shared variables 

400 scope: Paths this loop operates on (for concurrency control) 

401 max_iterations: Safety limit for loop iterations 

402 backoff: Seconds between iterations 

403 timeout: Max total runtime in seconds (loop-level) 

404 maintain: If True, restart after completion 

405 llm: LLM evaluation configuration 

406 on_handoff: Behavior when handoff signal detected (pause/spawn/terminate) 

407 """ 

408 

409 name: str 

410 initial: str 

411 states: dict[str, StateConfig] 

412 description: str | None = None 

413 context: dict[str, Any] = field(default_factory=dict) 

414 scope: list[str] = field(default_factory=list) 

415 max_iterations: int = 50 

416 backoff: float | None = None 

417 timeout: int | None = None 

418 default_timeout: int | None = None 

419 maintain: bool = False 

420 llm: LLMConfig = field(default_factory=LLMConfig) 

421 on_handoff: Literal["pause", "spawn", "terminate"] = "pause" 

422 input_key: str = "input" 

423 

424 def to_dict(self) -> dict[str, Any]: 

425 """Convert to dictionary for JSON/YAML serialization.""" 

426 result: dict[str, Any] = { 

427 "name": self.name, 

428 "initial": self.initial, 

429 "states": {name: state.to_dict() for name, state in self.states.items()}, 

430 } 

431 

432 if self.description is not None: 

433 result["description"] = self.description 

434 if self.context: 

435 result["context"] = self.context 

436 if self.scope: 

437 result["scope"] = self.scope 

438 if self.max_iterations != 50: 

439 result["max_iterations"] = self.max_iterations 

440 if self.backoff is not None: 

441 result["backoff"] = self.backoff 

442 if self.timeout is not None: 

443 result["timeout"] = self.timeout 

444 if self.default_timeout is not None: 

445 result["default_timeout"] = self.default_timeout 

446 if self.maintain: 

447 result["maintain"] = self.maintain 

448 if self.on_handoff != "pause": 

449 result["on_handoff"] = self.on_handoff 

450 if self.input_key != "input": 

451 result["input_key"] = self.input_key 

452 

453 llm_dict = self.llm.to_dict() 

454 if llm_dict: 

455 result["llm"] = llm_dict 

456 

457 return result 

458 

459 @classmethod 

460 def from_dict(cls, data: dict[str, Any]) -> FSMLoop: 

461 """Create from dictionary (JSON/YAML deserialization).""" 

462 states = { 

463 name: StateConfig.from_dict(state_data) 

464 for name, state_data in data.get("states", {}).items() 

465 } 

466 

467 llm = LLMConfig() 

468 if "llm" in data: 

469 llm = LLMConfig.from_dict(data["llm"]) 

470 

471 return cls( 

472 name=data["name"], 

473 initial=data["initial"], 

474 states=states, 

475 description=data.get("description"), 

476 context=data.get("context", {}), 

477 scope=data.get("scope", []), 

478 max_iterations=data.get("max_iterations", 50), 

479 backoff=data.get("backoff"), 

480 timeout=data.get("timeout"), 

481 default_timeout=data.get("default_timeout"), 

482 maintain=data.get("maintain", False), 

483 llm=llm, 

484 on_handoff=data.get("on_handoff", "pause"), 

485 input_key=data.get("input_key", "input"), 

486 ) 

487 

488 def get_all_state_names(self) -> set[str]: 

489 """Get all defined state names. 

490 

491 Returns: 

492 Set of all state names in this FSM. 

493 """ 

494 return set(self.states.keys()) 

495 

496 def get_terminal_states(self) -> set[str]: 

497 """Get all terminal state names. 

498 

499 Returns: 

500 Set of state names where terminal=True. 

501 """ 

502 return {name for name, state in self.states.items() if state.terminal} 

503 

504 def get_all_referenced_states(self) -> set[str]: 

505 """Get all state names referenced by transitions. 

506 

507 This includes the initial state and all states referenced 

508 in routing configurations. 

509 

510 Returns: 

511 Set of all referenced state names. 

512 """ 

513 refs: set[str] = {self.initial} 

514 for state in self.states.values(): 

515 refs.update(state.get_referenced_states()) 

516 return refs