Coverage for src/dataknobs_llm/conversations/manager.py: 71%
299 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
1"""Conversation manager for multi-turn interactions with LLMs.
3This module provides ConversationManager, a comprehensive system for managing
4multi-turn LLM conversations with advanced features like:
6- **Tree-based History**: Conversations stored as trees, enabling branching
7- **Persistence**: Automatic state saving to any storage backend
8- **RAG Caching**: Reuse search results across conversation branches
9- **Middleware**: Pre/post-processing pipeline for all LLM calls
10- **Cost Tracking**: Automatic API cost calculation and accumulation
11- **Flow Execution**: FSM-based conversation flows with state management
12- **Resumability**: Save and resume conversations across sessions
14Architecture:
15 ConversationManager orchestrates three core components:
17 1. **AsyncLLMProvider**: Handles LLM API calls (OpenAI, Anthropic, Ollama)
18 2. **AsyncPromptBuilder**: Renders prompts with RAG integration
19 3. **ConversationStorage**: Persists conversation state (Memory, File, S3, Postgres)
21 Conversations are stored as trees where each node represents a message.
22 Branching occurs when multiple responses are generated from the same point,
23 enabling A/B testing, alternative explorations, and retry scenarios.
25Example:
26 ```python
27 from dataknobs_llm import create_llm_provider
28 from dataknobs_llm.prompts import AsyncPromptBuilder
29 from dataknobs_llm.conversations import (
30 ConversationManager,
31 DataknobsConversationStorage
32 )
33 from dataknobs_data import database_factory
35 # Setup components
36 llm = create_llm_provider("openai", api_key="sk-...")
37 builder = AsyncPromptBuilder.create(library_path="./prompts")
38 db = database_factory.create(backend="memory")
39 storage = DataknobsConversationStorage(db)
41 # Create conversation
42 manager = await ConversationManager.create(
43 llm=llm,
44 prompt_builder=builder,
45 storage=storage,
46 system_prompt_name="helpful_assistant"
47 )
49 # Add user message and get response
50 await manager.add_message(
51 role="user",
52 content="What is Python?"
53 )
54 response = await manager.complete()
55 print(response.content)
57 # Continue conversation
58 await manager.add_message(
59 role="user",
60 content="Show me a decorator example"
61 )
62 response = await manager.complete()
64 # Create alternative response branch
65 await manager.switch_to_node("0.0") # Back to first assistant response
66 alt_response = await manager.complete(branch_name="alternative")
68 # Resume later
69 conv_id = manager.conversation_id
70 manager2 = await ConversationManager.resume(
71 conversation_id=conv_id,
72 llm=llm,
73 prompt_builder=builder,
74 storage=storage
75 )
76 ```
78See Also:
79 - ConversationStorage: Storage interface and implementations
80 - ConversationMiddleware: Middleware system for request/response processing
81 - ConversationFlow: FSM-based conversation flows
82 - AsyncPromptBuilder: Prompt rendering with RAG integration
83"""
85import uuid
86from typing import List, Dict, Any, AsyncIterator
87from datetime import datetime
89from dataknobs_structures.tree import Tree
90from dataknobs_llm.llm import AsyncLLMProvider, LLMMessage, LLMResponse, LLMStreamResponse
91from dataknobs_llm.prompts import AsyncPromptBuilder
92from dataknobs_llm.conversations.flow.flow import ConversationFlow
93from dataknobs_llm.conversations.middleware import ConversationMiddleware
94from dataknobs_llm.conversations.storage import (
95 ConversationNode,
96 ConversationState,
97 ConversationStorage,
98 calculate_node_id,
99 get_node_by_id,
100)
103class ConversationManager:
104 """Manages multi-turn conversations with persistence and branching.
106 This class orchestrates conversations by:
107 - Tracking message history with tree-based branching
108 - Managing conversation state
109 - Persisting to storage backend
110 - Supporting multiple conversation branches
112 The conversation history is stored as a tree structure where:
113 - Root node contains the initial system prompt (if any)
114 - Each message is a tree node with a dot-delimited ID (e.g., "0.1.2")
115 - Branches occur when multiple children are added to the same node
116 - Current position tracks where you are in the conversation tree
118 Attributes:
119 llm: LLM provider for completions
120 prompt_builder: Prompt builder with library
121 storage: Storage backend for persistence
122 state: Current conversation state (tree, metadata, position)
123 middleware: List of middleware to execute on requests/responses
124 cache_rag_results: Whether to store RAG metadata in nodes
125 reuse_rag_on_branch: Whether to reuse cached RAG across branches
126 conversation_id: Unique conversation identifier
127 current_node_id: Current position in conversation tree
129 Example:
130 ```python
131 # Create conversation
132 manager = await ConversationManager.create(
133 llm=llm,
134 prompt_builder=builder,
135 storage=storage_backend,
136 system_prompt_name="helpful_assistant"
137 )
139 # Add user message
140 await manager.add_message(
141 prompt_name="user_query",
142 params={"question": "What is Python?"},
143 role="user"
144 )
146 # Get LLM response
147 result = await manager.complete()
149 # Continue conversation
150 await manager.add_message(
151 content="Tell me more about decorators",
152 role="user"
153 )
154 result = await manager.complete()
156 # Create alternative response branch
157 await manager.switch_to_node("0") # Back to first user message
158 result2 = await manager.complete(branch_name="alt-response")
160 # Resume after interruption
161 manager2 = await ConversationManager.resume(
162 conversation_id=manager.conversation_id,
163 llm=llm,
164 prompt_builder=builder,
165 storage=storage_backend
166 )
167 ```
169 Note:
170 Tree-based branching enables:
172 - **A/B Testing**: Generate multiple responses from the same context
173 - **Retry Logic**: Try again from a previous point after failures
174 - **Alternative Explorations**: Explore different conversation paths
175 - **Debugging**: Compare different middleware or RAG configurations
177 Node IDs use dot notation (e.g., "0.1.2" means 3rd child of 2nd child
178 of 1st child of root). The root node has ID "".
180 State is automatically persisted after every operation. Use
181 `resume()` to continue conversations across sessions or servers.
183 See Also:
184 create: Create a new conversation
185 resume: Resume an existing conversation
186 add_message: Add user/system message
187 complete: Get LLM completion (blocking)
188 stream_complete: Get LLM completion (streaming)
189 switch_to_node: Navigate to different branch
190 get_branches: List available branches
191 get_total_cost: Calculate cumulative cost
192 ConversationStorage: Storage backend implementations
193 ConversationMiddleware: Request/response processing
194 ConversationFlow: FSM-based conversation flows
195 """
197 def __init__(
198 self,
199 llm: AsyncLLMProvider,
200 prompt_builder: AsyncPromptBuilder,
201 storage: ConversationStorage,
202 state: ConversationState | None = None,
203 metadata: Dict[str, Any] | None = None,
204 middleware: List[ConversationMiddleware] | None = None,
205 cache_rag_results: bool = False,
206 reuse_rag_on_branch: bool = False,
207 ):
208 """Initialize conversation manager.
210 Note: Use ConversationManager.create() or ConversationManager.resume()
211 instead of calling __init__ directly.
213 Args:
214 llm: LLM provider for completions
215 prompt_builder: Prompt builder with library
216 storage: Storage backend for persistence
217 state: Optional existing conversation state
218 metadata: Optional metadata for new conversations
219 middleware: Optional list of middleware to execute
220 cache_rag_results: If True, store RAG metadata in node metadata
221 for debugging and transparency
222 reuse_rag_on_branch: If True, reuse cached RAG results when
223 possible (useful for testing/branching)
224 """
225 self.llm = llm
226 self.prompt_builder = prompt_builder
227 self.storage = storage
228 self.state = state
229 self._initial_metadata = metadata or {}
230 self.middleware = middleware or []
231 self.cache_rag_results = cache_rag_results
232 self.reuse_rag_on_branch = reuse_rag_on_branch
234 @classmethod
235 async def create(
236 cls,
237 llm: AsyncLLMProvider,
238 prompt_builder: AsyncPromptBuilder,
239 storage: ConversationStorage,
240 system_prompt_name: str | None = None,
241 system_params: Dict[str, Any] | None = None,
242 metadata: Dict[str, Any] | None = None,
243 middleware: List[ConversationMiddleware] | None = None,
244 cache_rag_results: bool = False,
245 reuse_rag_on_branch: bool = False,
246 ) -> "ConversationManager":
247 """Create a new conversation.
249 Args:
250 llm: LLM provider
251 prompt_builder: Prompt builder
252 storage: Storage backend
253 system_prompt_name: Optional system prompt to initialize with
254 system_params: Optional params for system prompt
255 metadata: Optional conversation metadata
256 middleware: Optional list of middleware to execute
257 cache_rag_results: If True, store RAG metadata in node metadata
258 reuse_rag_on_branch: If True, reuse cached RAG results when possible
260 Returns:
261 Initialized ConversationManager
263 Example:
264 >>> manager = await ConversationManager.create(
265 ... llm=llm,
266 ... prompt_builder=builder,
267 ... storage=storage,
268 ... system_prompt_name="helpful_assistant",
269 ... cache_rag_results=True
270 ... )
271 """
272 manager = cls(
273 llm=llm,
274 prompt_builder=prompt_builder,
275 storage=storage,
276 metadata=metadata,
277 middleware=middleware,
278 cache_rag_results=cache_rag_results,
279 reuse_rag_on_branch=reuse_rag_on_branch,
280 )
282 # Initialize with system prompt if provided
283 if system_prompt_name:
284 await manager.add_message(
285 prompt_name=system_prompt_name,
286 params=system_params,
287 role="system",
288 )
290 return manager
292 @classmethod
293 async def resume(
294 cls,
295 conversation_id: str,
296 llm: AsyncLLMProvider,
297 prompt_builder: AsyncPromptBuilder,
298 storage: ConversationStorage,
299 middleware: List[ConversationMiddleware] | None = None,
300 cache_rag_results: bool = False,
301 reuse_rag_on_branch: bool = False,
302 ) -> "ConversationManager":
303 """Resume an existing conversation.
305 Args:
306 conversation_id: Existing conversation ID
307 llm: LLM provider
308 prompt_builder: Prompt builder
309 storage: Storage backend
310 middleware: Optional list of middleware to execute
311 cache_rag_results: If True, store RAG metadata in node metadata
312 reuse_rag_on_branch: If True, reuse cached RAG results when possible
314 Returns:
315 ConversationManager with restored state
317 Raises:
318 ValueError: If conversation not found
320 Example:
321 >>> manager = await ConversationManager.resume(
322 ... conversation_id="conv-123",
323 ... llm=llm,
324 ... prompt_builder=builder,
325 ... storage=storage,
326 ... cache_rag_results=True
327 ... )
328 """
329 # Load state from storage
330 state = await storage.load_conversation(conversation_id)
331 if not state:
332 raise ValueError(f"Conversation '{conversation_id}' not found")
334 # Create manager with existing state
335 manager = cls(
336 llm=llm,
337 prompt_builder=prompt_builder,
338 storage=storage,
339 state=state,
340 middleware=middleware,
341 cache_rag_results=cache_rag_results,
342 reuse_rag_on_branch=reuse_rag_on_branch,
343 )
345 return manager
347 async def add_message(
348 self,
349 role: str,
350 content: str | None = None,
351 prompt_name: str | None = None,
352 params: Dict[str, Any] | None = None,
353 include_rag: bool = True,
354 metadata: Dict[str, Any] | None = None,
355 ) -> ConversationNode:
356 """Add a message to the current conversation node.
358 Either content or prompt_name must be provided. If using a prompt
359 with RAG configuration, the RAG searches will be executed and results
360 will be automatically inserted into the prompt.
362 Args:
363 role: Message role ("system", "user", or "assistant")
364 content: Direct message content (if not using prompt)
365 prompt_name: Name of prompt template to render
366 params: Parameters for prompt rendering
367 include_rag: Whether to execute RAG searches for prompts
368 metadata: Optional metadata for this message node
370 Returns:
371 The created ConversationNode
373 Raises:
374 ValueError: If neither content nor prompt_name provided
376 Example:
377 ```python
378 # Add message from prompt
379 await manager.add_message(
380 role="user",
381 prompt_name="code_question",
382 params={"code": code_snippet}
383 )
385 # Add direct message
386 await manager.add_message(
387 role="user",
388 content="What is Python?"
389 )
391 # Add system prompt with custom metadata
392 await manager.add_message(
393 role="system",
394 prompt_name="expert_coder",
395 metadata={"version": "v2"}
396 )
397 ```
399 Note:
400 **RAG Caching Behavior**:
402 If `cache_rag_results=True` and `reuse_rag_on_branch=True` were
403 set during ConversationManager creation, this method will:
405 1. Check if the same prompt+role was used elsewhere in the tree
406 2. Check if the RAG query parameters match (via query hash)
407 3. Reuse cached RAG results if found (no re-search!)
408 4. Store new RAG results if not found
410 This is particularly useful when exploring conversation branches,
411 as you can avoid redundant searches for the same information.
413 See Also:
414 complete: Get LLM response after adding message
415 get_rag_metadata: Retrieve RAG metadata from a node
416 """
417 if not content and not prompt_name:
418 raise ValueError("Either content or prompt_name must be provided")
420 # Render prompt if needed
421 rag_metadata_to_store = None
422 if prompt_name:
423 params = params or {}
425 # Check if we should try to reuse cached RAG
426 cached_rag = None
427 if self.reuse_rag_on_branch and include_rag:
428 cached_rag = await self._find_cached_rag(prompt_name, role, params)
430 if role == "system":
431 result = await self.prompt_builder.render_system_prompt(
432 prompt_name,
433 params=params,
434 include_rag=include_rag,
435 return_rag_metadata=self.cache_rag_results,
436 cached_rag=cached_rag,
437 )
438 elif role == "user":
439 result = await self.prompt_builder.render_user_prompt(
440 prompt_name,
441 params=params,
442 include_rag=include_rag,
443 return_rag_metadata=self.cache_rag_results,
444 cached_rag=cached_rag,
445 )
446 else:
447 raise ValueError(f"Cannot render prompt for role '{role}'")
449 content = result.content
451 # Store RAG metadata if caching is enabled and metadata was captured
452 if self.cache_rag_results and result.rag_metadata:
453 rag_metadata_to_store = result.rag_metadata
455 # Create message
456 message = LLMMessage(role=role, content=content)
458 # Prepare node metadata
459 node_metadata = metadata or {}
460 if rag_metadata_to_store:
461 node_metadata["rag_metadata"] = rag_metadata_to_store
463 # Initialize state if this is the first message
464 if self.state is None:
465 conversation_id = str(uuid.uuid4())
466 root_node = ConversationNode(
467 message=message,
468 node_id="",
469 prompt_name=prompt_name,
470 metadata=node_metadata,
471 )
472 tree = Tree(root_node)
473 self.state = ConversationState(
474 conversation_id=conversation_id,
475 message_tree=tree,
476 current_node_id="",
477 metadata=self._initial_metadata,
478 )
479 else:
480 # Add as child of current node
481 current_tree_node = self.state.get_current_node()
482 if current_tree_node is None:
483 raise ValueError(f"Current node '{self.state.current_node_id}' not found")
485 # Create new tree node
486 new_tree_node = Tree(
487 ConversationNode(
488 message=message,
489 node_id="", # Will be calculated after adding to tree
490 prompt_name=prompt_name,
491 metadata=node_metadata,
492 )
493 )
495 # Add to tree
496 current_tree_node.add_child(new_tree_node)
498 # Calculate and set node_id
499 node_id = calculate_node_id(new_tree_node)
500 new_tree_node.data.node_id = node_id
502 # Move current position to new node
503 self.state.current_node_id = node_id
505 # Update timestamp
506 self.state.updated_at = datetime.now()
508 # Persist
509 await self._save_state()
511 return self.state.get_current_node().data
513 async def complete(
514 self,
515 branch_name: str | None = None,
516 metadata: Dict[str, Any] | None = None,
517 **llm_kwargs: Any,
518 ) -> LLMResponse:
519 """Get LLM completion and add as child of current node.
521 This method:
522 1. Gets conversation history from root to current node
523 2. Executes middleware (pre-LLM)
524 3. Calls LLM with history
525 4. Executes middleware (post-LLM)
526 5. Adds assistant response as child of current node
527 6. Updates current position to new node
528 7. Persists to storage
530 Args:
531 branch_name: Optional human-readable label for this branch
532 metadata: Optional metadata for the assistant message node
533 **llm_kwargs: Additional arguments for LLM.complete()
535 Returns:
536 LLM response with content, usage, and cost information
538 Raises:
539 ValueError: If conversation has no messages yet
541 Example:
542 ```python
543 # Get response
544 result = await manager.complete()
545 print(result.content)
546 print(f"Cost: ${result.cost_usd:.4f}")
548 # Create labeled branch
549 result = await manager.complete(branch_name="alternative-answer")
551 # With LLM parameters
552 result = await manager.complete(temperature=0.9, max_tokens=500)
553 ```
555 Note:
556 **Middleware Execution Order** (Onion Model):
558 - Pre-LLM: middleware[0] → middleware[1] → ... → middleware[N]
559 - LLM call happens
560 - Post-LLM: middleware[N] → ... → middleware[1] → middleware[0]
562 This "onion" pattern ensures that middleware wraps around the LLM
563 call symmetrically. For example, if middleware[0] starts a timer
564 in `process_request()`, it will be the last to run in
565 `process_response()` and can log the total elapsed time.
567 **Automatic Cost Tracking**:
569 The response includes `cost_usd` (this call) and `cumulative_cost_usd`
570 (total conversation cost) if the LLM provider returns usage statistics.
572 See Also:
573 stream_complete: Streaming version for real-time output
574 add_message: Add user/system message before calling complete
575 switch_to_node: Navigate to different branch before completing
576 """
577 if not self.state:
578 raise ValueError("Cannot complete: no messages in conversation")
580 # Get messages from root to current position
581 messages = self.state.get_current_messages()
583 # Execute middleware (pre-LLM) in forward order
584 for mw in self.middleware:
585 messages = await mw.process_request(messages, self.state)
587 # Call LLM
588 response = await self.llm.complete(messages, **llm_kwargs)
590 # Execute middleware (post-LLM) in reverse order (onion model)
591 for mw in reversed(self.middleware):
592 response = await mw.process_response(response, self.state)
594 # Add assistant message as child
595 current_tree_node = self.state.get_current_node()
596 if current_tree_node is None:
597 raise ValueError(f"Current node '{self.state.current_node_id}' not found")
599 # Create assistant message node
600 assistant_message = LLMMessage(
601 role="assistant",
602 content=response.content,
603 )
605 assistant_metadata = metadata or {}
606 assistant_metadata.update({
607 "usage": response.usage,
608 "model": response.model,
609 "finish_reason": response.finish_reason,
610 })
612 # Calculate and track cost
613 self._calculate_and_track_cost(response, assistant_metadata)
615 new_tree_node = Tree(
616 ConversationNode(
617 message=assistant_message,
618 node_id="", # Will be calculated
619 branch_name=branch_name,
620 metadata=assistant_metadata,
621 )
622 )
624 # Add to tree
625 current_tree_node.add_child(new_tree_node)
627 # Calculate node_id
628 node_id = calculate_node_id(new_tree_node)
629 new_tree_node.data.node_id = node_id
631 # Move current position
632 self.state.current_node_id = node_id
633 self.state.updated_at = datetime.now()
635 # Persist
636 await self._save_state()
638 return response
640 async def stream_complete(
641 self,
642 branch_name: str | None = None,
643 metadata: Dict[str, Any] | None = None,
644 **llm_kwargs,
645 ) -> AsyncIterator[LLMStreamResponse]:
646 r"""Stream LLM completion and add as child of current node.
648 Similar to complete() but streams the response incrementally for
649 real-time display. The complete response is automatically added
650 to the conversation tree after streaming finishes.
652 Args:
653 branch_name: Optional human-readable label for this branch
654 metadata: Optional metadata for the assistant message node
655 **llm_kwargs: Additional arguments for LLM.stream_complete()
657 Yields:
658 Streaming response chunks with delta, usage, and final metadata
660 Raises:
661 ValueError: If conversation has no messages yet
663 Example:
664 ```python
665 # Real-time display
666 async for chunk in manager.stream_complete():
667 print(chunk.delta, end="", flush=True)
668 print() # New line after streaming
670 # Accumulate full response
671 full_text = ""
672 async for chunk in manager.stream_complete():
673 full_text += chunk.delta
674 if chunk.is_final:
675 print(f"\nFinished. Total: {len(full_text)} chars")
676 print(f"Cost: ${chunk.usage.get('cost_usd', 0):.4f}")
678 # With branch label
679 async for chunk in manager.stream_complete(
680 branch_name="creative-response",
681 temperature=0.9
682 ):
683 print(chunk.delta, end="", flush=True)
684 ```
686 Note:
687 The middleware execution order is the same as `complete()`:
688 pre-LLM middleware runs before streaming starts, post-LLM
689 middleware runs after the stream completes.
691 Cost and usage information is only available in the final chunk
692 (when `chunk.is_final == True`).
694 See Also:
695 complete: Non-streaming version for simple use cases
696 add_message: Add message before streaming
697 """
698 if not self.state:
699 raise ValueError("Cannot complete: no messages in conversation")
701 # Get messages
702 messages = self.state.get_current_messages()
704 # Execute middleware (pre-LLM) in forward order
705 for mw in self.middleware:
706 messages = await mw.process_request(messages, self.state)
708 # Stream LLM response and accumulate
709 full_content = ""
710 final_chunk = None
711 async for chunk in self.llm.stream_complete(messages, **llm_kwargs):
712 full_content += chunk.delta
713 final_chunk = chunk
714 yield chunk
716 # Create complete response for state update
717 response = LLMResponse(
718 content=full_content,
719 model=self.llm.config.model,
720 finish_reason=final_chunk.finish_reason if final_chunk else "stop",
721 usage=final_chunk.usage if final_chunk else None,
722 )
724 # Execute middleware (post-LLM) in reverse order (onion model)
725 for mw in reversed(self.middleware):
726 response = await mw.process_response(response, self.state)
728 # Add assistant message as child (same as complete())
729 current_tree_node = self.state.get_current_node()
730 if current_tree_node is None:
731 raise ValueError(f"Current node '{self.state.current_node_id}' not found")
733 assistant_message = LLMMessage(role="assistant", content=response.content)
735 assistant_metadata = metadata or {}
736 assistant_metadata.update({
737 "usage": response.usage,
738 "model": response.model,
739 "finish_reason": response.finish_reason,
740 })
742 # Calculate and track cost
743 self._calculate_and_track_cost(response, assistant_metadata)
745 new_tree_node = Tree(
746 ConversationNode(
747 message=assistant_message,
748 node_id="",
749 branch_name=branch_name,
750 metadata=assistant_metadata,
751 )
752 )
754 current_tree_node.add_child(new_tree_node)
755 node_id = calculate_node_id(new_tree_node)
756 new_tree_node.data.node_id = node_id
758 self.state.current_node_id = node_id
759 self.state.updated_at = datetime.now()
761 await self._save_state()
763 async def switch_to_node(self, node_id: str) -> None:
764 """Switch current position to a different node in the tree.
766 This allows exploring different branches or backtracking in the conversation.
768 Args:
769 node_id: Target node ID (dot-delimited, e.g., "0.1" or "")
771 Raises:
772 ValueError: If node_id not found in tree
774 Example:
775 >>> # Go back to first user message
776 >>> await manager.switch_to_node("0")
777 >>>
778 >>> # Create alternative response
779 >>> result = await manager.complete(branch_name="alternative")
780 >>>
781 >>> # Go back to root
782 >>> await manager.switch_to_node("")
783 """
784 if not self.state:
785 raise ValueError("No conversation state")
787 # Verify node exists
788 target_node = get_node_by_id(self.state.message_tree, node_id)
789 if target_node is None:
790 raise ValueError(f"Node '{node_id}' not found in conversation tree")
792 # Update current position
793 self.state.current_node_id = node_id
794 self.state.updated_at = datetime.now()
796 # Persist
797 await self._save_state()
799 async def execute_flow(
800 self,
801 flow: ConversationFlow,
802 initial_params: Dict[str, Any] | None = None
803 ) -> AsyncIterator[ConversationNode]:
804 """Execute a conversation flow using FSM.
806 This method executes a predefined conversation flow, yielding
807 conversation nodes as the flow progresses through states.
809 Args:
810 flow: ConversationFlow definition
811 initial_params: Optional initial parameters for the flow
813 Yields:
814 ConversationNode for each state in the flow
816 Raises:
817 ValueError: If flow execution fails
819 Example:
820 >>> from dataknobs_llm.conversations.flow import (
821 ... ConversationFlow, FlowState,
822 ... keyword_condition
823 ... )
824 >>>
825 >>> # Define flow
826 >>> flow = ConversationFlow(
827 ... name="support",
828 ... initial_state="greeting",
829 ... states={
830 ... "greeting": FlowState(
831 ... prompt_name="support_greeting",
832 ... transitions={
833 ... "help": "collect_issue",
834 ... "browse": "end"
835 ... },
836 ... transition_conditions={
837 ... "help": keyword_condition(["help", "issue"]),
838 ... "browse": keyword_condition(["browse", "look"])
839 ... }
840 ... )
841 ... }
842 ... )
843 >>>
844 >>> # Execute flow
845 >>> async for node in manager.execute_flow(flow):
846 ... print(f"State: {node.metadata.get('state')}")
847 ... print(f"Response: {node.content}")
848 """
849 from dataknobs_llm.conversations.flow import ConversationFlowAdapter
851 if not self.state:
852 raise ValueError("No conversation state")
854 # Create adapter
855 adapter = ConversationFlowAdapter(
856 flow=flow,
857 prompt_builder=self.prompt_builder,
858 llm=self.llm
859 )
861 # Execute flow and yield nodes
862 data = initial_params or {}
863 data["conversation_id"] = self.state.conversation_id
865 try:
866 # Execute flow (this will internally use FSM)
867 await adapter.execute(data)
869 # Convert flow history to conversation nodes
870 for state_name, response in adapter.execution_state.history:
871 # Create node for this state's response
872 node = ConversationNode(
873 node_id=str(uuid.uuid4()), # Temporary ID
874 role="assistant",
875 content=response,
876 timestamp=datetime.now(),
877 metadata={
878 "state": state_name,
879 "flow_name": flow.name,
880 "flow_execution": True
881 }
882 )
884 # Add to conversation tree
885 current_tree_node = get_node_by_id(
886 self.state.message_tree,
887 self.state.current_node_id
888 )
890 new_tree_node = Tree(node)
891 current_tree_node.add_child(new_tree_node)
892 node_id = calculate_node_id(new_tree_node)
893 new_tree_node.data.node_id = node_id
895 self.state.current_node_id = node_id
896 self.state.updated_at = datetime.now()
898 await self._save_state()
900 yield node
902 except Exception as e:
903 import logging
904 logging.error(f"Flow execution failed: {e}")
905 raise ValueError(f"Flow execution failed: {e!s}") from e
907 async def get_history(self) -> List[LLMMessage]:
908 """Get conversation history from root to current position.
910 Returns:
911 List of messages in current conversation path
913 Example:
914 >>> messages = await manager.get_history()
915 >>> for msg in messages:
916 ... print(f"{msg.role}: {msg.content}")
917 """
918 if not self.state:
919 return []
921 return self.state.get_current_messages()
923 async def get_branches(self, node_id: str | None = None) -> List[Dict[str, Any]]:
924 """Get information about branches from a given node.
926 Args:
927 node_id: Node to get branches from (default: current node)
929 Returns:
930 List of branch info dicts with keys:
931 - node_id: ID of child node
932 - branch_name: Optional branch name
933 - role: Message role
934 - preview: First 100 chars of content
935 - timestamp: When created
937 Example:
938 >>> branches = await manager.get_branches()
939 >>> for branch in branches:
940 ... print(f"{branch['branch_name']}: {branch['preview']}")
941 """
942 if not self.state:
943 return []
945 # Default to current node
946 if node_id is None:
947 node_id = self.state.current_node_id
949 # Get node
950 node = get_node_by_id(self.state.message_tree, node_id)
951 if node is None or not node.children:
952 return []
954 # Build branch info
955 branches = []
956 for child in node.children:
957 data = child.data
958 branches.append({
959 "node_id": data.node_id,
960 "branch_name": data.branch_name,
961 "role": data.message.role,
962 "preview": data.message.content[:100],
963 "timestamp": data.timestamp,
964 })
966 return branches
968 async def add_metadata(self, key: str, value: Any) -> None:
969 """Add metadata to conversation.
971 Args:
972 key: Metadata key
973 value: Metadata value
975 Example:
976 >>> await manager.add_metadata("user_id", "alice")
977 >>> await manager.add_metadata("session", "abc123")
978 """
979 if not self.state:
980 raise ValueError("No conversation state")
982 self.state.metadata[key] = value
983 self.state.updated_at = datetime.now()
984 await self._save_state()
986 async def _find_cached_rag(
987 self,
988 prompt_name: str,
989 role: str,
990 params: Dict[str, Any]
991 ) -> Dict[str, Any] | None:
992 """Search conversation history for cached RAG metadata.
994 This method searches the entire conversation tree for cached RAG metadata
995 that matches both the prompt name/role AND the resolved RAG query parameters.
996 Query matching is done via query hashes.
998 Args:
999 prompt_name: Name of the prompt to find cached RAG for
1000 role: Role of the prompt ("system" or "user")
1001 params: Parameters for the prompt (used to match RAG queries)
1003 Returns:
1004 Cached RAG metadata dictionary if found, None otherwise
1006 Example:
1007 >>> cached = await manager._find_cached_rag("code_question", "user", {"topic": "decorators"})
1008 >>> if cached:
1009 ... print(f"Found cached RAG with {len(cached)} placeholders")
1010 """
1011 if not self.state:
1012 return None
1014 # Get RAG configs for this prompt to determine what queries we're looking for
1015 rag_configs = self.prompt_builder.library.get_prompt_rag_configs(
1016 prompt_name=prompt_name,
1017 prompt_type="system" if role == "system" else "user"
1018 )
1020 if not rag_configs:
1021 return None
1023 # Compute the query hashes we're looking for
1024 from jinja2 import Template
1025 target_hashes_by_placeholder = {}
1026 for rag_config in rag_configs:
1027 placeholder = rag_config.get("placeholder", "RAG_CONTENT")
1028 adapter_name = rag_config.get("adapter_name", "")
1029 query_template = rag_config.get("query", "")
1031 # Render the query template with params
1032 try:
1033 template = Template(query_template)
1034 resolved_query = template.render(params)
1036 # Compute hash
1037 query_hash = self.prompt_builder._compute_rag_query_hash(adapter_name, resolved_query)
1038 target_hashes_by_placeholder[placeholder] = query_hash
1039 except Exception:
1040 # If query rendering fails, we can't match cache
1041 continue
1043 if not target_hashes_by_placeholder:
1044 return None
1046 # Search entire tree for matching cached RAG (BFS to find any match)
1047 from collections import deque
1048 queue = deque([self.state.message_tree])
1050 while queue:
1051 tree_node = queue.popleft()
1052 node_data = tree_node.data
1054 # Check if this node has the same prompt name and role
1055 if (node_data.prompt_name == prompt_name and
1056 node_data.message.role == role):
1058 # Check if RAG metadata exists
1059 rag_metadata = node_data.metadata.get("rag_metadata")
1060 if rag_metadata:
1061 # Check if query hashes match for all placeholders
1062 all_match = True
1063 for placeholder, target_hash in target_hashes_by_placeholder.items():
1064 if placeholder not in rag_metadata:
1065 all_match = False
1066 break
1067 cached_hash = rag_metadata[placeholder].get("query_hash")
1068 if cached_hash != target_hash:
1069 all_match = False
1070 break
1072 if all_match:
1073 return rag_metadata
1075 # Add children to queue (if any)
1076 if tree_node.children:
1077 queue.extend(tree_node.children)
1079 return None
1081 def get_rag_metadata(self, node_id: str | None = None) -> Dict[str, Any] | None:
1082 """Get RAG metadata from a conversation node.
1084 This method retrieves the cached RAG metadata from a specific node,
1085 which includes information about RAG searches executed during prompt
1086 rendering (queries, results, query hashes, etc.).
1088 Args:
1089 node_id: Node ID to retrieve metadata from (default: current node)
1091 Returns:
1092 RAG metadata dictionary if present, None otherwise. Structure:
1094 ```python
1095 {
1096 "PLACEHOLDER_NAME": {
1097 "query": "resolved RAG query",
1098 "query_hash": "hash of adapter+query",
1099 "results": [...], # Search results
1100 "adapter_name": "name of RAG adapter used"
1101 },
1102 ... # One entry per RAG placeholder
1103 }
1104 ```
1106 Raises:
1107 ValueError: If node_id not found in conversation tree
1109 Example:
1110 ```python
1111 # Get RAG metadata from current node
1112 metadata = manager.get_rag_metadata()
1113 if metadata:
1114 for placeholder, rag_data in metadata.items():
1115 print(f"Placeholder: {placeholder}")
1116 print(f" Query: {rag_data['query']}")
1117 print(f" Adapter: {rag_data['adapter_name']}")
1118 print(f" Results: {len(rag_data['results'])} items")
1119 print(f" Hash: {rag_data['query_hash']}")
1121 # Get RAG metadata from specific node
1122 metadata = manager.get_rag_metadata(node_id="0.1")
1124 # Check if RAG was used for a message
1125 if manager.get_rag_metadata():
1126 print("This message used RAG-enhanced prompt")
1127 else:
1128 print("This message used direct content")
1129 ```
1131 Note:
1132 RAG metadata is only available if `cache_rag_results=True` was
1133 set during ConversationManager creation. This metadata is useful
1134 for debugging RAG behavior, understanding what information was
1135 retrieved, and implementing RAG result caching across branches.
1137 See Also:
1138 add_message: Method that executes RAG and stores metadata
1139 reuse_rag_on_branch: Parameter enabling RAG cache reuse
1140 """
1141 if not self.state:
1142 return None
1144 # Default to current node
1145 if node_id is None:
1146 node_id = self.state.current_node_id
1148 # Get node
1149 tree_node = get_node_by_id(self.state.message_tree, node_id)
1150 if tree_node is None:
1151 raise ValueError(f"Node '{node_id}' not found in conversation tree")
1153 # Return RAG metadata if present
1154 return tree_node.data.metadata.get("rag_metadata")
1156 async def _save_state(self) -> None:
1157 """Persist current state to storage."""
1158 if self.state:
1159 await self.storage.save_conversation(self.state)
1161 @property
1162 def conversation_id(self) -> str | None:
1163 """Get conversation ID."""
1164 return self.state.conversation_id if self.state else None
1166 @property
1167 def current_node_id(self) -> str | None:
1168 """Get current node ID."""
1169 return self.state.current_node_id if self.state else None
1171 def get_metadata(self, key: str | None = None, default: Any = None) -> Any:
1172 """Get conversation metadata.
1174 This provides access to the conversation-level metadata stored in
1175 the ConversationState. Metadata is useful for storing client_id,
1176 user_id, session information, and other contextual data.
1178 Args:
1179 key: Specific metadata key to retrieve. If None, returns all metadata.
1180 default: Default value if key not found (only used when key is specified)
1182 Returns:
1183 Metadata value, all metadata dict, or default value
1185 Example:
1186 >>> # Get all metadata
1187 >>> metadata = manager.get_metadata()
1188 >>> print(metadata) # {'client_id': 'abc', 'user_id': '123'}
1189 >>>
1190 >>> # Get specific key
1191 >>> client_id = manager.get_metadata('client_id')
1192 >>> print(client_id) # 'abc'
1193 >>>
1194 >>> # Get with default
1195 >>> tier = manager.get_metadata('user_tier', default='free')
1196 """
1197 if not self.state:
1198 return default if key else {}
1200 if key is None:
1201 return self.state.metadata
1202 else:
1203 return self.state.metadata.get(key, default)
1205 def set_metadata(self, key: str, value: Any) -> None:
1206 """Set conversation metadata.
1208 Updates a specific key in the conversation metadata. The metadata
1209 is automatically persisted when save() is called.
1211 Args:
1212 key: Metadata key to set
1213 value: Metadata value
1215 Example:
1216 >>> manager.set_metadata('client_id', 'client-abc')
1217 >>> manager.set_metadata('user_tier', 'premium')
1218 >>> await manager.save()
1219 """
1220 if self.state:
1221 self.state.metadata[key] = value
1223 def update_metadata(self, updates: Dict[str, Any]) -> None:
1224 """Update multiple metadata fields at once.
1226 Args:
1227 updates: Dictionary of metadata key-value pairs to update
1229 Example:
1230 >>> manager.update_metadata({
1231 ... 'client_id': 'client-abc',
1232 ... 'user_id': 'user-456',
1233 ... 'session_id': 'sess-789'
1234 ... })
1235 >>> await manager.save()
1236 """
1237 if self.state:
1238 self.state.metadata.update(updates)
1240 def remove_metadata(self, key: str) -> None:
1241 """Remove a metadata key.
1243 Args:
1244 key: Metadata key to remove
1246 Example:
1247 >>> manager.remove_metadata('temporary_flag')
1248 >>> await manager.save()
1249 """
1250 if self.state and key in self.state.metadata:
1251 del self.state.metadata[key]
1253 def get_total_cost(self) -> float:
1254 """Get total accumulated cost for this conversation in USD.
1256 Calculates the sum of all LLM API costs from the conversation history.
1257 Requires that cost_usd was set on LLMResponses.
1259 Returns:
1260 Total cost in USD, or 0.0 if no cost data available
1262 Example:
1263 >>> total = manager.get_total_cost()
1264 >>> print(f"Total cost: ${total:.4f}")
1265 """
1266 if not self.state:
1267 return 0.0
1269 total = 0.0
1271 # Walk the tree and sum costs from all assistant message nodes
1272 def walk_tree(node: Tree) -> None:
1273 nonlocal total
1274 if node.data and node.data.metadata:
1275 cost = node.data.metadata.get('cost_usd')
1276 if cost is not None:
1277 total += cost
1279 for child in node.children:
1280 walk_tree(child)
1282 walk_tree(self.state.message_tree)
1283 return total
1285 def get_cost_by_branch(self, node_id: str | None = None) -> float:
1286 """Get accumulated cost for a specific conversation branch.
1288 Calculates the cost from root to a specific node (defaults to current).
1290 Args:
1291 node_id: Node ID to calculate cost to. If None, uses current node.
1293 Returns:
1294 Cost in USD for this branch, or 0.0 if no cost data
1296 Example:
1297 >>> # Get cost of current branch
1298 >>> current_cost = manager.get_cost_by_branch()
1299 >>>
1300 >>> # Get cost of specific branch
1301 >>> alt_cost = manager.get_cost_by_branch("0.1")
1302 """
1303 if not self.state:
1304 return 0.0
1306 target_node_id = node_id or self.state.current_node_id
1308 # Get messages in this branch
1310 # Walk from root to target node
1311 if not target_node_id or target_node_id == "":
1312 # Just root node
1313 return 0.0
1315 indexes = [int(i) for i in target_node_id.split(".")]
1317 total = 0.0
1318 current = self.state.message_tree
1320 for idx in indexes:
1321 if idx < len(current.children):
1322 current = current.children[idx]
1323 if current.data and current.data.metadata:
1324 cost = current.data.metadata.get('cost_usd')
1325 if cost is not None:
1326 total += cost
1328 return total
1330 def _calculate_and_track_cost(
1331 self,
1332 response: LLMResponse,
1333 metadata: Dict[str, Any]
1334 ) -> None:
1335 """Calculate cost for a response and add to metadata.
1337 This is an internal helper that uses the CostCalculator utility
1338 to estimate costs and track them in the conversation.
1340 Args:
1341 response: LLM response to calculate cost for
1342 metadata: Metadata dict to add cost information to
1343 """
1344 try:
1345 from dataknobs_llm.llm.utils import CostCalculator
1347 if response.usage:
1348 cost = CostCalculator.calculate_cost(response, response.model)
1349 if cost is not None:
1350 # Add to response
1351 response.cost_usd = cost
1353 # Calculate cumulative cost
1354 cumulative = self.get_total_cost() + cost
1355 response.cumulative_cost_usd = cumulative
1357 # Store in metadata
1358 metadata['cost_usd'] = cost
1359 metadata['cumulative_cost_usd'] = cumulative
1360 except Exception as e:
1361 # Don't fail the conversation if cost calculation fails
1362 import logging
1363 logger = logging.getLogger(__name__)
1364 logger.warning(f"Failed to calculate cost: {e}")