Coverage for src/dataknobs_llm/conversations/manager.py: 71%

299 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-08 13:51 -0700

1"""Conversation manager for multi-turn interactions with LLMs. 

2 

3This module provides ConversationManager, a comprehensive system for managing 

4multi-turn LLM conversations with advanced features like: 

5 

6- **Tree-based History**: Conversations stored as trees, enabling branching 

7- **Persistence**: Automatic state saving to any storage backend 

8- **RAG Caching**: Reuse search results across conversation branches 

9- **Middleware**: Pre/post-processing pipeline for all LLM calls 

10- **Cost Tracking**: Automatic API cost calculation and accumulation 

11- **Flow Execution**: FSM-based conversation flows with state management 

12- **Resumability**: Save and resume conversations across sessions 

13 

14Architecture: 

15 ConversationManager orchestrates three core components: 

16 

17 1. **AsyncLLMProvider**: Handles LLM API calls (OpenAI, Anthropic, Ollama) 

18 2. **AsyncPromptBuilder**: Renders prompts with RAG integration 

19 3. **ConversationStorage**: Persists conversation state (Memory, File, S3, Postgres) 

20 

21 Conversations are stored as trees where each node represents a message. 

22 Branching occurs when multiple responses are generated from the same point, 

23 enabling A/B testing, alternative explorations, and retry scenarios. 

24 

25Example: 

26 ```python 

27 from dataknobs_llm import create_llm_provider 

28 from dataknobs_llm.prompts import AsyncPromptBuilder 

29 from dataknobs_llm.conversations import ( 

30 ConversationManager, 

31 DataknobsConversationStorage 

32 ) 

33 from dataknobs_data import database_factory 

34 

35 # Setup components 

36 llm = create_llm_provider("openai", api_key="sk-...") 

37 builder = AsyncPromptBuilder.create(library_path="./prompts") 

38 db = database_factory.create(backend="memory") 

39 storage = DataknobsConversationStorage(db) 

40 

41 # Create conversation 

42 manager = await ConversationManager.create( 

43 llm=llm, 

44 prompt_builder=builder, 

45 storage=storage, 

46 system_prompt_name="helpful_assistant" 

47 ) 

48 

49 # Add user message and get response 

50 await manager.add_message( 

51 role="user", 

52 content="What is Python?" 

53 ) 

54 response = await manager.complete() 

55 print(response.content) 

56 

57 # Continue conversation 

58 await manager.add_message( 

59 role="user", 

60 content="Show me a decorator example" 

61 ) 

62 response = await manager.complete() 

63 

64 # Create alternative response branch 

65 await manager.switch_to_node("0.0") # Back to first assistant response 

66 alt_response = await manager.complete(branch_name="alternative") 

67 

68 # Resume later 

69 conv_id = manager.conversation_id 

70 manager2 = await ConversationManager.resume( 

71 conversation_id=conv_id, 

72 llm=llm, 

73 prompt_builder=builder, 

74 storage=storage 

75 ) 

76 ``` 

77 

78See Also: 

79 - ConversationStorage: Storage interface and implementations 

80 - ConversationMiddleware: Middleware system for request/response processing 

81 - ConversationFlow: FSM-based conversation flows 

82 - AsyncPromptBuilder: Prompt rendering with RAG integration 

83""" 

84 

85import uuid 

86from typing import List, Dict, Any, AsyncIterator 

87from datetime import datetime 

88 

89from dataknobs_structures.tree import Tree 

90from dataknobs_llm.llm import AsyncLLMProvider, LLMMessage, LLMResponse, LLMStreamResponse 

91from dataknobs_llm.prompts import AsyncPromptBuilder 

92from dataknobs_llm.conversations.flow.flow import ConversationFlow 

93from dataknobs_llm.conversations.middleware import ConversationMiddleware 

94from dataknobs_llm.conversations.storage import ( 

95 ConversationNode, 

96 ConversationState, 

97 ConversationStorage, 

98 calculate_node_id, 

99 get_node_by_id, 

100) 

101 

102 

103class ConversationManager: 

104 """Manages multi-turn conversations with persistence and branching. 

105 

106 This class orchestrates conversations by: 

107 - Tracking message history with tree-based branching 

108 - Managing conversation state 

109 - Persisting to storage backend 

110 - Supporting multiple conversation branches 

111 

112 The conversation history is stored as a tree structure where: 

113 - Root node contains the initial system prompt (if any) 

114 - Each message is a tree node with a dot-delimited ID (e.g., "0.1.2") 

115 - Branches occur when multiple children are added to the same node 

116 - Current position tracks where you are in the conversation tree 

117 

118 Attributes: 

119 llm: LLM provider for completions 

120 prompt_builder: Prompt builder with library 

121 storage: Storage backend for persistence 

122 state: Current conversation state (tree, metadata, position) 

123 middleware: List of middleware to execute on requests/responses 

124 cache_rag_results: Whether to store RAG metadata in nodes 

125 reuse_rag_on_branch: Whether to reuse cached RAG across branches 

126 conversation_id: Unique conversation identifier 

127 current_node_id: Current position in conversation tree 

128 

129 Example: 

130 ```python 

131 # Create conversation 

132 manager = await ConversationManager.create( 

133 llm=llm, 

134 prompt_builder=builder, 

135 storage=storage_backend, 

136 system_prompt_name="helpful_assistant" 

137 ) 

138 

139 # Add user message 

140 await manager.add_message( 

141 prompt_name="user_query", 

142 params={"question": "What is Python?"}, 

143 role="user" 

144 ) 

145 

146 # Get LLM response 

147 result = await manager.complete() 

148 

149 # Continue conversation 

150 await manager.add_message( 

151 content="Tell me more about decorators", 

152 role="user" 

153 ) 

154 result = await manager.complete() 

155 

156 # Create alternative response branch 

157 await manager.switch_to_node("0") # Back to first user message 

158 result2 = await manager.complete(branch_name="alt-response") 

159 

160 # Resume after interruption 

161 manager2 = await ConversationManager.resume( 

162 conversation_id=manager.conversation_id, 

163 llm=llm, 

164 prompt_builder=builder, 

165 storage=storage_backend 

166 ) 

167 ``` 

168 

169 Note: 

170 Tree-based branching enables: 

171 

172 - **A/B Testing**: Generate multiple responses from the same context 

173 - **Retry Logic**: Try again from a previous point after failures 

174 - **Alternative Explorations**: Explore different conversation paths 

175 - **Debugging**: Compare different middleware or RAG configurations 

176 

177 Node IDs use dot notation (e.g., "0.1.2" means 3rd child of 2nd child 

178 of 1st child of root). The root node has ID "". 

179 

180 State is automatically persisted after every operation. Use 

181 `resume()` to continue conversations across sessions or servers. 

182 

183 See Also: 

184 create: Create a new conversation 

185 resume: Resume an existing conversation 

186 add_message: Add user/system message 

187 complete: Get LLM completion (blocking) 

188 stream_complete: Get LLM completion (streaming) 

189 switch_to_node: Navigate to different branch 

190 get_branches: List available branches 

191 get_total_cost: Calculate cumulative cost 

192 ConversationStorage: Storage backend implementations 

193 ConversationMiddleware: Request/response processing 

194 ConversationFlow: FSM-based conversation flows 

195 """ 

196 

197 def __init__( 

198 self, 

199 llm: AsyncLLMProvider, 

200 prompt_builder: AsyncPromptBuilder, 

201 storage: ConversationStorage, 

202 state: ConversationState | None = None, 

203 metadata: Dict[str, Any] | None = None, 

204 middleware: List[ConversationMiddleware] | None = None, 

205 cache_rag_results: bool = False, 

206 reuse_rag_on_branch: bool = False, 

207 ): 

208 """Initialize conversation manager. 

209 

210 Note: Use ConversationManager.create() or ConversationManager.resume() 

211 instead of calling __init__ directly. 

212 

213 Args: 

214 llm: LLM provider for completions 

215 prompt_builder: Prompt builder with library 

216 storage: Storage backend for persistence 

217 state: Optional existing conversation state 

218 metadata: Optional metadata for new conversations 

219 middleware: Optional list of middleware to execute 

220 cache_rag_results: If True, store RAG metadata in node metadata 

221 for debugging and transparency 

222 reuse_rag_on_branch: If True, reuse cached RAG results when 

223 possible (useful for testing/branching) 

224 """ 

225 self.llm = llm 

226 self.prompt_builder = prompt_builder 

227 self.storage = storage 

228 self.state = state 

229 self._initial_metadata = metadata or {} 

230 self.middleware = middleware or [] 

231 self.cache_rag_results = cache_rag_results 

232 self.reuse_rag_on_branch = reuse_rag_on_branch 

233 

234 @classmethod 

235 async def create( 

236 cls, 

237 llm: AsyncLLMProvider, 

238 prompt_builder: AsyncPromptBuilder, 

239 storage: ConversationStorage, 

240 system_prompt_name: str | None = None, 

241 system_params: Dict[str, Any] | None = None, 

242 metadata: Dict[str, Any] | None = None, 

243 middleware: List[ConversationMiddleware] | None = None, 

244 cache_rag_results: bool = False, 

245 reuse_rag_on_branch: bool = False, 

246 ) -> "ConversationManager": 

247 """Create a new conversation. 

248 

249 Args: 

250 llm: LLM provider 

251 prompt_builder: Prompt builder 

252 storage: Storage backend 

253 system_prompt_name: Optional system prompt to initialize with 

254 system_params: Optional params for system prompt 

255 metadata: Optional conversation metadata 

256 middleware: Optional list of middleware to execute 

257 cache_rag_results: If True, store RAG metadata in node metadata 

258 reuse_rag_on_branch: If True, reuse cached RAG results when possible 

259 

260 Returns: 

261 Initialized ConversationManager 

262 

263 Example: 

264 >>> manager = await ConversationManager.create( 

265 ... llm=llm, 

266 ... prompt_builder=builder, 

267 ... storage=storage, 

268 ... system_prompt_name="helpful_assistant", 

269 ... cache_rag_results=True 

270 ... ) 

271 """ 

272 manager = cls( 

273 llm=llm, 

274 prompt_builder=prompt_builder, 

275 storage=storage, 

276 metadata=metadata, 

277 middleware=middleware, 

278 cache_rag_results=cache_rag_results, 

279 reuse_rag_on_branch=reuse_rag_on_branch, 

280 ) 

281 

282 # Initialize with system prompt if provided 

283 if system_prompt_name: 

284 await manager.add_message( 

285 prompt_name=system_prompt_name, 

286 params=system_params, 

287 role="system", 

288 ) 

289 

290 return manager 

291 

292 @classmethod 

293 async def resume( 

294 cls, 

295 conversation_id: str, 

296 llm: AsyncLLMProvider, 

297 prompt_builder: AsyncPromptBuilder, 

298 storage: ConversationStorage, 

299 middleware: List[ConversationMiddleware] | None = None, 

300 cache_rag_results: bool = False, 

301 reuse_rag_on_branch: bool = False, 

302 ) -> "ConversationManager": 

303 """Resume an existing conversation. 

304 

305 Args: 

306 conversation_id: Existing conversation ID 

307 llm: LLM provider 

308 prompt_builder: Prompt builder 

309 storage: Storage backend 

310 middleware: Optional list of middleware to execute 

311 cache_rag_results: If True, store RAG metadata in node metadata 

312 reuse_rag_on_branch: If True, reuse cached RAG results when possible 

313 

314 Returns: 

315 ConversationManager with restored state 

316 

317 Raises: 

318 ValueError: If conversation not found 

319 

320 Example: 

321 >>> manager = await ConversationManager.resume( 

322 ... conversation_id="conv-123", 

323 ... llm=llm, 

324 ... prompt_builder=builder, 

325 ... storage=storage, 

326 ... cache_rag_results=True 

327 ... ) 

328 """ 

329 # Load state from storage 

330 state = await storage.load_conversation(conversation_id) 

331 if not state: 

332 raise ValueError(f"Conversation '{conversation_id}' not found") 

333 

334 # Create manager with existing state 

335 manager = cls( 

336 llm=llm, 

337 prompt_builder=prompt_builder, 

338 storage=storage, 

339 state=state, 

340 middleware=middleware, 

341 cache_rag_results=cache_rag_results, 

342 reuse_rag_on_branch=reuse_rag_on_branch, 

343 ) 

344 

345 return manager 

346 

347 async def add_message( 

348 self, 

349 role: str, 

350 content: str | None = None, 

351 prompt_name: str | None = None, 

352 params: Dict[str, Any] | None = None, 

353 include_rag: bool = True, 

354 metadata: Dict[str, Any] | None = None, 

355 ) -> ConversationNode: 

356 """Add a message to the current conversation node. 

357 

358 Either content or prompt_name must be provided. If using a prompt 

359 with RAG configuration, the RAG searches will be executed and results 

360 will be automatically inserted into the prompt. 

361 

362 Args: 

363 role: Message role ("system", "user", or "assistant") 

364 content: Direct message content (if not using prompt) 

365 prompt_name: Name of prompt template to render 

366 params: Parameters for prompt rendering 

367 include_rag: Whether to execute RAG searches for prompts 

368 metadata: Optional metadata for this message node 

369 

370 Returns: 

371 The created ConversationNode 

372 

373 Raises: 

374 ValueError: If neither content nor prompt_name provided 

375 

376 Example: 

377 ```python 

378 # Add message from prompt 

379 await manager.add_message( 

380 role="user", 

381 prompt_name="code_question", 

382 params={"code": code_snippet} 

383 ) 

384 

385 # Add direct message 

386 await manager.add_message( 

387 role="user", 

388 content="What is Python?" 

389 ) 

390 

391 # Add system prompt with custom metadata 

392 await manager.add_message( 

393 role="system", 

394 prompt_name="expert_coder", 

395 metadata={"version": "v2"} 

396 ) 

397 ``` 

398 

399 Note: 

400 **RAG Caching Behavior**: 

401 

402 If `cache_rag_results=True` and `reuse_rag_on_branch=True` were 

403 set during ConversationManager creation, this method will: 

404 

405 1. Check if the same prompt+role was used elsewhere in the tree 

406 2. Check if the RAG query parameters match (via query hash) 

407 3. Reuse cached RAG results if found (no re-search!) 

408 4. Store new RAG results if not found 

409 

410 This is particularly useful when exploring conversation branches, 

411 as you can avoid redundant searches for the same information. 

412 

413 See Also: 

414 complete: Get LLM response after adding message 

415 get_rag_metadata: Retrieve RAG metadata from a node 

416 """ 

417 if not content and not prompt_name: 

418 raise ValueError("Either content or prompt_name must be provided") 

419 

420 # Render prompt if needed 

421 rag_metadata_to_store = None 

422 if prompt_name: 

423 params = params or {} 

424 

425 # Check if we should try to reuse cached RAG 

426 cached_rag = None 

427 if self.reuse_rag_on_branch and include_rag: 

428 cached_rag = await self._find_cached_rag(prompt_name, role, params) 

429 

430 if role == "system": 

431 result = await self.prompt_builder.render_system_prompt( 

432 prompt_name, 

433 params=params, 

434 include_rag=include_rag, 

435 return_rag_metadata=self.cache_rag_results, 

436 cached_rag=cached_rag, 

437 ) 

438 elif role == "user": 

439 result = await self.prompt_builder.render_user_prompt( 

440 prompt_name, 

441 params=params, 

442 include_rag=include_rag, 

443 return_rag_metadata=self.cache_rag_results, 

444 cached_rag=cached_rag, 

445 ) 

446 else: 

447 raise ValueError(f"Cannot render prompt for role '{role}'") 

448 

449 content = result.content 

450 

451 # Store RAG metadata if caching is enabled and metadata was captured 

452 if self.cache_rag_results and result.rag_metadata: 

453 rag_metadata_to_store = result.rag_metadata 

454 

455 # Create message 

456 message = LLMMessage(role=role, content=content) 

457 

458 # Prepare node metadata 

459 node_metadata = metadata or {} 

460 if rag_metadata_to_store: 

461 node_metadata["rag_metadata"] = rag_metadata_to_store 

462 

463 # Initialize state if this is the first message 

464 if self.state is None: 

465 conversation_id = str(uuid.uuid4()) 

466 root_node = ConversationNode( 

467 message=message, 

468 node_id="", 

469 prompt_name=prompt_name, 

470 metadata=node_metadata, 

471 ) 

472 tree = Tree(root_node) 

473 self.state = ConversationState( 

474 conversation_id=conversation_id, 

475 message_tree=tree, 

476 current_node_id="", 

477 metadata=self._initial_metadata, 

478 ) 

479 else: 

480 # Add as child of current node 

481 current_tree_node = self.state.get_current_node() 

482 if current_tree_node is None: 

483 raise ValueError(f"Current node '{self.state.current_node_id}' not found") 

484 

485 # Create new tree node 

486 new_tree_node = Tree( 

487 ConversationNode( 

488 message=message, 

489 node_id="", # Will be calculated after adding to tree 

490 prompt_name=prompt_name, 

491 metadata=node_metadata, 

492 ) 

493 ) 

494 

495 # Add to tree 

496 current_tree_node.add_child(new_tree_node) 

497 

498 # Calculate and set node_id 

499 node_id = calculate_node_id(new_tree_node) 

500 new_tree_node.data.node_id = node_id 

501 

502 # Move current position to new node 

503 self.state.current_node_id = node_id 

504 

505 # Update timestamp 

506 self.state.updated_at = datetime.now() 

507 

508 # Persist 

509 await self._save_state() 

510 

511 return self.state.get_current_node().data 

512 

513 async def complete( 

514 self, 

515 branch_name: str | None = None, 

516 metadata: Dict[str, Any] | None = None, 

517 **llm_kwargs: Any, 

518 ) -> LLMResponse: 

519 """Get LLM completion and add as child of current node. 

520 

521 This method: 

522 1. Gets conversation history from root to current node 

523 2. Executes middleware (pre-LLM) 

524 3. Calls LLM with history 

525 4. Executes middleware (post-LLM) 

526 5. Adds assistant response as child of current node 

527 6. Updates current position to new node 

528 7. Persists to storage 

529 

530 Args: 

531 branch_name: Optional human-readable label for this branch 

532 metadata: Optional metadata for the assistant message node 

533 **llm_kwargs: Additional arguments for LLM.complete() 

534 

535 Returns: 

536 LLM response with content, usage, and cost information 

537 

538 Raises: 

539 ValueError: If conversation has no messages yet 

540 

541 Example: 

542 ```python 

543 # Get response 

544 result = await manager.complete() 

545 print(result.content) 

546 print(f"Cost: ${result.cost_usd:.4f}") 

547 

548 # Create labeled branch 

549 result = await manager.complete(branch_name="alternative-answer") 

550 

551 # With LLM parameters 

552 result = await manager.complete(temperature=0.9, max_tokens=500) 

553 ``` 

554 

555 Note: 

556 **Middleware Execution Order** (Onion Model): 

557 

558 - Pre-LLM: middleware[0] → middleware[1] → ... → middleware[N] 

559 - LLM call happens 

560 - Post-LLM: middleware[N] → ... → middleware[1] → middleware[0] 

561 

562 This "onion" pattern ensures that middleware wraps around the LLM 

563 call symmetrically. For example, if middleware[0] starts a timer 

564 in `process_request()`, it will be the last to run in 

565 `process_response()` and can log the total elapsed time. 

566 

567 **Automatic Cost Tracking**: 

568 

569 The response includes `cost_usd` (this call) and `cumulative_cost_usd` 

570 (total conversation cost) if the LLM provider returns usage statistics. 

571 

572 See Also: 

573 stream_complete: Streaming version for real-time output 

574 add_message: Add user/system message before calling complete 

575 switch_to_node: Navigate to different branch before completing 

576 """ 

577 if not self.state: 

578 raise ValueError("Cannot complete: no messages in conversation") 

579 

580 # Get messages from root to current position 

581 messages = self.state.get_current_messages() 

582 

583 # Execute middleware (pre-LLM) in forward order 

584 for mw in self.middleware: 

585 messages = await mw.process_request(messages, self.state) 

586 

587 # Call LLM 

588 response = await self.llm.complete(messages, **llm_kwargs) 

589 

590 # Execute middleware (post-LLM) in reverse order (onion model) 

591 for mw in reversed(self.middleware): 

592 response = await mw.process_response(response, self.state) 

593 

594 # Add assistant message as child 

595 current_tree_node = self.state.get_current_node() 

596 if current_tree_node is None: 

597 raise ValueError(f"Current node '{self.state.current_node_id}' not found") 

598 

599 # Create assistant message node 

600 assistant_message = LLMMessage( 

601 role="assistant", 

602 content=response.content, 

603 ) 

604 

605 assistant_metadata = metadata or {} 

606 assistant_metadata.update({ 

607 "usage": response.usage, 

608 "model": response.model, 

609 "finish_reason": response.finish_reason, 

610 }) 

611 

612 # Calculate and track cost 

613 self._calculate_and_track_cost(response, assistant_metadata) 

614 

615 new_tree_node = Tree( 

616 ConversationNode( 

617 message=assistant_message, 

618 node_id="", # Will be calculated 

619 branch_name=branch_name, 

620 metadata=assistant_metadata, 

621 ) 

622 ) 

623 

624 # Add to tree 

625 current_tree_node.add_child(new_tree_node) 

626 

627 # Calculate node_id 

628 node_id = calculate_node_id(new_tree_node) 

629 new_tree_node.data.node_id = node_id 

630 

631 # Move current position 

632 self.state.current_node_id = node_id 

633 self.state.updated_at = datetime.now() 

634 

635 # Persist 

636 await self._save_state() 

637 

638 return response 

639 

640 async def stream_complete( 

641 self, 

642 branch_name: str | None = None, 

643 metadata: Dict[str, Any] | None = None, 

644 **llm_kwargs, 

645 ) -> AsyncIterator[LLMStreamResponse]: 

646 r"""Stream LLM completion and add as child of current node. 

647 

648 Similar to complete() but streams the response incrementally for 

649 real-time display. The complete response is automatically added 

650 to the conversation tree after streaming finishes. 

651 

652 Args: 

653 branch_name: Optional human-readable label for this branch 

654 metadata: Optional metadata for the assistant message node 

655 **llm_kwargs: Additional arguments for LLM.stream_complete() 

656 

657 Yields: 

658 Streaming response chunks with delta, usage, and final metadata 

659 

660 Raises: 

661 ValueError: If conversation has no messages yet 

662 

663 Example: 

664 ```python 

665 # Real-time display 

666 async for chunk in manager.stream_complete(): 

667 print(chunk.delta, end="", flush=True) 

668 print() # New line after streaming 

669 

670 # Accumulate full response 

671 full_text = "" 

672 async for chunk in manager.stream_complete(): 

673 full_text += chunk.delta 

674 if chunk.is_final: 

675 print(f"\nFinished. Total: {len(full_text)} chars") 

676 print(f"Cost: ${chunk.usage.get('cost_usd', 0):.4f}") 

677 

678 # With branch label 

679 async for chunk in manager.stream_complete( 

680 branch_name="creative-response", 

681 temperature=0.9 

682 ): 

683 print(chunk.delta, end="", flush=True) 

684 ``` 

685 

686 Note: 

687 The middleware execution order is the same as `complete()`: 

688 pre-LLM middleware runs before streaming starts, post-LLM 

689 middleware runs after the stream completes. 

690 

691 Cost and usage information is only available in the final chunk 

692 (when `chunk.is_final == True`). 

693 

694 See Also: 

695 complete: Non-streaming version for simple use cases 

696 add_message: Add message before streaming 

697 """ 

698 if not self.state: 

699 raise ValueError("Cannot complete: no messages in conversation") 

700 

701 # Get messages 

702 messages = self.state.get_current_messages() 

703 

704 # Execute middleware (pre-LLM) in forward order 

705 for mw in self.middleware: 

706 messages = await mw.process_request(messages, self.state) 

707 

708 # Stream LLM response and accumulate 

709 full_content = "" 

710 final_chunk = None 

711 async for chunk in self.llm.stream_complete(messages, **llm_kwargs): 

712 full_content += chunk.delta 

713 final_chunk = chunk 

714 yield chunk 

715 

716 # Create complete response for state update 

717 response = LLMResponse( 

718 content=full_content, 

719 model=self.llm.config.model, 

720 finish_reason=final_chunk.finish_reason if final_chunk else "stop", 

721 usage=final_chunk.usage if final_chunk else None, 

722 ) 

723 

724 # Execute middleware (post-LLM) in reverse order (onion model) 

725 for mw in reversed(self.middleware): 

726 response = await mw.process_response(response, self.state) 

727 

728 # Add assistant message as child (same as complete()) 

729 current_tree_node = self.state.get_current_node() 

730 if current_tree_node is None: 

731 raise ValueError(f"Current node '{self.state.current_node_id}' not found") 

732 

733 assistant_message = LLMMessage(role="assistant", content=response.content) 

734 

735 assistant_metadata = metadata or {} 

736 assistant_metadata.update({ 

737 "usage": response.usage, 

738 "model": response.model, 

739 "finish_reason": response.finish_reason, 

740 }) 

741 

742 # Calculate and track cost 

743 self._calculate_and_track_cost(response, assistant_metadata) 

744 

745 new_tree_node = Tree( 

746 ConversationNode( 

747 message=assistant_message, 

748 node_id="", 

749 branch_name=branch_name, 

750 metadata=assistant_metadata, 

751 ) 

752 ) 

753 

754 current_tree_node.add_child(new_tree_node) 

755 node_id = calculate_node_id(new_tree_node) 

756 new_tree_node.data.node_id = node_id 

757 

758 self.state.current_node_id = node_id 

759 self.state.updated_at = datetime.now() 

760 

761 await self._save_state() 

762 

763 async def switch_to_node(self, node_id: str) -> None: 

764 """Switch current position to a different node in the tree. 

765 

766 This allows exploring different branches or backtracking in the conversation. 

767 

768 Args: 

769 node_id: Target node ID (dot-delimited, e.g., "0.1" or "") 

770 

771 Raises: 

772 ValueError: If node_id not found in tree 

773 

774 Example: 

775 >>> # Go back to first user message 

776 >>> await manager.switch_to_node("0") 

777 >>> 

778 >>> # Create alternative response 

779 >>> result = await manager.complete(branch_name="alternative") 

780 >>> 

781 >>> # Go back to root 

782 >>> await manager.switch_to_node("") 

783 """ 

784 if not self.state: 

785 raise ValueError("No conversation state") 

786 

787 # Verify node exists 

788 target_node = get_node_by_id(self.state.message_tree, node_id) 

789 if target_node is None: 

790 raise ValueError(f"Node '{node_id}' not found in conversation tree") 

791 

792 # Update current position 

793 self.state.current_node_id = node_id 

794 self.state.updated_at = datetime.now() 

795 

796 # Persist 

797 await self._save_state() 

798 

799 async def execute_flow( 

800 self, 

801 flow: ConversationFlow, 

802 initial_params: Dict[str, Any] | None = None 

803 ) -> AsyncIterator[ConversationNode]: 

804 """Execute a conversation flow using FSM. 

805 

806 This method executes a predefined conversation flow, yielding 

807 conversation nodes as the flow progresses through states. 

808 

809 Args: 

810 flow: ConversationFlow definition 

811 initial_params: Optional initial parameters for the flow 

812 

813 Yields: 

814 ConversationNode for each state in the flow 

815 

816 Raises: 

817 ValueError: If flow execution fails 

818 

819 Example: 

820 >>> from dataknobs_llm.conversations.flow import ( 

821 ... ConversationFlow, FlowState, 

822 ... keyword_condition 

823 ... ) 

824 >>> 

825 >>> # Define flow 

826 >>> flow = ConversationFlow( 

827 ... name="support", 

828 ... initial_state="greeting", 

829 ... states={ 

830 ... "greeting": FlowState( 

831 ... prompt_name="support_greeting", 

832 ... transitions={ 

833 ... "help": "collect_issue", 

834 ... "browse": "end" 

835 ... }, 

836 ... transition_conditions={ 

837 ... "help": keyword_condition(["help", "issue"]), 

838 ... "browse": keyword_condition(["browse", "look"]) 

839 ... } 

840 ... ) 

841 ... } 

842 ... ) 

843 >>> 

844 >>> # Execute flow 

845 >>> async for node in manager.execute_flow(flow): 

846 ... print(f"State: {node.metadata.get('state')}") 

847 ... print(f"Response: {node.content}") 

848 """ 

849 from dataknobs_llm.conversations.flow import ConversationFlowAdapter 

850 

851 if not self.state: 

852 raise ValueError("No conversation state") 

853 

854 # Create adapter 

855 adapter = ConversationFlowAdapter( 

856 flow=flow, 

857 prompt_builder=self.prompt_builder, 

858 llm=self.llm 

859 ) 

860 

861 # Execute flow and yield nodes 

862 data = initial_params or {} 

863 data["conversation_id"] = self.state.conversation_id 

864 

865 try: 

866 # Execute flow (this will internally use FSM) 

867 await adapter.execute(data) 

868 

869 # Convert flow history to conversation nodes 

870 for state_name, response in adapter.execution_state.history: 

871 # Create node for this state's response 

872 node = ConversationNode( 

873 node_id=str(uuid.uuid4()), # Temporary ID 

874 role="assistant", 

875 content=response, 

876 timestamp=datetime.now(), 

877 metadata={ 

878 "state": state_name, 

879 "flow_name": flow.name, 

880 "flow_execution": True 

881 } 

882 ) 

883 

884 # Add to conversation tree 

885 current_tree_node = get_node_by_id( 

886 self.state.message_tree, 

887 self.state.current_node_id 

888 ) 

889 

890 new_tree_node = Tree(node) 

891 current_tree_node.add_child(new_tree_node) 

892 node_id = calculate_node_id(new_tree_node) 

893 new_tree_node.data.node_id = node_id 

894 

895 self.state.current_node_id = node_id 

896 self.state.updated_at = datetime.now() 

897 

898 await self._save_state() 

899 

900 yield node 

901 

902 except Exception as e: 

903 import logging 

904 logging.error(f"Flow execution failed: {e}") 

905 raise ValueError(f"Flow execution failed: {e!s}") from e 

906 

907 async def get_history(self) -> List[LLMMessage]: 

908 """Get conversation history from root to current position. 

909 

910 Returns: 

911 List of messages in current conversation path 

912 

913 Example: 

914 >>> messages = await manager.get_history() 

915 >>> for msg in messages: 

916 ... print(f"{msg.role}: {msg.content}") 

917 """ 

918 if not self.state: 

919 return [] 

920 

921 return self.state.get_current_messages() 

922 

923 async def get_branches(self, node_id: str | None = None) -> List[Dict[str, Any]]: 

924 """Get information about branches from a given node. 

925 

926 Args: 

927 node_id: Node to get branches from (default: current node) 

928 

929 Returns: 

930 List of branch info dicts with keys: 

931 - node_id: ID of child node 

932 - branch_name: Optional branch name 

933 - role: Message role 

934 - preview: First 100 chars of content 

935 - timestamp: When created 

936 

937 Example: 

938 >>> branches = await manager.get_branches() 

939 >>> for branch in branches: 

940 ... print(f"{branch['branch_name']}: {branch['preview']}") 

941 """ 

942 if not self.state: 

943 return [] 

944 

945 # Default to current node 

946 if node_id is None: 

947 node_id = self.state.current_node_id 

948 

949 # Get node 

950 node = get_node_by_id(self.state.message_tree, node_id) 

951 if node is None or not node.children: 

952 return [] 

953 

954 # Build branch info 

955 branches = [] 

956 for child in node.children: 

957 data = child.data 

958 branches.append({ 

959 "node_id": data.node_id, 

960 "branch_name": data.branch_name, 

961 "role": data.message.role, 

962 "preview": data.message.content[:100], 

963 "timestamp": data.timestamp, 

964 }) 

965 

966 return branches 

967 

968 async def add_metadata(self, key: str, value: Any) -> None: 

969 """Add metadata to conversation. 

970 

971 Args: 

972 key: Metadata key 

973 value: Metadata value 

974 

975 Example: 

976 >>> await manager.add_metadata("user_id", "alice") 

977 >>> await manager.add_metadata("session", "abc123") 

978 """ 

979 if not self.state: 

980 raise ValueError("No conversation state") 

981 

982 self.state.metadata[key] = value 

983 self.state.updated_at = datetime.now() 

984 await self._save_state() 

985 

986 async def _find_cached_rag( 

987 self, 

988 prompt_name: str, 

989 role: str, 

990 params: Dict[str, Any] 

991 ) -> Dict[str, Any] | None: 

992 """Search conversation history for cached RAG metadata. 

993 

994 This method searches the entire conversation tree for cached RAG metadata 

995 that matches both the prompt name/role AND the resolved RAG query parameters. 

996 Query matching is done via query hashes. 

997 

998 Args: 

999 prompt_name: Name of the prompt to find cached RAG for 

1000 role: Role of the prompt ("system" or "user") 

1001 params: Parameters for the prompt (used to match RAG queries) 

1002 

1003 Returns: 

1004 Cached RAG metadata dictionary if found, None otherwise 

1005 

1006 Example: 

1007 >>> cached = await manager._find_cached_rag("code_question", "user", {"topic": "decorators"}) 

1008 >>> if cached: 

1009 ... print(f"Found cached RAG with {len(cached)} placeholders") 

1010 """ 

1011 if not self.state: 

1012 return None 

1013 

1014 # Get RAG configs for this prompt to determine what queries we're looking for 

1015 rag_configs = self.prompt_builder.library.get_prompt_rag_configs( 

1016 prompt_name=prompt_name, 

1017 prompt_type="system" if role == "system" else "user" 

1018 ) 

1019 

1020 if not rag_configs: 

1021 return None 

1022 

1023 # Compute the query hashes we're looking for 

1024 from jinja2 import Template 

1025 target_hashes_by_placeholder = {} 

1026 for rag_config in rag_configs: 

1027 placeholder = rag_config.get("placeholder", "RAG_CONTENT") 

1028 adapter_name = rag_config.get("adapter_name", "") 

1029 query_template = rag_config.get("query", "") 

1030 

1031 # Render the query template with params 

1032 try: 

1033 template = Template(query_template) 

1034 resolved_query = template.render(params) 

1035 

1036 # Compute hash 

1037 query_hash = self.prompt_builder._compute_rag_query_hash(adapter_name, resolved_query) 

1038 target_hashes_by_placeholder[placeholder] = query_hash 

1039 except Exception: 

1040 # If query rendering fails, we can't match cache 

1041 continue 

1042 

1043 if not target_hashes_by_placeholder: 

1044 return None 

1045 

1046 # Search entire tree for matching cached RAG (BFS to find any match) 

1047 from collections import deque 

1048 queue = deque([self.state.message_tree]) 

1049 

1050 while queue: 

1051 tree_node = queue.popleft() 

1052 node_data = tree_node.data 

1053 

1054 # Check if this node has the same prompt name and role 

1055 if (node_data.prompt_name == prompt_name and 

1056 node_data.message.role == role): 

1057 

1058 # Check if RAG metadata exists 

1059 rag_metadata = node_data.metadata.get("rag_metadata") 

1060 if rag_metadata: 

1061 # Check if query hashes match for all placeholders 

1062 all_match = True 

1063 for placeholder, target_hash in target_hashes_by_placeholder.items(): 

1064 if placeholder not in rag_metadata: 

1065 all_match = False 

1066 break 

1067 cached_hash = rag_metadata[placeholder].get("query_hash") 

1068 if cached_hash != target_hash: 

1069 all_match = False 

1070 break 

1071 

1072 if all_match: 

1073 return rag_metadata 

1074 

1075 # Add children to queue (if any) 

1076 if tree_node.children: 

1077 queue.extend(tree_node.children) 

1078 

1079 return None 

1080 

1081 def get_rag_metadata(self, node_id: str | None = None) -> Dict[str, Any] | None: 

1082 """Get RAG metadata from a conversation node. 

1083 

1084 This method retrieves the cached RAG metadata from a specific node, 

1085 which includes information about RAG searches executed during prompt 

1086 rendering (queries, results, query hashes, etc.). 

1087 

1088 Args: 

1089 node_id: Node ID to retrieve metadata from (default: current node) 

1090 

1091 Returns: 

1092 RAG metadata dictionary if present, None otherwise. Structure: 

1093 

1094 ```python 

1095 { 

1096 "PLACEHOLDER_NAME": { 

1097 "query": "resolved RAG query", 

1098 "query_hash": "hash of adapter+query", 

1099 "results": [...], # Search results 

1100 "adapter_name": "name of RAG adapter used" 

1101 }, 

1102 ... # One entry per RAG placeholder 

1103 } 

1104 ``` 

1105 

1106 Raises: 

1107 ValueError: If node_id not found in conversation tree 

1108 

1109 Example: 

1110 ```python 

1111 # Get RAG metadata from current node 

1112 metadata = manager.get_rag_metadata() 

1113 if metadata: 

1114 for placeholder, rag_data in metadata.items(): 

1115 print(f"Placeholder: {placeholder}") 

1116 print(f" Query: {rag_data['query']}") 

1117 print(f" Adapter: {rag_data['adapter_name']}") 

1118 print(f" Results: {len(rag_data['results'])} items") 

1119 print(f" Hash: {rag_data['query_hash']}") 

1120 

1121 # Get RAG metadata from specific node 

1122 metadata = manager.get_rag_metadata(node_id="0.1") 

1123 

1124 # Check if RAG was used for a message 

1125 if manager.get_rag_metadata(): 

1126 print("This message used RAG-enhanced prompt") 

1127 else: 

1128 print("This message used direct content") 

1129 ``` 

1130 

1131 Note: 

1132 RAG metadata is only available if `cache_rag_results=True` was 

1133 set during ConversationManager creation. This metadata is useful 

1134 for debugging RAG behavior, understanding what information was 

1135 retrieved, and implementing RAG result caching across branches. 

1136 

1137 See Also: 

1138 add_message: Method that executes RAG and stores metadata 

1139 reuse_rag_on_branch: Parameter enabling RAG cache reuse 

1140 """ 

1141 if not self.state: 

1142 return None 

1143 

1144 # Default to current node 

1145 if node_id is None: 

1146 node_id = self.state.current_node_id 

1147 

1148 # Get node 

1149 tree_node = get_node_by_id(self.state.message_tree, node_id) 

1150 if tree_node is None: 

1151 raise ValueError(f"Node '{node_id}' not found in conversation tree") 

1152 

1153 # Return RAG metadata if present 

1154 return tree_node.data.metadata.get("rag_metadata") 

1155 

1156 async def _save_state(self) -> None: 

1157 """Persist current state to storage.""" 

1158 if self.state: 

1159 await self.storage.save_conversation(self.state) 

1160 

1161 @property 

1162 def conversation_id(self) -> str | None: 

1163 """Get conversation ID.""" 

1164 return self.state.conversation_id if self.state else None 

1165 

1166 @property 

1167 def current_node_id(self) -> str | None: 

1168 """Get current node ID.""" 

1169 return self.state.current_node_id if self.state else None 

1170 

1171 def get_metadata(self, key: str | None = None, default: Any = None) -> Any: 

1172 """Get conversation metadata. 

1173 

1174 This provides access to the conversation-level metadata stored in 

1175 the ConversationState. Metadata is useful for storing client_id, 

1176 user_id, session information, and other contextual data. 

1177 

1178 Args: 

1179 key: Specific metadata key to retrieve. If None, returns all metadata. 

1180 default: Default value if key not found (only used when key is specified) 

1181 

1182 Returns: 

1183 Metadata value, all metadata dict, or default value 

1184 

1185 Example: 

1186 >>> # Get all metadata 

1187 >>> metadata = manager.get_metadata() 

1188 >>> print(metadata) # {'client_id': 'abc', 'user_id': '123'} 

1189 >>> 

1190 >>> # Get specific key 

1191 >>> client_id = manager.get_metadata('client_id') 

1192 >>> print(client_id) # 'abc' 

1193 >>> 

1194 >>> # Get with default 

1195 >>> tier = manager.get_metadata('user_tier', default='free') 

1196 """ 

1197 if not self.state: 

1198 return default if key else {} 

1199 

1200 if key is None: 

1201 return self.state.metadata 

1202 else: 

1203 return self.state.metadata.get(key, default) 

1204 

1205 def set_metadata(self, key: str, value: Any) -> None: 

1206 """Set conversation metadata. 

1207 

1208 Updates a specific key in the conversation metadata. The metadata 

1209 is automatically persisted when save() is called. 

1210 

1211 Args: 

1212 key: Metadata key to set 

1213 value: Metadata value 

1214 

1215 Example: 

1216 >>> manager.set_metadata('client_id', 'client-abc') 

1217 >>> manager.set_metadata('user_tier', 'premium') 

1218 >>> await manager.save() 

1219 """ 

1220 if self.state: 

1221 self.state.metadata[key] = value 

1222 

1223 def update_metadata(self, updates: Dict[str, Any]) -> None: 

1224 """Update multiple metadata fields at once. 

1225 

1226 Args: 

1227 updates: Dictionary of metadata key-value pairs to update 

1228 

1229 Example: 

1230 >>> manager.update_metadata({ 

1231 ... 'client_id': 'client-abc', 

1232 ... 'user_id': 'user-456', 

1233 ... 'session_id': 'sess-789' 

1234 ... }) 

1235 >>> await manager.save() 

1236 """ 

1237 if self.state: 

1238 self.state.metadata.update(updates) 

1239 

1240 def remove_metadata(self, key: str) -> None: 

1241 """Remove a metadata key. 

1242 

1243 Args: 

1244 key: Metadata key to remove 

1245 

1246 Example: 

1247 >>> manager.remove_metadata('temporary_flag') 

1248 >>> await manager.save() 

1249 """ 

1250 if self.state and key in self.state.metadata: 

1251 del self.state.metadata[key] 

1252 

1253 def get_total_cost(self) -> float: 

1254 """Get total accumulated cost for this conversation in USD. 

1255 

1256 Calculates the sum of all LLM API costs from the conversation history. 

1257 Requires that cost_usd was set on LLMResponses. 

1258 

1259 Returns: 

1260 Total cost in USD, or 0.0 if no cost data available 

1261 

1262 Example: 

1263 >>> total = manager.get_total_cost() 

1264 >>> print(f"Total cost: ${total:.4f}") 

1265 """ 

1266 if not self.state: 

1267 return 0.0 

1268 

1269 total = 0.0 

1270 

1271 # Walk the tree and sum costs from all assistant message nodes 

1272 def walk_tree(node: Tree) -> None: 

1273 nonlocal total 

1274 if node.data and node.data.metadata: 

1275 cost = node.data.metadata.get('cost_usd') 

1276 if cost is not None: 

1277 total += cost 

1278 

1279 for child in node.children: 

1280 walk_tree(child) 

1281 

1282 walk_tree(self.state.message_tree) 

1283 return total 

1284 

1285 def get_cost_by_branch(self, node_id: str | None = None) -> float: 

1286 """Get accumulated cost for a specific conversation branch. 

1287 

1288 Calculates the cost from root to a specific node (defaults to current). 

1289 

1290 Args: 

1291 node_id: Node ID to calculate cost to. If None, uses current node. 

1292 

1293 Returns: 

1294 Cost in USD for this branch, or 0.0 if no cost data 

1295 

1296 Example: 

1297 >>> # Get cost of current branch 

1298 >>> current_cost = manager.get_cost_by_branch() 

1299 >>> 

1300 >>> # Get cost of specific branch 

1301 >>> alt_cost = manager.get_cost_by_branch("0.1") 

1302 """ 

1303 if not self.state: 

1304 return 0.0 

1305 

1306 target_node_id = node_id or self.state.current_node_id 

1307 

1308 # Get messages in this branch 

1309 

1310 # Walk from root to target node 

1311 if not target_node_id or target_node_id == "": 

1312 # Just root node 

1313 return 0.0 

1314 

1315 indexes = [int(i) for i in target_node_id.split(".")] 

1316 

1317 total = 0.0 

1318 current = self.state.message_tree 

1319 

1320 for idx in indexes: 

1321 if idx < len(current.children): 

1322 current = current.children[idx] 

1323 if current.data and current.data.metadata: 

1324 cost = current.data.metadata.get('cost_usd') 

1325 if cost is not None: 

1326 total += cost 

1327 

1328 return total 

1329 

1330 def _calculate_and_track_cost( 

1331 self, 

1332 response: LLMResponse, 

1333 metadata: Dict[str, Any] 

1334 ) -> None: 

1335 """Calculate cost for a response and add to metadata. 

1336 

1337 This is an internal helper that uses the CostCalculator utility 

1338 to estimate costs and track them in the conversation. 

1339 

1340 Args: 

1341 response: LLM response to calculate cost for 

1342 metadata: Metadata dict to add cost information to 

1343 """ 

1344 try: 

1345 from dataknobs_llm.llm.utils import CostCalculator 

1346 

1347 if response.usage: 

1348 cost = CostCalculator.calculate_cost(response, response.model) 

1349 if cost is not None: 

1350 # Add to response 

1351 response.cost_usd = cost 

1352 

1353 # Calculate cumulative cost 

1354 cumulative = self.get_total_cost() + cost 

1355 response.cumulative_cost_usd = cumulative 

1356 

1357 # Store in metadata 

1358 metadata['cost_usd'] = cost 

1359 metadata['cumulative_cost_usd'] = cumulative 

1360 except Exception as e: 

1361 # Don't fail the conversation if cost calculation fails 

1362 import logging 

1363 logger = logging.getLogger(__name__) 

1364 logger.warning(f"Failed to calculate cost: {e}")