Coverage for src / dataknobs_llm / conversations / manager.py: 39%
314 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-15 10:29 -0700
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-15 10:29 -0700
1"""Conversation manager for multi-turn interactions with LLMs.
3This module provides ConversationManager, a comprehensive system for managing
4multi-turn LLM conversations with advanced features like:
6- **Tree-based History**: Conversations stored as trees, enabling branching
7- **Persistence**: Automatic state saving to any storage backend
8- **RAG Caching**: Reuse search results across conversation branches
9- **Middleware**: Pre/post-processing pipeline for all LLM calls
10- **Cost Tracking**: Automatic API cost calculation and accumulation
11- **Flow Execution**: FSM-based conversation flows with state management
12- **Resumability**: Save and resume conversations across sessions
14Architecture:
15 ConversationManager orchestrates three core components:
17 1. **AsyncLLMProvider**: Handles LLM API calls (OpenAI, Anthropic, Ollama)
18 2. **AsyncPromptBuilder**: Renders prompts with RAG integration
19 3. **ConversationStorage**: Persists conversation state (Memory, File, S3, Postgres)
21 Conversations are stored as trees where each node represents a message.
22 Branching occurs when multiple responses are generated from the same point,
23 enabling A/B testing, alternative explorations, and retry scenarios.
25Example:
26 ```python
27 from dataknobs_llm import create_llm_provider
28 from dataknobs_llm.prompts import AsyncPromptBuilder
29 from dataknobs_llm.conversations import (
30 ConversationManager,
31 DataknobsConversationStorage
32 )
33 from dataknobs_data import database_factory
35 # Setup components
36 llm = create_llm_provider("openai", api_key="sk-...")
37 builder = AsyncPromptBuilder.create(library_path="./prompts")
38 db = database_factory.create(backend="memory")
39 storage = DataknobsConversationStorage(db)
41 # Create conversation
42 manager = await ConversationManager.create(
43 llm=llm,
44 prompt_builder=builder,
45 storage=storage,
46 system_prompt_name="helpful_assistant"
47 )
49 # Add user message and get response
50 await manager.add_message(
51 role="user",
52 content="What is Python?"
53 )
54 response = await manager.complete()
55 print(response.content)
57 # Continue conversation
58 await manager.add_message(
59 role="user",
60 content="Show me a decorator example"
61 )
62 response = await manager.complete()
64 # Create alternative response branch
65 await manager.switch_to_node("0.0") # Back to first assistant response
66 alt_response = await manager.complete(branch_name="alternative")
68 # Resume later
69 conv_id = manager.conversation_id
70 manager2 = await ConversationManager.resume(
71 conversation_id=conv_id,
72 llm=llm,
73 prompt_builder=builder,
74 storage=storage
75 )
76 ```
78See Also:
79 - ConversationStorage: Storage interface and implementations
80 - ConversationMiddleware: Middleware system for request/response processing
81 - ConversationFlow: FSM-based conversation flows
82 - AsyncPromptBuilder: Prompt rendering with RAG integration
83"""
85import uuid
86from typing import List, Dict, Any, AsyncIterator
87from datetime import datetime
89from dataknobs_structures.tree import Tree
90from dataknobs_llm.llm import AsyncLLMProvider, LLMMessage, LLMResponse, LLMStreamResponse
91from dataknobs_llm.prompts import AsyncPromptBuilder
92from dataknobs_llm.conversations.flow.flow import ConversationFlow
93from dataknobs_llm.conversations.middleware import ConversationMiddleware
94from dataknobs_llm.conversations.storage import (
95 ConversationNode,
96 ConversationState,
97 ConversationStorage,
98 calculate_node_id,
99 get_node_by_id,
100)
103class ConversationManager:
104 """Manages multi-turn conversations with persistence and branching.
106 This class orchestrates conversations by:
107 - Tracking message history with tree-based branching
108 - Managing conversation state
109 - Persisting to storage backend
110 - Supporting multiple conversation branches
112 The conversation history is stored as a tree structure where:
113 - Root node contains the initial system prompt (if any)
114 - Each message is a tree node with a dot-delimited ID (e.g., "0.1.2")
115 - Branches occur when multiple children are added to the same node
116 - Current position tracks where you are in the conversation tree
118 Attributes:
119 llm: LLM provider for completions
120 prompt_builder: Prompt builder with library
121 storage: Storage backend for persistence
122 state: Current conversation state (tree, metadata, position)
123 middleware: List of middleware to execute on requests/responses
124 cache_rag_results: Whether to store RAG metadata in nodes
125 reuse_rag_on_branch: Whether to reuse cached RAG across branches
126 conversation_id: Unique conversation identifier
127 current_node_id: Current position in conversation tree
129 Example:
130 ```python
131 # Create conversation
132 manager = await ConversationManager.create(
133 llm=llm,
134 prompt_builder=builder,
135 storage=storage_backend,
136 system_prompt_name="helpful_assistant"
137 )
139 # Add user message
140 await manager.add_message(
141 prompt_name="user_query",
142 params={"question": "What is Python?"},
143 role="user"
144 )
146 # Get LLM response
147 result = await manager.complete()
149 # Continue conversation
150 await manager.add_message(
151 content="Tell me more about decorators",
152 role="user"
153 )
154 result = await manager.complete()
156 # Create alternative response branch
157 await manager.switch_to_node("0") # Back to first user message
158 result2 = await manager.complete(branch_name="alt-response")
160 # Resume after interruption
161 manager2 = await ConversationManager.resume(
162 conversation_id=manager.conversation_id,
163 llm=llm,
164 prompt_builder=builder,
165 storage=storage_backend
166 )
167 ```
169 Note:
170 Tree-based branching enables:
172 - **A/B Testing**: Generate multiple responses from the same context
173 - **Retry Logic**: Try again from a previous point after failures
174 - **Alternative Explorations**: Explore different conversation paths
175 - **Debugging**: Compare different middleware or RAG configurations
177 Node IDs use dot notation (e.g., "0.1.2" means 3rd child of 2nd child
178 of 1st child of root). The root node has ID "".
180 State is automatically persisted after every operation. Use
181 `resume()` to continue conversations across sessions or servers.
183 See Also:
184 create: Create a new conversation
185 resume: Resume an existing conversation
186 add_message: Add user/system message
187 complete: Get LLM completion (blocking)
188 stream_complete: Get LLM completion (streaming)
189 switch_to_node: Navigate to different branch
190 get_branches: List available branches
191 get_total_cost: Calculate cumulative cost
192 ConversationStorage: Storage backend implementations
193 ConversationMiddleware: Request/response processing
194 ConversationFlow: FSM-based conversation flows
195 """
197 def __init__(
198 self,
199 llm: AsyncLLMProvider,
200 prompt_builder: AsyncPromptBuilder,
201 storage: ConversationStorage,
202 state: ConversationState | None = None,
203 metadata: Dict[str, Any] | None = None,
204 middleware: List[ConversationMiddleware] | None = None,
205 cache_rag_results: bool = False,
206 reuse_rag_on_branch: bool = False,
207 ):
208 """Initialize conversation manager.
210 Note: Use ConversationManager.create() or ConversationManager.resume()
211 instead of calling __init__ directly.
213 Args:
214 llm: LLM provider for completions
215 prompt_builder: Prompt builder with library
216 storage: Storage backend for persistence
217 state: Optional existing conversation state
218 metadata: Optional metadata for new conversations
219 middleware: Optional list of middleware to execute
220 cache_rag_results: If True, store RAG metadata in node metadata
221 for debugging and transparency
222 reuse_rag_on_branch: If True, reuse cached RAG results when
223 possible (useful for testing/branching)
224 """
225 self.llm = llm
226 self.prompt_builder = prompt_builder
227 self.storage = storage
228 self.state = state
229 self._initial_metadata = metadata or {}
230 self.middleware = middleware or []
231 self.cache_rag_results = cache_rag_results
232 self.reuse_rag_on_branch = reuse_rag_on_branch
234 @classmethod
235 async def create(
236 cls,
237 llm: AsyncLLMProvider,
238 prompt_builder: AsyncPromptBuilder,
239 storage: ConversationStorage,
240 system_prompt_name: str | None = None,
241 system_params: Dict[str, Any] | None = None,
242 metadata: Dict[str, Any] | None = None,
243 middleware: List[ConversationMiddleware] | None = None,
244 cache_rag_results: bool = False,
245 reuse_rag_on_branch: bool = False,
246 ) -> "ConversationManager":
247 """Create a new conversation.
249 Args:
250 llm: LLM provider
251 prompt_builder: Prompt builder
252 storage: Storage backend
253 system_prompt_name: Optional system prompt to initialize with
254 system_params: Optional params for system prompt
255 metadata: Optional conversation metadata
256 middleware: Optional list of middleware to execute
257 cache_rag_results: If True, store RAG metadata in node metadata
258 reuse_rag_on_branch: If True, reuse cached RAG results when possible
260 Returns:
261 Initialized ConversationManager
263 Example:
264 >>> manager = await ConversationManager.create(
265 ... llm=llm,
266 ... prompt_builder=builder,
267 ... storage=storage,
268 ... system_prompt_name="helpful_assistant",
269 ... cache_rag_results=True
270 ... )
271 """
272 manager = cls(
273 llm=llm,
274 prompt_builder=prompt_builder,
275 storage=storage,
276 metadata=metadata,
277 middleware=middleware,
278 cache_rag_results=cache_rag_results,
279 reuse_rag_on_branch=reuse_rag_on_branch,
280 )
282 # Initialize with system prompt if provided
283 if system_prompt_name:
284 await manager.add_message(
285 prompt_name=system_prompt_name,
286 params=system_params,
287 role="system",
288 )
290 return manager
292 @classmethod
293 async def resume(
294 cls,
295 conversation_id: str,
296 llm: AsyncLLMProvider,
297 prompt_builder: AsyncPromptBuilder,
298 storage: ConversationStorage,
299 middleware: List[ConversationMiddleware] | None = None,
300 cache_rag_results: bool = False,
301 reuse_rag_on_branch: bool = False,
302 ) -> "ConversationManager":
303 """Resume an existing conversation.
305 Args:
306 conversation_id: Existing conversation ID
307 llm: LLM provider
308 prompt_builder: Prompt builder
309 storage: Storage backend
310 middleware: Optional list of middleware to execute
311 cache_rag_results: If True, store RAG metadata in node metadata
312 reuse_rag_on_branch: If True, reuse cached RAG results when possible
314 Returns:
315 ConversationManager with restored state
317 Raises:
318 ValueError: If conversation not found
320 Example:
321 >>> manager = await ConversationManager.resume(
322 ... conversation_id="conv-123",
323 ... llm=llm,
324 ... prompt_builder=builder,
325 ... storage=storage,
326 ... cache_rag_results=True
327 ... )
328 """
329 # Load state from storage
330 state = await storage.load_conversation(conversation_id)
331 if not state:
332 raise ValueError(f"Conversation '{conversation_id}' not found")
334 # Create manager with existing state
335 manager = cls(
336 llm=llm,
337 prompt_builder=prompt_builder,
338 storage=storage,
339 state=state,
340 middleware=middleware,
341 cache_rag_results=cache_rag_results,
342 reuse_rag_on_branch=reuse_rag_on_branch,
343 )
345 return manager
347 async def add_message(
348 self,
349 role: str,
350 content: str | None = None,
351 prompt_name: str | None = None,
352 params: Dict[str, Any] | None = None,
353 include_rag: bool = True,
354 rag_configs: List[Dict[str, Any]] | None = None,
355 metadata: Dict[str, Any] | None = None,
356 ) -> ConversationNode:
357 """Add a message to the current conversation node.
359 Either content or prompt_name must be provided. If using a prompt
360 with RAG configuration, the RAG searches will be executed and results
361 will be automatically inserted into the prompt.
363 Args:
364 role: Message role ("system", "user", or "assistant")
365 content: Direct message content (if not using prompt)
366 prompt_name: Name of prompt template to render
367 params: Parameters for prompt rendering
368 include_rag: Whether to execute RAG searches for prompts
369 rag_configs: RAG configurations for inline content (only used when
370 content is provided without prompt_name). Allows inline
371 prompts to benefit from RAG enhancement.
372 metadata: Optional metadata for this message node
374 Returns:
375 The created ConversationNode
377 Raises:
378 ValueError: If neither content nor prompt_name provided
380 Example:
381 ```python
382 # Add message from prompt
383 await manager.add_message(
384 role="user",
385 prompt_name="code_question",
386 params={"code": code_snippet}
387 )
389 # Add direct message
390 await manager.add_message(
391 role="user",
392 content="What is Python?"
393 )
395 # Add inline message with RAG enhancement
396 await manager.add_message(
397 role="system",
398 content="You are a helpful assistant. Use the context below.",
399 rag_configs=[{
400 "adapter_name": "docs",
401 "query": "assistant guidelines",
402 "placeholder": "CONTEXT",
403 "k": 3
404 }]
405 )
407 # Add system prompt with custom metadata
408 await manager.add_message(
409 role="system",
410 prompt_name="expert_coder",
411 metadata={"version": "v2"}
412 )
413 ```
415 Note:
416 **RAG Caching Behavior**:
418 If `cache_rag_results=True` and `reuse_rag_on_branch=True` were
419 set during ConversationManager creation, this method will:
421 1. Check if the same prompt+role was used elsewhere in the tree
422 2. Check if the RAG query parameters match (via query hash)
423 3. Reuse cached RAG results if found (no re-search!)
424 4. Store new RAG results if not found
426 This is particularly useful when exploring conversation branches,
427 as you can avoid redundant searches for the same information.
429 See Also:
430 complete: Get LLM response after adding message
431 get_rag_metadata: Retrieve RAG metadata from a node
432 """
433 if not content and not prompt_name:
434 raise ValueError("Either content or prompt_name must be provided")
436 # Render prompt if needed
437 rag_metadata_to_store = None
438 if prompt_name:
439 params = params or {}
441 # Check if we should try to reuse cached RAG
442 cached_rag = None
443 if self.reuse_rag_on_branch and include_rag:
444 cached_rag = await self._find_cached_rag(prompt_name, role, params)
446 if role == "system":
447 result = await self.prompt_builder.render_system_prompt(
448 prompt_name,
449 params=params,
450 include_rag=include_rag,
451 return_rag_metadata=self.cache_rag_results,
452 cached_rag=cached_rag,
453 )
454 elif role == "user":
455 result = await self.prompt_builder.render_user_prompt(
456 prompt_name,
457 params=params,
458 include_rag=include_rag,
459 return_rag_metadata=self.cache_rag_results,
460 cached_rag=cached_rag,
461 )
462 else:
463 raise ValueError(f"Cannot render prompt for role '{role}'")
465 content = result.content
467 # Store RAG metadata if caching is enabled and metadata was captured
468 if self.cache_rag_results and result.rag_metadata:
469 rag_metadata_to_store = result.rag_metadata
471 elif content and include_rag and rag_configs:
472 # Render inline content with RAG enhancement
473 params = params or {}
474 if role == "system":
475 result = await self.prompt_builder.render_inline_system_prompt(
476 content,
477 params=params,
478 rag_configs=rag_configs,
479 include_rag=True,
480 return_rag_metadata=self.cache_rag_results,
481 )
482 elif role == "user":
483 result = await self.prompt_builder.render_inline_user_prompt(
484 content,
485 params=params,
486 rag_configs=rag_configs,
487 include_rag=True,
488 return_rag_metadata=self.cache_rag_results,
489 )
490 else:
491 # For assistant role, just use content as-is
492 result = None
494 if result:
495 content = result.content
496 if self.cache_rag_results and result.rag_metadata:
497 rag_metadata_to_store = result.rag_metadata
499 # Create message
500 message = LLMMessage(role=role, content=content)
502 # Prepare node metadata
503 node_metadata = metadata or {}
504 if rag_metadata_to_store:
505 node_metadata["rag_metadata"] = rag_metadata_to_store
507 # Initialize state if this is the first message
508 if self.state is None:
509 conversation_id = str(uuid.uuid4())
510 root_node = ConversationNode(
511 message=message,
512 node_id="",
513 prompt_name=prompt_name,
514 metadata=node_metadata,
515 )
516 tree = Tree(root_node)
517 self.state = ConversationState(
518 conversation_id=conversation_id,
519 message_tree=tree,
520 current_node_id="",
521 metadata=self._initial_metadata,
522 )
523 else:
524 # Add as child of current node
525 current_tree_node = self.state.get_current_node()
526 if current_tree_node is None:
527 raise ValueError(f"Current node '{self.state.current_node_id}' not found")
529 # Create new tree node
530 new_tree_node = Tree(
531 ConversationNode(
532 message=message,
533 node_id="", # Will be calculated after adding to tree
534 prompt_name=prompt_name,
535 metadata=node_metadata,
536 )
537 )
539 # Add to tree
540 current_tree_node.add_child(new_tree_node)
542 # Calculate and set node_id
543 node_id = calculate_node_id(new_tree_node)
544 new_tree_node.data.node_id = node_id
546 # Move current position to new node
547 self.state.current_node_id = node_id
549 # Update timestamp
550 self.state.updated_at = datetime.now()
552 # Persist
553 await self._save_state()
555 return self.state.get_current_node().data
557 async def complete(
558 self,
559 branch_name: str | None = None,
560 metadata: Dict[str, Any] | None = None,
561 llm_config_overrides: Dict[str, Any] | None = None,
562 **llm_kwargs: Any,
563 ) -> LLMResponse:
564 """Get LLM completion and add as child of current node.
566 This method:
567 1. Gets conversation history from root to current node
568 2. Executes middleware (pre-LLM)
569 3. Calls LLM with history
570 4. Executes middleware (post-LLM)
571 5. Adds assistant response as child of current node
572 6. Updates current position to new node
573 7. Persists to storage
575 Args:
576 branch_name: Optional human-readable label for this branch
577 metadata: Optional metadata for the assistant message node
578 llm_config_overrides: Optional dict to override LLM config fields
579 for this request only. Supported fields: model, temperature,
580 max_tokens, top_p, stop_sequences, seed.
581 **llm_kwargs: Additional arguments for LLM.complete()
583 Returns:
584 LLM response with content, usage, and cost information
586 Raises:
587 ValueError: If conversation has no messages yet
589 Example:
590 ```python
591 # Get response
592 result = await manager.complete()
593 print(result.content)
594 print(f"Cost: ${result.cost_usd:.4f}")
596 # Create labeled branch
597 result = await manager.complete(branch_name="alternative-answer")
599 # With LLM parameters
600 result = await manager.complete(temperature=0.9, max_tokens=500)
602 # With config overrides (switch model per-request)
603 result = await manager.complete(
604 llm_config_overrides={"model": "gpt-4-turbo", "temperature": 0.9}
605 )
606 ```
608 Note:
609 **Middleware Execution Order** (Onion Model):
611 - Pre-LLM: middleware[0] → middleware[1] → ... → middleware[N]
612 - LLM call happens
613 - Post-LLM: middleware[N] → ... → middleware[1] → middleware[0]
615 This "onion" pattern ensures that middleware wraps around the LLM
616 call symmetrically. For example, if middleware[0] starts a timer
617 in `process_request()`, it will be the last to run in
618 `process_response()` and can log the total elapsed time.
620 **Automatic Cost Tracking**:
622 The response includes `cost_usd` (this call) and `cumulative_cost_usd`
623 (total conversation cost) if the LLM provider returns usage statistics.
625 See Also:
626 stream_complete: Streaming version for real-time output
627 add_message: Add user/system message before calling complete
628 switch_to_node: Navigate to different branch before completing
629 """
630 if not self.state:
631 raise ValueError("Cannot complete: no messages in conversation")
633 # Get messages from root to current position
634 messages = self.state.get_current_messages()
636 # Execute middleware (pre-LLM) in forward order
637 for mw in self.middleware:
638 messages = await mw.process_request(messages, self.state)
640 # Call LLM with config overrides if provided
641 response = await self.llm.complete(
642 messages,
643 config_overrides=llm_config_overrides,
644 **llm_kwargs
645 )
647 # Execute middleware (post-LLM) in reverse order (onion model)
648 for mw in reversed(self.middleware):
649 response = await mw.process_response(response, self.state)
651 # Add assistant message as child
652 current_tree_node = self.state.get_current_node()
653 if current_tree_node is None:
654 raise ValueError(f"Current node '{self.state.current_node_id}' not found")
656 # Create assistant message node
657 assistant_message = LLMMessage(
658 role="assistant",
659 content=response.content,
660 )
662 assistant_metadata = metadata or {}
663 assistant_metadata.update({
664 "usage": response.usage,
665 "model": response.model,
666 "finish_reason": response.finish_reason,
667 })
669 # Track config overrides if they were applied
670 if llm_config_overrides:
671 assistant_metadata["config_overrides_applied"] = llm_config_overrides
673 # Calculate and track cost
674 self._calculate_and_track_cost(response, assistant_metadata)
676 new_tree_node = Tree(
677 ConversationNode(
678 message=assistant_message,
679 node_id="", # Will be calculated
680 branch_name=branch_name,
681 metadata=assistant_metadata,
682 )
683 )
685 # Add to tree
686 current_tree_node.add_child(new_tree_node)
688 # Calculate node_id
689 node_id = calculate_node_id(new_tree_node)
690 new_tree_node.data.node_id = node_id
692 # Move current position
693 self.state.current_node_id = node_id
694 self.state.updated_at = datetime.now()
696 # Persist
697 await self._save_state()
699 return response
701 async def stream_complete(
702 self,
703 branch_name: str | None = None,
704 metadata: Dict[str, Any] | None = None,
705 llm_config_overrides: Dict[str, Any] | None = None,
706 **llm_kwargs,
707 ) -> AsyncIterator[LLMStreamResponse]:
708 r"""Stream LLM completion and add as child of current node.
710 Similar to complete() but streams the response incrementally for
711 real-time display. The complete response is automatically added
712 to the conversation tree after streaming finishes.
714 Args:
715 branch_name: Optional human-readable label for this branch
716 metadata: Optional metadata for the assistant message node
717 llm_config_overrides: Optional dict to override LLM config fields
718 for this request only. Supported fields: model, temperature,
719 max_tokens, top_p, stop_sequences, seed.
720 **llm_kwargs: Additional arguments for LLM.stream_complete()
722 Yields:
723 Streaming response chunks with delta, usage, and final metadata
725 Raises:
726 ValueError: If conversation has no messages yet
728 Example:
729 ```python
730 # Real-time display
731 async for chunk in manager.stream_complete():
732 print(chunk.delta, end="", flush=True)
733 print() # New line after streaming
735 # Accumulate full response
736 full_text = ""
737 async for chunk in manager.stream_complete():
738 full_text += chunk.delta
739 if chunk.is_final:
740 print(f"\nFinished. Total: {len(full_text)} chars")
741 print(f"Cost: ${chunk.usage.get('cost_usd', 0):.4f}")
743 # With config overrides (switch model per-request)
744 async for chunk in manager.stream_complete(
745 llm_config_overrides={"model": "gpt-4-turbo", "temperature": 0.9}
746 ):
747 print(chunk.delta, end="", flush=True)
748 ```
750 Note:
751 The middleware execution order is the same as `complete()`:
752 pre-LLM middleware runs before streaming starts, post-LLM
753 middleware runs after the stream completes.
755 Cost and usage information is only available in the final chunk
756 (when `chunk.is_final == True`).
758 See Also:
759 complete: Non-streaming version for simple use cases
760 add_message: Add message before streaming
761 """
762 if not self.state:
763 raise ValueError("Cannot complete: no messages in conversation")
765 # Get messages
766 messages = self.state.get_current_messages()
768 # Execute middleware (pre-LLM) in forward order
769 for mw in self.middleware:
770 messages = await mw.process_request(messages, self.state)
772 # Stream LLM response and accumulate
773 full_content = ""
774 final_chunk = None
775 async for chunk in self.llm.stream_complete(
776 messages,
777 config_overrides=llm_config_overrides,
778 **llm_kwargs
779 ):
780 full_content += chunk.delta
781 final_chunk = chunk
782 yield chunk
784 # Create complete response for state update
785 response = LLMResponse(
786 content=full_content,
787 model=self.llm.config.model,
788 finish_reason=final_chunk.finish_reason if final_chunk else "stop",
789 usage=final_chunk.usage if final_chunk else None,
790 )
792 # Execute middleware (post-LLM) in reverse order (onion model)
793 for mw in reversed(self.middleware):
794 response = await mw.process_response(response, self.state)
796 # Add assistant message as child (same as complete())
797 current_tree_node = self.state.get_current_node()
798 if current_tree_node is None:
799 raise ValueError(f"Current node '{self.state.current_node_id}' not found")
801 assistant_message = LLMMessage(role="assistant", content=response.content)
803 assistant_metadata = metadata or {}
804 assistant_metadata.update({
805 "usage": response.usage,
806 "model": response.model,
807 "finish_reason": response.finish_reason,
808 })
810 # Track config overrides if they were applied
811 if llm_config_overrides:
812 assistant_metadata["config_overrides_applied"] = llm_config_overrides
814 # Calculate and track cost
815 self._calculate_and_track_cost(response, assistant_metadata)
817 new_tree_node = Tree(
818 ConversationNode(
819 message=assistant_message,
820 node_id="",
821 branch_name=branch_name,
822 metadata=assistant_metadata,
823 )
824 )
826 current_tree_node.add_child(new_tree_node)
827 node_id = calculate_node_id(new_tree_node)
828 new_tree_node.data.node_id = node_id
830 self.state.current_node_id = node_id
831 self.state.updated_at = datetime.now()
833 await self._save_state()
835 async def switch_to_node(self, node_id: str) -> None:
836 """Switch current position to a different node in the tree.
838 This allows exploring different branches or backtracking in the conversation.
840 Args:
841 node_id: Target node ID (dot-delimited, e.g., "0.1" or "")
843 Raises:
844 ValueError: If node_id not found in tree
846 Example:
847 >>> # Go back to first user message
848 >>> await manager.switch_to_node("0")
849 >>>
850 >>> # Create alternative response
851 >>> result = await manager.complete(branch_name="alternative")
852 >>>
853 >>> # Go back to root
854 >>> await manager.switch_to_node("")
855 """
856 if not self.state:
857 raise ValueError("No conversation state")
859 # Verify node exists
860 target_node = get_node_by_id(self.state.message_tree, node_id)
861 if target_node is None:
862 raise ValueError(f"Node '{node_id}' not found in conversation tree")
864 # Update current position
865 self.state.current_node_id = node_id
866 self.state.updated_at = datetime.now()
868 # Persist
869 await self._save_state()
871 async def execute_flow(
872 self,
873 flow: ConversationFlow,
874 initial_params: Dict[str, Any] | None = None
875 ) -> AsyncIterator[ConversationNode]:
876 """Execute a conversation flow using FSM.
878 This method executes a predefined conversation flow, yielding
879 conversation nodes as the flow progresses through states.
881 Args:
882 flow: ConversationFlow definition
883 initial_params: Optional initial parameters for the flow
885 Yields:
886 ConversationNode for each state in the flow
888 Raises:
889 ValueError: If flow execution fails
891 Example:
892 >>> from dataknobs_llm.conversations.flow import (
893 ... ConversationFlow, FlowState,
894 ... keyword_condition
895 ... )
896 >>>
897 >>> # Define flow
898 >>> flow = ConversationFlow(
899 ... name="support",
900 ... initial_state="greeting",
901 ... states={
902 ... "greeting": FlowState(
903 ... prompt_name="support_greeting",
904 ... transitions={
905 ... "help": "collect_issue",
906 ... "browse": "end"
907 ... },
908 ... transition_conditions={
909 ... "help": keyword_condition(["help", "issue"]),
910 ... "browse": keyword_condition(["browse", "look"])
911 ... }
912 ... )
913 ... }
914 ... )
915 >>>
916 >>> # Execute flow
917 >>> async for node in manager.execute_flow(flow):
918 ... print(f"State: {node.metadata.get('state')}")
919 ... print(f"Response: {node.content}")
920 """
921 from dataknobs_llm.conversations.flow import ConversationFlowAdapter
923 if not self.state:
924 raise ValueError("No conversation state")
926 # Create adapter
927 adapter = ConversationFlowAdapter(
928 flow=flow,
929 prompt_builder=self.prompt_builder,
930 llm=self.llm
931 )
933 # Execute flow and yield nodes
934 data = initial_params or {}
935 data["conversation_id"] = self.state.conversation_id
937 try:
938 # Execute flow (this will internally use FSM)
939 await adapter.execute(data)
941 # Convert flow history to conversation nodes
942 for state_name, response in adapter.execution_state.history:
943 # Create node for this state's response
944 node = ConversationNode(
945 node_id=str(uuid.uuid4()), # Temporary ID
946 role="assistant",
947 content=response,
948 timestamp=datetime.now(),
949 metadata={
950 "state": state_name,
951 "flow_name": flow.name,
952 "flow_execution": True
953 }
954 )
956 # Add to conversation tree
957 current_tree_node = get_node_by_id(
958 self.state.message_tree,
959 self.state.current_node_id
960 )
962 new_tree_node = Tree(node)
963 current_tree_node.add_child(new_tree_node)
964 node_id = calculate_node_id(new_tree_node)
965 new_tree_node.data.node_id = node_id
967 self.state.current_node_id = node_id
968 self.state.updated_at = datetime.now()
970 await self._save_state()
972 yield node
974 except Exception as e:
975 import logging
976 logging.error(f"Flow execution failed: {e}")
977 raise ValueError(f"Flow execution failed: {e!s}") from e
979 async def get_history(self) -> List[LLMMessage]:
980 """Get conversation history from root to current position.
982 Returns:
983 List of messages in current conversation path
985 Example:
986 >>> messages = await manager.get_history()
987 >>> for msg in messages:
988 ... print(f"{msg.role}: {msg.content}")
989 """
990 if not self.state:
991 return []
993 return self.state.get_current_messages()
995 async def get_branches(self, node_id: str | None = None) -> List[Dict[str, Any]]:
996 """Get information about branches from a given node.
998 Args:
999 node_id: Node to get branches from (default: current node)
1001 Returns:
1002 List of branch info dicts with keys:
1003 - node_id: ID of child node
1004 - branch_name: Optional branch name
1005 - role: Message role
1006 - preview: First 100 chars of content
1007 - timestamp: When created
1009 Example:
1010 >>> branches = await manager.get_branches()
1011 >>> for branch in branches:
1012 ... print(f"{branch['branch_name']}: {branch['preview']}")
1013 """
1014 if not self.state:
1015 return []
1017 # Default to current node
1018 if node_id is None:
1019 node_id = self.state.current_node_id
1021 # Get node
1022 node = get_node_by_id(self.state.message_tree, node_id)
1023 if node is None or not node.children:
1024 return []
1026 # Build branch info
1027 branches = []
1028 for child in node.children:
1029 data = child.data
1030 branches.append({
1031 "node_id": data.node_id,
1032 "branch_name": data.branch_name,
1033 "role": data.message.role,
1034 "preview": data.message.content[:100],
1035 "timestamp": data.timestamp,
1036 })
1038 return branches
1040 async def add_metadata(self, key: str, value: Any) -> None:
1041 """Add metadata to conversation.
1043 Args:
1044 key: Metadata key
1045 value: Metadata value
1047 Example:
1048 >>> await manager.add_metadata("user_id", "alice")
1049 >>> await manager.add_metadata("session", "abc123")
1050 """
1051 if not self.state:
1052 raise ValueError("No conversation state")
1054 self.state.metadata[key] = value
1055 self.state.updated_at = datetime.now()
1056 await self._save_state()
1058 async def _find_cached_rag(
1059 self,
1060 prompt_name: str,
1061 role: str,
1062 params: Dict[str, Any]
1063 ) -> Dict[str, Any] | None:
1064 """Search conversation history for cached RAG metadata.
1066 This method searches the entire conversation tree for cached RAG metadata
1067 that matches both the prompt name/role AND the resolved RAG query parameters.
1068 Query matching is done via query hashes.
1070 Args:
1071 prompt_name: Name of the prompt to find cached RAG for
1072 role: Role of the prompt ("system" or "user")
1073 params: Parameters for the prompt (used to match RAG queries)
1075 Returns:
1076 Cached RAG metadata dictionary if found, None otherwise
1078 Example:
1079 >>> cached = await manager._find_cached_rag("code_question", "user", {"topic": "decorators"})
1080 >>> if cached:
1081 ... print(f"Found cached RAG with {len(cached)} placeholders")
1082 """
1083 if not self.state:
1084 return None
1086 # Get RAG configs for this prompt to determine what queries we're looking for
1087 rag_configs = self.prompt_builder.library.get_prompt_rag_configs(
1088 prompt_name=prompt_name,
1089 prompt_type="system" if role == "system" else "user"
1090 )
1092 if not rag_configs:
1093 return None
1095 # Compute the query hashes we're looking for
1096 from jinja2 import Template
1097 target_hashes_by_placeholder = {}
1098 for rag_config in rag_configs:
1099 placeholder = rag_config.get("placeholder", "RAG_CONTENT")
1100 adapter_name = rag_config.get("adapter_name", "")
1101 query_template = rag_config.get("query", "")
1103 # Render the query template with params
1104 try:
1105 template = Template(query_template)
1106 resolved_query = template.render(params)
1108 # Compute hash
1109 query_hash = self.prompt_builder._compute_rag_query_hash(adapter_name, resolved_query)
1110 target_hashes_by_placeholder[placeholder] = query_hash
1111 except Exception:
1112 # If query rendering fails, we can't match cache
1113 continue
1115 if not target_hashes_by_placeholder:
1116 return None
1118 # Search entire tree for matching cached RAG (BFS to find any match)
1119 from collections import deque
1120 queue = deque([self.state.message_tree])
1122 while queue:
1123 tree_node = queue.popleft()
1124 node_data = tree_node.data
1126 # Check if this node has the same prompt name and role
1127 if (node_data.prompt_name == prompt_name and
1128 node_data.message.role == role):
1130 # Check if RAG metadata exists
1131 rag_metadata = node_data.metadata.get("rag_metadata")
1132 if rag_metadata:
1133 # Check if query hashes match for all placeholders
1134 all_match = True
1135 for placeholder, target_hash in target_hashes_by_placeholder.items():
1136 if placeholder not in rag_metadata:
1137 all_match = False
1138 break
1139 cached_hash = rag_metadata[placeholder].get("query_hash")
1140 if cached_hash != target_hash:
1141 all_match = False
1142 break
1144 if all_match:
1145 return rag_metadata
1147 # Add children to queue (if any)
1148 if tree_node.children:
1149 queue.extend(tree_node.children)
1151 return None
1153 def get_rag_metadata(self, node_id: str | None = None) -> Dict[str, Any] | None:
1154 """Get RAG metadata from a conversation node.
1156 This method retrieves the cached RAG metadata from a specific node,
1157 which includes information about RAG searches executed during prompt
1158 rendering (queries, results, query hashes, etc.).
1160 Args:
1161 node_id: Node ID to retrieve metadata from (default: current node)
1163 Returns:
1164 RAG metadata dictionary if present, None otherwise. Structure:
1166 ```python
1167 {
1168 "PLACEHOLDER_NAME": {
1169 "query": "resolved RAG query",
1170 "query_hash": "hash of adapter+query",
1171 "results": [...], # Search results
1172 "adapter_name": "name of RAG adapter used"
1173 },
1174 ... # One entry per RAG placeholder
1175 }
1176 ```
1178 Raises:
1179 ValueError: If node_id not found in conversation tree
1181 Example:
1182 ```python
1183 # Get RAG metadata from current node
1184 metadata = manager.get_rag_metadata()
1185 if metadata:
1186 for placeholder, rag_data in metadata.items():
1187 print(f"Placeholder: {placeholder}")
1188 print(f" Query: {rag_data['query']}")
1189 print(f" Adapter: {rag_data['adapter_name']}")
1190 print(f" Results: {len(rag_data['results'])} items")
1191 print(f" Hash: {rag_data['query_hash']}")
1193 # Get RAG metadata from specific node
1194 metadata = manager.get_rag_metadata(node_id="0.1")
1196 # Check if RAG was used for a message
1197 if manager.get_rag_metadata():
1198 print("This message used RAG-enhanced prompt")
1199 else:
1200 print("This message used direct content")
1201 ```
1203 Note:
1204 RAG metadata is only available if `cache_rag_results=True` was
1205 set during ConversationManager creation. This metadata is useful
1206 for debugging RAG behavior, understanding what information was
1207 retrieved, and implementing RAG result caching across branches.
1209 See Also:
1210 add_message: Method that executes RAG and stores metadata
1211 reuse_rag_on_branch: Parameter enabling RAG cache reuse
1212 """
1213 if not self.state:
1214 return None
1216 # Default to current node
1217 if node_id is None:
1218 node_id = self.state.current_node_id
1220 # Get node
1221 tree_node = get_node_by_id(self.state.message_tree, node_id)
1222 if tree_node is None:
1223 raise ValueError(f"Node '{node_id}' not found in conversation tree")
1225 # Return RAG metadata if present
1226 return tree_node.data.metadata.get("rag_metadata")
1228 async def _save_state(self) -> None:
1229 """Persist current state to storage."""
1230 if self.state:
1231 await self.storage.save_conversation(self.state)
1233 @property
1234 def conversation_id(self) -> str | None:
1235 """Get conversation ID."""
1236 return self.state.conversation_id if self.state else None
1238 @property
1239 def current_node_id(self) -> str | None:
1240 """Get current node ID."""
1241 return self.state.current_node_id if self.state else None
1243 def get_metadata(self, key: str | None = None, default: Any = None) -> Any:
1244 """Get conversation metadata.
1246 This provides access to the conversation-level metadata stored in
1247 the ConversationState. Metadata is useful for storing client_id,
1248 user_id, session information, and other contextual data.
1250 Args:
1251 key: Specific metadata key to retrieve. If None, returns all metadata.
1252 default: Default value if key not found (only used when key is specified)
1254 Returns:
1255 Metadata value, all metadata dict, or default value
1257 Example:
1258 >>> # Get all metadata
1259 >>> metadata = manager.get_metadata()
1260 >>> print(metadata) # {'client_id': 'abc', 'user_id': '123'}
1261 >>>
1262 >>> # Get specific key
1263 >>> client_id = manager.get_metadata('client_id')
1264 >>> print(client_id) # 'abc'
1265 >>>
1266 >>> # Get with default
1267 >>> tier = manager.get_metadata('user_tier', default='free')
1268 """
1269 if not self.state:
1270 return default if key else {}
1272 if key is None:
1273 return self.state.metadata
1274 else:
1275 return self.state.metadata.get(key, default)
1277 def set_metadata(self, key: str, value: Any) -> None:
1278 """Set conversation metadata.
1280 Updates a specific key in the conversation metadata. The metadata
1281 is automatically persisted when save() is called.
1283 Args:
1284 key: Metadata key to set
1285 value: Metadata value
1287 Example:
1288 >>> manager.set_metadata('client_id', 'client-abc')
1289 >>> manager.set_metadata('user_tier', 'premium')
1290 >>> await manager.save()
1291 """
1292 if self.state:
1293 self.state.metadata[key] = value
1295 def update_metadata(self, updates: Dict[str, Any]) -> None:
1296 """Update multiple metadata fields at once.
1298 Args:
1299 updates: Dictionary of metadata key-value pairs to update
1301 Example:
1302 >>> manager.update_metadata({
1303 ... 'client_id': 'client-abc',
1304 ... 'user_id': 'user-456',
1305 ... 'session_id': 'sess-789'
1306 ... })
1307 >>> await manager.save()
1308 """
1309 if self.state:
1310 self.state.metadata.update(updates)
1312 def remove_metadata(self, key: str) -> None:
1313 """Remove a metadata key.
1315 Args:
1316 key: Metadata key to remove
1318 Example:
1319 >>> manager.remove_metadata('temporary_flag')
1320 >>> await manager.save()
1321 """
1322 if self.state and key in self.state.metadata:
1323 del self.state.metadata[key]
1325 def get_total_cost(self) -> float:
1326 """Get total accumulated cost for this conversation in USD.
1328 Calculates the sum of all LLM API costs from the conversation history.
1329 Requires that cost_usd was set on LLMResponses.
1331 Returns:
1332 Total cost in USD, or 0.0 if no cost data available
1334 Example:
1335 >>> total = manager.get_total_cost()
1336 >>> print(f"Total cost: ${total:.4f}")
1337 """
1338 if not self.state:
1339 return 0.0
1341 total = 0.0
1343 # Walk the tree and sum costs from all assistant message nodes
1344 def walk_tree(node: Tree) -> None:
1345 nonlocal total
1346 if node.data and node.data.metadata:
1347 cost = node.data.metadata.get('cost_usd')
1348 if cost is not None:
1349 total += cost
1351 for child in node.children:
1352 walk_tree(child)
1354 walk_tree(self.state.message_tree)
1355 return total
1357 def get_cost_by_branch(self, node_id: str | None = None) -> float:
1358 """Get accumulated cost for a specific conversation branch.
1360 Calculates the cost from root to a specific node (defaults to current).
1362 Args:
1363 node_id: Node ID to calculate cost to. If None, uses current node.
1365 Returns:
1366 Cost in USD for this branch, or 0.0 if no cost data
1368 Example:
1369 >>> # Get cost of current branch
1370 >>> current_cost = manager.get_cost_by_branch()
1371 >>>
1372 >>> # Get cost of specific branch
1373 >>> alt_cost = manager.get_cost_by_branch("0.1")
1374 """
1375 if not self.state:
1376 return 0.0
1378 target_node_id = node_id or self.state.current_node_id
1380 # Get messages in this branch
1382 # Walk from root to target node
1383 if not target_node_id or target_node_id == "":
1384 # Just root node
1385 return 0.0
1387 indexes = [int(i) for i in target_node_id.split(".")]
1389 total = 0.0
1390 current = self.state.message_tree
1392 for idx in indexes:
1393 if idx < len(current.children):
1394 current = current.children[idx]
1395 if current.data and current.data.metadata:
1396 cost = current.data.metadata.get('cost_usd')
1397 if cost is not None:
1398 total += cost
1400 return total
1402 def _calculate_and_track_cost(
1403 self,
1404 response: LLMResponse,
1405 metadata: Dict[str, Any]
1406 ) -> None:
1407 """Calculate cost for a response and add to metadata.
1409 This is an internal helper that uses the CostCalculator utility
1410 to estimate costs and track them in the conversation.
1412 Args:
1413 response: LLM response to calculate cost for
1414 metadata: Metadata dict to add cost information to
1415 """
1416 try:
1417 from dataknobs_llm.llm.utils import CostCalculator
1419 if response.usage:
1420 cost = CostCalculator.calculate_cost(response, response.model)
1421 if cost is not None:
1422 # Add to response
1423 response.cost_usd = cost
1425 # Calculate cumulative cost
1426 cumulative = self.get_total_cost() + cost
1427 response.cumulative_cost_usd = cumulative
1429 # Store in metadata
1430 metadata['cost_usd'] = cost
1431 metadata['cumulative_cost_usd'] = cumulative
1432 except Exception as e:
1433 # Don't fail the conversation if cost calculation fails
1434 import logging
1435 logger = logging.getLogger(__name__)
1436 logger.warning(f"Failed to calculate cost: {e}")