Coverage for src/dataknobs_llm/conversations/storage.py: 0%
177 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-31 16:04 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-31 16:04 -0600
1"""Conversation storage with tree-based branching support.
3This module provides:
4- ConversationNode: Data stored in each tree node
5- ConversationState: Tree-based conversation state
6- ConversationStorage: Abstract storage interface
7- DataknobsConversationStorage: Storage adapter for dataknobs backends
8- Helper functions for node ID management and tree navigation
10Schema Versioning:
11 The storage format uses semantic versioning (MAJOR.MINOR.PATCH):
12 - MAJOR: Incompatible changes requiring migration
13 - MINOR: Backward-compatible additions
14 - PATCH: Bug fixes, no schema changes
16 Current schema version: 1.0.0
17"""
19from abc import ABC, abstractmethod
20from dataclasses import dataclass, field
21from datetime import datetime
22from typing import Any, Dict, List, Optional
23import logging
25from dataknobs_structures.tree import Tree
26from dataknobs_llm.llm.base import LLMMessage
28# Current schema version - increment when making schema changes
29SCHEMA_VERSION = "1.0.0"
31logger = logging.getLogger(__name__)
34@dataclass
35class ConversationNode:
36 """Data stored in each conversation tree node.
38 Each node represents a single message (system, user, or assistant) in the
39 conversation. The tree structure allows for branching conversations where
40 multiple alternative messages can be explored.
42 Attributes:
43 message: The LLM message (role + content)
44 node_id: Dot-delimited child positions from root (e.g., "0.1.2")
45 timestamp: When this message was created
46 prompt_name: Optional name of prompt template used to generate this
47 branch_name: Optional human-readable label for this branch
48 metadata: Additional metadata (usage stats, model info, etc.)
50 Example:
51 >>> node = ConversationNode(
52 ... message=LLMMessage(role="user", content="Hello"),
53 ... node_id="0.1",
54 ... timestamp=datetime.now(),
55 ... prompt_name="greeting",
56 ... branch_name="polite-variant"
57 ... )
58 """
59 message: LLMMessage
60 node_id: str
61 timestamp: datetime = field(default_factory=datetime.now)
62 prompt_name: Optional[str] = None
63 branch_name: Optional[str] = None
64 metadata: Dict[str, Any] = field(default_factory=dict)
66 def to_dict(self) -> Dict[str, Any]:
67 """Convert node to dictionary for storage."""
68 return {
69 "message": {
70 "role": self.message.role,
71 "content": self.message.content,
72 "name": self.message.name,
73 "metadata": self.message.metadata or {}
74 },
75 "node_id": self.node_id,
76 "timestamp": self.timestamp.isoformat(),
77 "prompt_name": self.prompt_name,
78 "branch_name": self.branch_name,
79 "metadata": self.metadata
80 }
82 @classmethod
83 def from_dict(cls, data: Dict[str, Any]) -> "ConversationNode":
84 """Create node from dictionary."""
85 msg_data = data["message"]
86 return cls(
87 message=LLMMessage(
88 role=msg_data["role"],
89 content=msg_data["content"],
90 name=msg_data.get("name"),
91 metadata=msg_data.get("metadata", {})
92 ),
93 node_id=data["node_id"],
94 timestamp=datetime.fromisoformat(data["timestamp"]),
95 prompt_name=data.get("prompt_name"),
96 branch_name=data.get("branch_name"),
97 metadata=data.get("metadata", {})
98 )
101def calculate_node_id(node: Tree) -> str:
102 """Calculate dot-delimited node ID by walking up to root.
104 The node ID represents the path from root to this node as a series of
105 child indexes. For example, "0.1.2" means: root's child 0, then that
106 node's child 1, then that node's child 2.
108 Args:
109 node: Tree node to calculate ID for
111 Returns:
112 Dot-delimited node ID (e.g., "0", "0.1", "0.1.2")
114 Example:
115 >>> root = Tree(data)
116 >>> child = root.add_child(data2)
117 >>> grandchild = child.add_child(data3)
118 >>> calculate_node_id(grandchild)
119 '0.0'
120 """
121 if node.parent is None:
122 # Root node has no parent, so it's just "0" or we could use ""
123 # Let's use "" for root to make child IDs cleaner
124 return ""
126 # Walk up to root, collecting child indexes
127 indexes = []
128 current = node
129 while current.parent is not None:
130 indexes.append(str(current.sibnum))
131 current = current.parent
133 # Reverse to get root-to-node order
134 indexes.reverse()
136 return ".".join(indexes) if indexes else "0"
139def get_node_by_id(tree: Tree, node_id: str) -> Optional[Tree]:
140 """Retrieve tree node by its dot-delimited ID.
142 Args:
143 tree: Root of the tree
144 node_id: Dot-delimited node ID (e.g., "0.1.2")
146 Returns:
147 Tree node with that ID, or None if not found
149 Example:
150 >>> node = get_node_by_id(tree, "0.1.2")
151 >>> # Equivalent to: tree.children[0].children[1].children[2]
152 """
153 if not node_id or node_id == "":
154 return tree # Root node
156 # Split into child indexes
157 try:
158 indexes = [int(i) for i in node_id.split(".")]
159 except ValueError:
160 return None # Invalid node_id format
162 # Navigate down the tree
163 current = tree
164 for idx in indexes:
165 if not current.children or idx >= len(current.children):
166 return None # Invalid path
167 current = current.children[idx]
169 return current
172def get_messages_for_llm(tree: Tree, node_id: str) -> List[LLMMessage]:
173 """Get linear message sequence from root to specified node.
175 This is what gets sent to the LLM - the path through the tree from
176 root to current position.
178 Args:
179 tree: Root of conversation tree
180 node_id: ID of current position
182 Returns:
183 List of messages from root to current node
185 Example:
186 >>> messages = get_messages_for_llm(tree, "0.1.2")
187 >>> # Returns: [root_msg, child_0_msg, child_1_msg, child_2_msg]
188 """
189 node = get_node_by_id(tree, node_id)
190 if node is None:
191 return []
193 # Get path from root to node
194 path = node.get_path()
196 # Extract messages from each node's data
197 messages = []
198 for tree_node in path:
199 if isinstance(tree_node.data, ConversationNode):
200 messages.append(tree_node.data.message)
202 return messages
205@dataclass
206class ConversationState:
207 """State of a conversation with tree-based branching support.
209 This replaces the linear message history with a tree structure that
210 supports multiple branches (alternative conversation paths).
212 Attributes:
213 conversation_id: Unique conversation identifier
214 message_tree: Root of conversation tree (Tree[ConversationNode])
215 current_node_id: ID of current position in tree (dot-delimited)
216 metadata: Additional conversation metadata
217 created_at: Conversation creation timestamp
218 updated_at: Last update timestamp
219 schema_version: Version of the storage schema used
221 Example:
222 >>> # Create conversation with system message
223 >>> root_node = ConversationNode(
224 ... message=LLMMessage(role="system", content="You are helpful"),
225 ... node_id=""
226 ... )
227 >>> tree = Tree(root_node)
228 >>> state = ConversationState(
229 ... conversation_id="conv-123",
230 ... message_tree=tree,
231 ... current_node_id="",
232 ... metadata={"user_id": "alice"}
233 ... )
234 >>>
235 >>> # Add user message
236 >>> user_node = ConversationNode(
237 ... message=LLMMessage(role="user", content="Hello"),
238 ... node_id="0"
239 ... )
240 >>> tree.add_child(Tree(user_node))
241 >>> state.current_node_id = "0"
242 """
243 conversation_id: str
244 message_tree: Tree # Tree[ConversationNode]
245 current_node_id: str = ""
246 metadata: Dict[str, Any] = field(default_factory=dict)
247 created_at: datetime = field(default_factory=datetime.now)
248 updated_at: datetime = field(default_factory=datetime.now)
249 schema_version: str = SCHEMA_VERSION
251 def get_current_node(self) -> Optional[Tree]:
252 """Get the current tree node."""
253 return get_node_by_id(self.message_tree, self.current_node_id)
255 def get_current_messages(self) -> List[LLMMessage]:
256 """Get messages from root to current position (for LLM)."""
257 return get_messages_for_llm(self.message_tree, self.current_node_id)
259 def to_dict(self) -> Dict[str, Any]:
260 """Convert state to dictionary for storage.
262 The tree is serialized as a list of edges (parent_id, child_id, node_data).
263 Includes schema_version for backward compatibility.
264 """
265 # Collect all nodes and their data
266 nodes = []
267 edges = []
269 all_nodes = self.message_tree.find_nodes(lambda n: True, traversal="bfs")
270 for tree_node in all_nodes:
271 if isinstance(tree_node.data, ConversationNode):
272 nodes.append(tree_node.data.to_dict())
274 # Add edge to parent (if not root)
275 if tree_node.parent is not None:
276 parent_id = calculate_node_id(tree_node.parent)
277 child_id = tree_node.data.node_id
278 edges.append([parent_id, child_id])
280 return {
281 "schema_version": self.schema_version,
282 "conversation_id": self.conversation_id,
283 "nodes": nodes,
284 "edges": edges,
285 "current_node_id": self.current_node_id,
286 "metadata": self.metadata,
287 "created_at": self.created_at.isoformat(),
288 "updated_at": self.updated_at.isoformat()
289 }
291 @classmethod
292 def from_dict(cls, data: Dict[str, Any]) -> "ConversationState":
293 """Create state from dictionary.
295 Reconstructs the tree from nodes and edges.
296 Handles schema version migration if needed.
297 """
298 # Check schema version
299 stored_version = data.get("schema_version", "0.0.0") # Default to 0.0.0 if missing
301 # Apply migrations if needed
302 if stored_version != SCHEMA_VERSION:
303 logger.info(
304 f"Migrating conversation {data['conversation_id']} "
305 f"from schema {stored_version} to {SCHEMA_VERSION}"
306 )
307 data = cls._migrate_schema(data, stored_version, SCHEMA_VERSION)
309 # Create nodes indexed by ID
310 nodes_by_id: Dict[str, ConversationNode] = {}
311 for node_data in data["nodes"]:
312 node = ConversationNode.from_dict(node_data)
313 nodes_by_id[node.node_id] = node
315 # Find root (node with empty ID)
316 root_node = nodes_by_id.get("", None)
317 if root_node is None:
318 # Try to find node with no parent in edges
319 child_ids = {edge[1] for edge in data["edges"]}
320 parent_ids = {edge[0] for edge in data["edges"]}
321 root_ids = parent_ids - child_ids
322 if root_ids:
323 root_node = nodes_by_id[root_ids.pop()]
324 else:
325 # Fallback: first node
326 root_node = list(nodes_by_id.values())[0]
328 tree = Tree(root_node)
329 tree_nodes_by_id = {"": tree} # Map node_id -> Tree node
331 # Build tree by adding edges
332 for parent_id, child_id in data["edges"]:
333 if parent_id in tree_nodes_by_id:
334 parent_tree_node = tree_nodes_by_id[parent_id]
335 child_node = nodes_by_id[child_id]
336 child_tree_node = parent_tree_node.add_child(Tree(child_node))
337 tree_nodes_by_id[child_id] = child_tree_node
339 return cls(
340 conversation_id=data["conversation_id"],
341 message_tree=tree,
342 current_node_id=data["current_node_id"],
343 metadata=data["metadata"],
344 created_at=datetime.fromisoformat(data["created_at"]),
345 updated_at=datetime.fromisoformat(data["updated_at"]),
346 schema_version=SCHEMA_VERSION # Always use current version after migration
347 )
349 @staticmethod
350 def _migrate_schema(
351 data: Dict[str, Any],
352 from_version: str,
353 to_version: str
354 ) -> Dict[str, Any]:
355 """Migrate data from one schema version to another.
357 This method applies migrations sequentially to transform data from
358 an older schema version to the current version.
360 Args:
361 data: Data in old schema format
362 from_version: Source schema version
363 to_version: Target schema version
365 Returns:
366 Data in new schema format
368 Raises:
369 SchemaVersionError: If migration path is not supported
370 """
371 # Parse version strings
372 from_major, from_minor, from_patch = map(int, from_version.split("."))
373 to_major, to_minor, to_patch = map(int, to_version.split("."))
375 # No migration needed if versions match
376 if from_version == to_version:
377 return data
379 # Apply migrations based on version transitions
380 # Future migrations will be added here as needed
382 # Example migration patterns:
383 # if from_version == "0.0.0" and to_version >= "1.0.0":
384 # data = cls._migrate_0_to_1(data)
385 # if from_version < "1.1.0" and to_version >= "1.1.0":
386 # data = cls._migrate_1_0_to_1_1(data)
388 # For now, version 0.0.0 (no version field) to 1.0.0 is a no-op
389 # because the schema didn't change, we just added versioning
390 if from_version == "0.0.0":
391 logger.debug("Migrating from unversioned schema to 1.0.0 (no changes needed)")
392 data["schema_version"] = "1.0.0"
393 return data
395 # If we get here and versions still don't match, it's unsupported
396 if from_major > to_major:
397 raise SchemaVersionError(
398 f"Cannot downgrade from schema {from_version} to {to_version}"
399 )
401 logger.warning(
402 f"No migration path defined from {from_version} to {to_version}. "
403 "Using data as-is."
404 )
405 data["schema_version"] = to_version
406 return data
408 # Future migration methods will be added here as needed:
409 # @staticmethod
410 # def _migrate_1_0_to_1_1(data: Dict[str, Any]) -> Dict[str, Any]:
411 # """Migrate from schema 1.0 to 1.1."""
412 # # Add new field with default value
413 # data["new_field"] = "default_value"
414 # return data
417class ConversationStorage(ABC):
418 """Abstract storage interface for conversations.
420 This interface defines the contract for persisting conversation state.
421 Implementations can use any backend (SQL, NoSQL, file, etc.).
422 """
424 @abstractmethod
425 async def save_conversation(self, state: ConversationState) -> None:
426 """Save conversation state (upsert).
428 Args:
429 state: Conversation state to save
430 """
431 pass
433 @abstractmethod
434 async def load_conversation(
435 self,
436 conversation_id: str
437 ) -> Optional[ConversationState]:
438 """Load conversation state.
440 Args:
441 conversation_id: Conversation identifier
443 Returns:
444 Conversation state or None if not found
445 """
446 pass
448 @abstractmethod
449 async def delete_conversation(self, conversation_id: str) -> bool:
450 """Delete conversation.
452 Args:
453 conversation_id: Conversation identifier
455 Returns:
456 True if deleted, False if not found
457 """
458 pass
460 @abstractmethod
461 async def list_conversations(
462 self,
463 filter_metadata: Optional[Dict[str, Any]] = None,
464 limit: int = 100,
465 offset: int = 0
466 ) -> List[ConversationState]:
467 """List conversations with optional filtering.
469 Args:
470 filter_metadata: Optional metadata filters
471 limit: Maximum number of results
472 offset: Offset for pagination
474 Returns:
475 List of conversation states
476 """
477 pass
480class DataknobsConversationStorage(ConversationStorage):
481 """Conversation storage using dataknobs_data backends.
483 Stores conversations as Records with the tree serialized as nodes + edges.
484 Works with any dataknobs backend (Memory, File, S3, Postgres, etc.).
486 Example:
487 >>> from dataknobs_data.backends import AsyncMemoryDatabase
488 >>> storage = DataknobsConversationStorage(AsyncMemoryDatabase())
489 >>> await storage.save_conversation(state)
490 """
492 def __init__(self, backend: Any):
493 """Initialize storage with dataknobs backend.
495 Args:
496 backend: Dataknobs async database backend (AsyncMemoryDatabase, etc.)
497 """
498 self.backend = backend
500 def _state_to_record(self, state: ConversationState) -> Any:
501 """Convert ConversationState to Record.
503 Args:
504 state: Conversation state to convert
506 Returns:
507 Record object for storage
508 """
509 # Import here to avoid circular dependency
510 try:
511 from dataknobs_data.records import Record
512 except ImportError:
513 raise StorageError(
514 "dataknobs_data package not available. "
515 "Install it to use DataknobsConversationStorage."
516 )
518 # Convert state to dict
519 data = state.to_dict()
521 # Create Record with conversation_id as storage_id
522 return Record(
523 data=data,
524 storage_id=state.conversation_id
525 )
527 def _record_to_state(self, record: Any) -> ConversationState:
528 """Convert Record to ConversationState.
530 Args:
531 record: Record object from storage
533 Returns:
534 Conversation state
535 """
536 # Extract data from record
537 data = {}
538 for field_name, field in record.fields.items():
539 data[field_name] = field.value
541 # Reconstruct conversation state
542 return ConversationState.from_dict(data)
544 async def save_conversation(self, state: ConversationState) -> None:
545 """Save conversation to backend."""
546 try:
547 record = self._state_to_record(state)
548 # Use upsert to insert or update
549 await self.backend.upsert(state.conversation_id, record)
550 except Exception as e:
551 raise StorageError(f"Failed to save conversation: {e}") from e
553 async def load_conversation(
554 self,
555 conversation_id: str
556 ) -> Optional[ConversationState]:
557 """Load conversation from backend."""
558 try:
559 # Read record by ID
560 record = await self.backend.read(conversation_id)
561 if record is None:
562 return None
564 return self._record_to_state(record)
566 except Exception as e:
567 raise StorageError(f"Failed to load conversation: {e}") from e
569 async def delete_conversation(self, conversation_id: str) -> bool:
570 """Delete conversation from backend."""
571 try:
572 return await self.backend.delete(conversation_id)
573 except Exception as e:
574 raise StorageError(f"Failed to delete conversation: {e}") from e
576 async def list_conversations(
577 self,
578 filter_metadata: Optional[Dict[str, Any]] = None,
579 limit: int = 100,
580 offset: int = 0
581 ) -> List[ConversationState]:
582 """List conversations from backend."""
583 try:
584 # Import Query for filtering
585 try:
586 from dataknobs_data.query import Query
587 except ImportError:
588 raise StorageError(
589 "dataknobs_data package not available. "
590 "Install it to use DataknobsConversationStorage."
591 )
593 # Build query with metadata filters using fluent interface
594 query = Query()
595 query.limit(limit).offset(offset)
597 if filter_metadata:
598 for key, value in filter_metadata.items():
599 # Add filter for metadata.key = value
600 query.filter(f"metadata.{key}", "=", value)
602 # Search with query
603 results = await self.backend.search(query)
605 # Convert records to conversation states
606 return [self._record_to_state(record) for record in results]
608 except Exception as e:
609 raise StorageError(f"Failed to list conversations: {e}") from e
612class StorageError(Exception):
613 """Exception raised for storage operation errors."""
614 pass
617class SchemaVersionError(Exception):
618 """Exception raised for schema version incompatibilities."""
619 pass