Coverage for src/dataknobs_llm/conversations/manager.py: 0%

235 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-31 16:04 -0600

1"""Conversation manager for multi-turn interactions with LLMs.""" 

2 

3import uuid 

4from typing import Optional, List, Dict, Any, AsyncIterator 

5from datetime import datetime 

6 

7from dataknobs_structures.tree import Tree 

8from dataknobs_llm.llm import AsyncLLMProvider, LLMMessage, LLMResponse, LLMStreamResponse 

9from dataknobs_llm.prompts import AsyncPromptBuilder 

10from dataknobs_llm.conversations.storage import ( 

11 ConversationNode, 

12 ConversationState, 

13 ConversationStorage, 

14 calculate_node_id, 

15 get_node_by_id, 

16 get_messages_for_llm, 

17) 

18 

19 

20class ConversationManager: 

21 """Manages multi-turn conversations with persistence and branching. 

22 

23 This class orchestrates conversations by: 

24 - Tracking message history with tree-based branching 

25 - Managing conversation state 

26 - Persisting to storage backend 

27 - Supporting multiple conversation branches 

28 

29 Example: 

30 >>> manager = await ConversationManager.create( 

31 ... llm=llm, 

32 ... prompt_builder=builder, 

33 ... storage=storage_backend, 

34 ... system_prompt_name="helpful_assistant" 

35 ... ) 

36 >>> 

37 >>> # Add user message 

38 >>> await manager.add_message( 

39 ... prompt_name="user_query", 

40 ... params={"question": "What is Python?"}, 

41 ... role="user" 

42 ... ) 

43 >>> 

44 >>> # Get LLM response 

45 >>> result = await manager.complete() 

46 >>> 

47 >>> # Continue conversation 

48 >>> await manager.add_message( 

49 ... content="Tell me more about decorators", 

50 ... role="user" 

51 ... ) 

52 >>> result = await manager.complete() 

53 >>> 

54 >>> # Create alternative response branch 

55 >>> await manager.switch_to_node("0") # Back to first user message 

56 >>> result2 = await manager.complete(branch_name="alt-response") 

57 >>> 

58 >>> # Resume after interruption 

59 >>> manager2 = await ConversationManager.resume( 

60 ... conversation_id=manager.conversation_id, 

61 ... llm=llm, 

62 ... prompt_builder=builder, 

63 ... storage=storage_backend 

64 ... ) 

65 """ 

66 

67 def __init__( 

68 self, 

69 llm: AsyncLLMProvider, 

70 prompt_builder: AsyncPromptBuilder, 

71 storage: ConversationStorage, 

72 state: Optional[ConversationState] = None, 

73 metadata: Optional[Dict[str, Any]] = None, 

74 middleware: Optional[List["ConversationMiddleware"]] = None, 

75 cache_rag_results: bool = False, 

76 reuse_rag_on_branch: bool = False, 

77 ): 

78 """Initialize conversation manager. 

79 

80 Note: Use ConversationManager.create() or ConversationManager.resume() 

81 instead of calling __init__ directly. 

82 

83 Args: 

84 llm: LLM provider for completions 

85 prompt_builder: Prompt builder with library 

86 storage: Storage backend for persistence 

87 state: Optional existing conversation state 

88 metadata: Optional metadata for new conversations 

89 middleware: Optional list of middleware to execute 

90 cache_rag_results: If True, store RAG metadata in node metadata 

91 for debugging and transparency 

92 reuse_rag_on_branch: If True, reuse cached RAG results when 

93 possible (useful for testing/branching) 

94 """ 

95 self.llm = llm 

96 self.prompt_builder = prompt_builder 

97 self.storage = storage 

98 self.state = state 

99 self._initial_metadata = metadata or {} 

100 self.middleware = middleware or [] 

101 self.cache_rag_results = cache_rag_results 

102 self.reuse_rag_on_branch = reuse_rag_on_branch 

103 

104 @classmethod 

105 async def create( 

106 cls, 

107 llm: AsyncLLMProvider, 

108 prompt_builder: AsyncPromptBuilder, 

109 storage: ConversationStorage, 

110 system_prompt_name: Optional[str] = None, 

111 system_params: Optional[Dict[str, Any]] = None, 

112 metadata: Optional[Dict[str, Any]] = None, 

113 middleware: Optional[List["ConversationMiddleware"]] = None, 

114 cache_rag_results: bool = False, 

115 reuse_rag_on_branch: bool = False, 

116 ) -> "ConversationManager": 

117 """Create a new conversation. 

118 

119 Args: 

120 llm: LLM provider 

121 prompt_builder: Prompt builder 

122 storage: Storage backend 

123 system_prompt_name: Optional system prompt to initialize with 

124 system_params: Optional params for system prompt 

125 metadata: Optional conversation metadata 

126 middleware: Optional list of middleware to execute 

127 cache_rag_results: If True, store RAG metadata in node metadata 

128 reuse_rag_on_branch: If True, reuse cached RAG results when possible 

129 

130 Returns: 

131 Initialized ConversationManager 

132 

133 Example: 

134 >>> manager = await ConversationManager.create( 

135 ... llm=llm, 

136 ... prompt_builder=builder, 

137 ... storage=storage, 

138 ... system_prompt_name="helpful_assistant", 

139 ... cache_rag_results=True 

140 ... ) 

141 """ 

142 manager = cls( 

143 llm=llm, 

144 prompt_builder=prompt_builder, 

145 storage=storage, 

146 metadata=metadata, 

147 middleware=middleware, 

148 cache_rag_results=cache_rag_results, 

149 reuse_rag_on_branch=reuse_rag_on_branch, 

150 ) 

151 

152 # Initialize with system prompt if provided 

153 if system_prompt_name: 

154 await manager.add_message( 

155 prompt_name=system_prompt_name, 

156 params=system_params, 

157 role="system", 

158 ) 

159 

160 return manager 

161 

162 @classmethod 

163 async def resume( 

164 cls, 

165 conversation_id: str, 

166 llm: AsyncLLMProvider, 

167 prompt_builder: AsyncPromptBuilder, 

168 storage: ConversationStorage, 

169 middleware: Optional[List["ConversationMiddleware"]] = None, 

170 cache_rag_results: bool = False, 

171 reuse_rag_on_branch: bool = False, 

172 ) -> "ConversationManager": 

173 """Resume an existing conversation. 

174 

175 Args: 

176 conversation_id: Existing conversation ID 

177 llm: LLM provider 

178 prompt_builder: Prompt builder 

179 storage: Storage backend 

180 middleware: Optional list of middleware to execute 

181 cache_rag_results: If True, store RAG metadata in node metadata 

182 reuse_rag_on_branch: If True, reuse cached RAG results when possible 

183 

184 Returns: 

185 ConversationManager with restored state 

186 

187 Raises: 

188 ValueError: If conversation not found 

189 

190 Example: 

191 >>> manager = await ConversationManager.resume( 

192 ... conversation_id="conv-123", 

193 ... llm=llm, 

194 ... prompt_builder=builder, 

195 ... storage=storage, 

196 ... cache_rag_results=True 

197 ... ) 

198 """ 

199 # Load state from storage 

200 state = await storage.load_conversation(conversation_id) 

201 if not state: 

202 raise ValueError(f"Conversation '{conversation_id}' not found") 

203 

204 # Create manager with existing state 

205 manager = cls( 

206 llm=llm, 

207 prompt_builder=prompt_builder, 

208 storage=storage, 

209 state=state, 

210 middleware=middleware, 

211 cache_rag_results=cache_rag_results, 

212 reuse_rag_on_branch=reuse_rag_on_branch, 

213 ) 

214 

215 return manager 

216 

217 async def add_message( 

218 self, 

219 role: str, 

220 content: Optional[str] = None, 

221 prompt_name: Optional[str] = None, 

222 params: Optional[Dict[str, Any]] = None, 

223 include_rag: bool = True, 

224 metadata: Optional[Dict[str, Any]] = None, 

225 ) -> ConversationNode: 

226 """Add a message to the current conversation node. 

227 

228 Either content or prompt_name must be provided. 

229 

230 Args: 

231 role: Message role ("system", "user", or "assistant") 

232 content: Direct message content (if not using prompt) 

233 prompt_name: Name of prompt template to render 

234 params: Parameters for prompt rendering 

235 include_rag: Whether to execute RAG searches for prompts 

236 metadata: Optional metadata for this message node 

237 

238 Returns: 

239 The created ConversationNode 

240 

241 Raises: 

242 ValueError: If neither content nor prompt_name provided 

243 

244 Example: 

245 >>> # Add message from prompt 

246 >>> await manager.add_message( 

247 ... role="user", 

248 ... prompt_name="code_question", 

249 ... params={"code": code_snippet} 

250 ... ) 

251 >>> 

252 >>> # Add direct message 

253 >>> await manager.add_message( 

254 ... role="user", 

255 ... content="What is Python?" 

256 ... ) 

257 """ 

258 if not content and not prompt_name: 

259 raise ValueError("Either content or prompt_name must be provided") 

260 

261 # Render prompt if needed 

262 rag_metadata_to_store = None 

263 if prompt_name: 

264 params = params or {} 

265 

266 # Check if we should try to reuse cached RAG 

267 cached_rag = None 

268 if self.reuse_rag_on_branch and include_rag: 

269 cached_rag = await self._find_cached_rag(prompt_name, role, params) 

270 

271 if role == "system": 

272 result = await self.prompt_builder.render_system_prompt( 

273 prompt_name, 

274 params=params, 

275 include_rag=include_rag, 

276 return_rag_metadata=self.cache_rag_results, 

277 cached_rag=cached_rag, 

278 ) 

279 elif role == "user": 

280 result = await self.prompt_builder.render_user_prompt( 

281 prompt_name, 

282 params=params, 

283 include_rag=include_rag, 

284 return_rag_metadata=self.cache_rag_results, 

285 cached_rag=cached_rag, 

286 ) 

287 else: 

288 raise ValueError(f"Cannot render prompt for role '{role}'") 

289 

290 content = result.content 

291 

292 # Store RAG metadata if caching is enabled and metadata was captured 

293 if self.cache_rag_results and result.rag_metadata: 

294 rag_metadata_to_store = result.rag_metadata 

295 

296 # Create message 

297 message = LLMMessage(role=role, content=content) 

298 

299 # Prepare node metadata 

300 node_metadata = metadata or {} 

301 if rag_metadata_to_store: 

302 node_metadata["rag_metadata"] = rag_metadata_to_store 

303 

304 # Initialize state if this is the first message 

305 if self.state is None: 

306 conversation_id = str(uuid.uuid4()) 

307 root_node = ConversationNode( 

308 message=message, 

309 node_id="", 

310 prompt_name=prompt_name, 

311 metadata=node_metadata, 

312 ) 

313 tree = Tree(root_node) 

314 self.state = ConversationState( 

315 conversation_id=conversation_id, 

316 message_tree=tree, 

317 current_node_id="", 

318 metadata=self._initial_metadata, 

319 ) 

320 else: 

321 # Add as child of current node 

322 current_tree_node = self.state.get_current_node() 

323 if current_tree_node is None: 

324 raise ValueError(f"Current node '{self.state.current_node_id}' not found") 

325 

326 # Create new tree node 

327 new_tree_node = Tree( 

328 ConversationNode( 

329 message=message, 

330 node_id="", # Will be calculated after adding to tree 

331 prompt_name=prompt_name, 

332 metadata=node_metadata, 

333 ) 

334 ) 

335 

336 # Add to tree 

337 current_tree_node.add_child(new_tree_node) 

338 

339 # Calculate and set node_id 

340 node_id = calculate_node_id(new_tree_node) 

341 new_tree_node.data.node_id = node_id 

342 

343 # Move current position to new node 

344 self.state.current_node_id = node_id 

345 

346 # Update timestamp 

347 self.state.updated_at = datetime.now() 

348 

349 # Persist 

350 await self._save_state() 

351 

352 return self.state.get_current_node().data 

353 

354 async def complete( 

355 self, 

356 branch_name: Optional[str] = None, 

357 metadata: Optional[Dict[str, Any]] = None, 

358 **llm_kwargs, 

359 ) -> LLMResponse: 

360 """Get LLM completion and add as child of current node. 

361 

362 This method: 

363 1. Gets conversation history from root to current node 

364 2. Executes middleware (pre-LLM) 

365 3. Calls LLM with history 

366 4. Executes middleware (post-LLM) 

367 5. Adds assistant response as child of current node 

368 6. Updates current position to new node 

369 7. Persists to storage 

370 

371 Args: 

372 branch_name: Optional human-readable label for this branch 

373 metadata: Optional metadata for the assistant message node 

374 **llm_kwargs: Additional arguments for LLM.complete() 

375 

376 Returns: 

377 LLM response 

378 

379 Raises: 

380 ValueError: If conversation has no messages yet 

381 

382 Example: 

383 >>> # Get response 

384 >>> result = await manager.complete() 

385 >>> print(result.content) 

386 >>> 

387 >>> # Create labeled branch 

388 >>> result = await manager.complete(branch_name="alternative-answer") 

389 """ 

390 if not self.state: 

391 raise ValueError("Cannot complete: no messages in conversation") 

392 

393 # Get messages from root to current position 

394 messages = self.state.get_current_messages() 

395 

396 # Execute middleware (pre-LLM) in forward order 

397 for mw in self.middleware: 

398 messages = await mw.process_request(messages, self.state) 

399 

400 # Call LLM 

401 response = await self.llm.complete(messages, **llm_kwargs) 

402 

403 # Execute middleware (post-LLM) in reverse order (onion model) 

404 for mw in reversed(self.middleware): 

405 response = await mw.process_response(response, self.state) 

406 

407 # Add assistant message as child 

408 current_tree_node = self.state.get_current_node() 

409 if current_tree_node is None: 

410 raise ValueError(f"Current node '{self.state.current_node_id}' not found") 

411 

412 # Create assistant message node 

413 assistant_message = LLMMessage( 

414 role="assistant", 

415 content=response.content, 

416 ) 

417 

418 assistant_metadata = metadata or {} 

419 assistant_metadata.update({ 

420 "usage": response.usage, 

421 "model": response.model, 

422 "finish_reason": response.finish_reason, 

423 }) 

424 

425 new_tree_node = Tree( 

426 ConversationNode( 

427 message=assistant_message, 

428 node_id="", # Will be calculated 

429 branch_name=branch_name, 

430 metadata=assistant_metadata, 

431 ) 

432 ) 

433 

434 # Add to tree 

435 current_tree_node.add_child(new_tree_node) 

436 

437 # Calculate node_id 

438 node_id = calculate_node_id(new_tree_node) 

439 new_tree_node.data.node_id = node_id 

440 

441 # Move current position 

442 self.state.current_node_id = node_id 

443 self.state.updated_at = datetime.now() 

444 

445 # Persist 

446 await self._save_state() 

447 

448 return response 

449 

450 async def stream_complete( 

451 self, 

452 branch_name: Optional[str] = None, 

453 metadata: Optional[Dict[str, Any]] = None, 

454 **llm_kwargs, 

455 ) -> AsyncIterator[LLMStreamResponse]: 

456 """Stream LLM completion and add as child of current node. 

457 

458 Similar to complete() but streams the response. 

459 

460 Args: 

461 branch_name: Optional human-readable label for this branch 

462 metadata: Optional metadata for the assistant message node 

463 **llm_kwargs: Additional arguments for LLM.stream_complete() 

464 

465 Yields: 

466 Streaming response chunks 

467 

468 Raises: 

469 ValueError: If conversation has no messages yet 

470 

471 Example: 

472 >>> async for chunk in manager.stream_complete(): 

473 ... print(chunk.delta, end="", flush=True) 

474 """ 

475 if not self.state: 

476 raise ValueError("Cannot complete: no messages in conversation") 

477 

478 # Get messages 

479 messages = self.state.get_current_messages() 

480 

481 # Execute middleware (pre-LLM) in forward order 

482 for mw in self.middleware: 

483 messages = await mw.process_request(messages, self.state) 

484 

485 # Stream LLM response and accumulate 

486 full_content = "" 

487 final_chunk = None 

488 async for chunk in self.llm.stream_complete(messages, **llm_kwargs): 

489 full_content += chunk.delta 

490 final_chunk = chunk 

491 yield chunk 

492 

493 # Create complete response for state update 

494 response = LLMResponse( 

495 content=full_content, 

496 model=self.llm.config.model, 

497 finish_reason=final_chunk.finish_reason if final_chunk else "stop", 

498 usage=final_chunk.usage if final_chunk else None, 

499 ) 

500 

501 # Execute middleware (post-LLM) in reverse order (onion model) 

502 for mw in reversed(self.middleware): 

503 response = await mw.process_response(response, self.state) 

504 

505 # Add assistant message as child (same as complete()) 

506 current_tree_node = self.state.get_current_node() 

507 if current_tree_node is None: 

508 raise ValueError(f"Current node '{self.state.current_node_id}' not found") 

509 

510 assistant_message = LLMMessage(role="assistant", content=response.content) 

511 

512 assistant_metadata = metadata or {} 

513 assistant_metadata.update({ 

514 "usage": response.usage, 

515 "model": response.model, 

516 "finish_reason": response.finish_reason, 

517 }) 

518 

519 new_tree_node = Tree( 

520 ConversationNode( 

521 message=assistant_message, 

522 node_id="", 

523 branch_name=branch_name, 

524 metadata=assistant_metadata, 

525 ) 

526 ) 

527 

528 current_tree_node.add_child(new_tree_node) 

529 node_id = calculate_node_id(new_tree_node) 

530 new_tree_node.data.node_id = node_id 

531 

532 self.state.current_node_id = node_id 

533 self.state.updated_at = datetime.now() 

534 

535 await self._save_state() 

536 

537 async def switch_to_node(self, node_id: str) -> None: 

538 """Switch current position to a different node in the tree. 

539 

540 This allows exploring different branches or backtracking in the conversation. 

541 

542 Args: 

543 node_id: Target node ID (dot-delimited, e.g., "0.1" or "") 

544 

545 Raises: 

546 ValueError: If node_id not found in tree 

547 

548 Example: 

549 >>> # Go back to first user message 

550 >>> await manager.switch_to_node("0") 

551 >>> 

552 >>> # Create alternative response 

553 >>> result = await manager.complete(branch_name="alternative") 

554 >>> 

555 >>> # Go back to root 

556 >>> await manager.switch_to_node("") 

557 """ 

558 if not self.state: 

559 raise ValueError("No conversation state") 

560 

561 # Verify node exists 

562 target_node = get_node_by_id(self.state.message_tree, node_id) 

563 if target_node is None: 

564 raise ValueError(f"Node '{node_id}' not found in conversation tree") 

565 

566 # Update current position 

567 self.state.current_node_id = node_id 

568 self.state.updated_at = datetime.now() 

569 

570 # Persist 

571 await self._save_state() 

572 

573 async def execute_flow( 

574 self, 

575 flow: "ConversationFlow", 

576 initial_params: Optional[Dict[str, Any]] = None 

577 ) -> AsyncIterator[ConversationNode]: 

578 """Execute a conversation flow using FSM. 

579 

580 This method executes a predefined conversation flow, yielding 

581 conversation nodes as the flow progresses through states. 

582 

583 Args: 

584 flow: ConversationFlow definition 

585 initial_params: Optional initial parameters for the flow 

586 

587 Yields: 

588 ConversationNode for each state in the flow 

589 

590 Raises: 

591 ValueError: If flow execution fails 

592 

593 Example: 

594 >>> from dataknobs_llm.conversations.flow import ( 

595 ... ConversationFlow, FlowState, 

596 ... keyword_condition 

597 ... ) 

598 >>> 

599 >>> # Define flow 

600 >>> flow = ConversationFlow( 

601 ... name="support", 

602 ... initial_state="greeting", 

603 ... states={ 

604 ... "greeting": FlowState( 

605 ... prompt_name="support_greeting", 

606 ... transitions={ 

607 ... "help": "collect_issue", 

608 ... "browse": "end" 

609 ... }, 

610 ... transition_conditions={ 

611 ... "help": keyword_condition(["help", "issue"]), 

612 ... "browse": keyword_condition(["browse", "look"]) 

613 ... } 

614 ... ) 

615 ... } 

616 ... ) 

617 >>> 

618 >>> # Execute flow 

619 >>> async for node in manager.execute_flow(flow): 

620 ... print(f"State: {node.metadata.get('state')}") 

621 ... print(f"Response: {node.content}") 

622 """ 

623 from dataknobs_llm.conversations.flow import ConversationFlowAdapter 

624 

625 if not self.state: 

626 raise ValueError("No conversation state") 

627 

628 # Create adapter 

629 adapter = ConversationFlowAdapter( 

630 flow=flow, 

631 prompt_builder=self.prompt_builder, 

632 llm=self.llm 

633 ) 

634 

635 # Execute flow and yield nodes 

636 data = initial_params or {} 

637 data["conversation_id"] = self.state.conversation_id 

638 

639 try: 

640 # Execute flow (this will internally use FSM) 

641 result = await adapter.execute(data) 

642 

643 # Convert flow history to conversation nodes 

644 for state_name, response in adapter.execution_state.history: 

645 # Create node for this state's response 

646 node = ConversationNode( 

647 node_id=str(uuid.uuid4()), # Temporary ID 

648 role="assistant", 

649 content=response, 

650 timestamp=datetime.now(), 

651 metadata={ 

652 "state": state_name, 

653 "flow_name": flow.name, 

654 "flow_execution": True 

655 } 

656 ) 

657 

658 # Add to conversation tree 

659 current_tree_node = get_node_by_id( 

660 self.state.message_tree, 

661 self.state.current_node_id 

662 ) 

663 

664 new_tree_node = Tree(node) 

665 current_tree_node.add_child(new_tree_node) 

666 node_id = calculate_node_id(new_tree_node) 

667 new_tree_node.data.node_id = node_id 

668 

669 self.state.current_node_id = node_id 

670 self.state.updated_at = datetime.now() 

671 

672 await self._save_state() 

673 

674 yield node 

675 

676 except Exception as e: 

677 import logging 

678 logging.error(f"Flow execution failed: {e}") 

679 raise ValueError(f"Flow execution failed: {str(e)}") from e 

680 

681 async def get_history(self) -> List[LLMMessage]: 

682 """Get conversation history from root to current position. 

683 

684 Returns: 

685 List of messages in current conversation path 

686 

687 Example: 

688 >>> messages = await manager.get_history() 

689 >>> for msg in messages: 

690 ... print(f"{msg.role}: {msg.content}") 

691 """ 

692 if not self.state: 

693 return [] 

694 

695 return self.state.get_current_messages() 

696 

697 async def get_branches(self, node_id: Optional[str] = None) -> List[Dict[str, Any]]: 

698 """Get information about branches from a given node. 

699 

700 Args: 

701 node_id: Node to get branches from (default: current node) 

702 

703 Returns: 

704 List of branch info dicts with keys: 

705 - node_id: ID of child node 

706 - branch_name: Optional branch name 

707 - role: Message role 

708 - preview: First 100 chars of content 

709 - timestamp: When created 

710 

711 Example: 

712 >>> branches = await manager.get_branches() 

713 >>> for branch in branches: 

714 ... print(f"{branch['branch_name']}: {branch['preview']}") 

715 """ 

716 if not self.state: 

717 return [] 

718 

719 # Default to current node 

720 if node_id is None: 

721 node_id = self.state.current_node_id 

722 

723 # Get node 

724 node = get_node_by_id(self.state.message_tree, node_id) 

725 if node is None or not node.children: 

726 return [] 

727 

728 # Build branch info 

729 branches = [] 

730 for child in node.children: 

731 data = child.data 

732 branches.append({ 

733 "node_id": data.node_id, 

734 "branch_name": data.branch_name, 

735 "role": data.message.role, 

736 "preview": data.message.content[:100], 

737 "timestamp": data.timestamp, 

738 }) 

739 

740 return branches 

741 

742 async def add_metadata(self, key: str, value: Any) -> None: 

743 """Add metadata to conversation. 

744 

745 Args: 

746 key: Metadata key 

747 value: Metadata value 

748 

749 Example: 

750 >>> await manager.add_metadata("user_id", "alice") 

751 >>> await manager.add_metadata("session", "abc123") 

752 """ 

753 if not self.state: 

754 raise ValueError("No conversation state") 

755 

756 self.state.metadata[key] = value 

757 self.state.updated_at = datetime.now() 

758 await self._save_state() 

759 

760 async def _find_cached_rag( 

761 self, 

762 prompt_name: str, 

763 role: str, 

764 params: Dict[str, Any] 

765 ) -> Optional[Dict[str, Any]]: 

766 """Search conversation history for cached RAG metadata. 

767 

768 This method searches the entire conversation tree for cached RAG metadata 

769 that matches both the prompt name/role AND the resolved RAG query parameters. 

770 Query matching is done via query hashes. 

771 

772 Args: 

773 prompt_name: Name of the prompt to find cached RAG for 

774 role: Role of the prompt ("system" or "user") 

775 params: Parameters for the prompt (used to match RAG queries) 

776 

777 Returns: 

778 Cached RAG metadata dictionary if found, None otherwise 

779 

780 Example: 

781 >>> cached = await manager._find_cached_rag("code_question", "user", {"topic": "decorators"}) 

782 >>> if cached: 

783 ... print(f"Found cached RAG with {len(cached)} placeholders") 

784 """ 

785 if not self.state: 

786 return None 

787 

788 # Get RAG configs for this prompt to determine what queries we're looking for 

789 rag_configs = self.prompt_builder.library.get_prompt_rag_configs( 

790 prompt_name=prompt_name, 

791 prompt_type="system" if role == "system" else "user" 

792 ) 

793 

794 if not rag_configs: 

795 return None 

796 

797 # Compute the query hashes we're looking for 

798 from jinja2 import Template 

799 target_hashes_by_placeholder = {} 

800 for rag_config in rag_configs: 

801 placeholder = rag_config.get("placeholder", "RAG_CONTENT") 

802 adapter_name = rag_config.get("adapter_name", "") 

803 query_template = rag_config.get("query", "") 

804 

805 # Render the query template with params 

806 try: 

807 template = Template(query_template) 

808 resolved_query = template.render(params) 

809 

810 # Compute hash 

811 query_hash = self.prompt_builder._compute_rag_query_hash(adapter_name, resolved_query) 

812 target_hashes_by_placeholder[placeholder] = query_hash 

813 except Exception: 

814 # If query rendering fails, we can't match cache 

815 continue 

816 

817 if not target_hashes_by_placeholder: 

818 return None 

819 

820 # Search entire tree for matching cached RAG (BFS to find any match) 

821 from collections import deque 

822 queue = deque([self.state.message_tree]) 

823 

824 while queue: 

825 tree_node = queue.popleft() 

826 node_data = tree_node.data 

827 

828 # Check if this node has the same prompt name and role 

829 if (node_data.prompt_name == prompt_name and 

830 node_data.message.role == role): 

831 

832 # Check if RAG metadata exists 

833 rag_metadata = node_data.metadata.get("rag_metadata") 

834 if rag_metadata: 

835 # Check if query hashes match for all placeholders 

836 all_match = True 

837 for placeholder, target_hash in target_hashes_by_placeholder.items(): 

838 if placeholder not in rag_metadata: 

839 all_match = False 

840 break 

841 cached_hash = rag_metadata[placeholder].get("query_hash") 

842 if cached_hash != target_hash: 

843 all_match = False 

844 break 

845 

846 if all_match: 

847 return rag_metadata 

848 

849 # Add children to queue (if any) 

850 if tree_node.children: 

851 queue.extend(tree_node.children) 

852 

853 return None 

854 

855 def get_rag_metadata(self, node_id: Optional[str] = None) -> Optional[Dict[str, Any]]: 

856 """Get RAG metadata from a conversation node. 

857 

858 This method retrieves the cached RAG metadata from a specific node, 

859 which includes information about RAG searches executed during prompt 

860 rendering (queries, results, query hashes, etc.). 

861 

862 Args: 

863 node_id: Node ID to retrieve metadata from (default: current node) 

864 

865 Returns: 

866 RAG metadata dictionary if present, None otherwise 

867 

868 Raises: 

869 ValueError: If node_id not found in conversation tree 

870 

871 Example: 

872 >>> # Get RAG metadata from current node 

873 >>> metadata = manager.get_rag_metadata() 

874 >>> if metadata: 

875 ... for placeholder, rag_data in metadata.items(): 

876 ... print(f"{placeholder}: query={rag_data['query']}") 

877 >>> 

878 >>> # Get RAG metadata from specific node 

879 >>> metadata = manager.get_rag_metadata(node_id="0.1") 

880 """ 

881 if not self.state: 

882 return None 

883 

884 # Default to current node 

885 if node_id is None: 

886 node_id = self.state.current_node_id 

887 

888 # Get node 

889 tree_node = get_node_by_id(self.state.message_tree, node_id) 

890 if tree_node is None: 

891 raise ValueError(f"Node '{node_id}' not found in conversation tree") 

892 

893 # Return RAG metadata if present 

894 return tree_node.data.metadata.get("rag_metadata") 

895 

896 async def _save_state(self) -> None: 

897 """Persist current state to storage.""" 

898 if self.state: 

899 await self.storage.save_conversation(self.state) 

900 

901 @property 

902 def conversation_id(self) -> Optional[str]: 

903 """Get conversation ID.""" 

904 return self.state.conversation_id if self.state else None 

905 

906 @property 

907 def current_node_id(self) -> Optional[str]: 

908 """Get current node ID.""" 

909 return self.state.current_node_id if self.state else None