Coverage for src/dataknobs_llm/conversations/storage.py: 0%

177 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-31 16:04 -0600

1"""Conversation storage with tree-based branching support. 

2 

3This module provides: 

4- ConversationNode: Data stored in each tree node 

5- ConversationState: Tree-based conversation state 

6- ConversationStorage: Abstract storage interface 

7- DataknobsConversationStorage: Storage adapter for dataknobs backends 

8- Helper functions for node ID management and tree navigation 

9 

10Schema Versioning: 

11 The storage format uses semantic versioning (MAJOR.MINOR.PATCH): 

12 - MAJOR: Incompatible changes requiring migration 

13 - MINOR: Backward-compatible additions 

14 - PATCH: Bug fixes, no schema changes 

15 

16 Current schema version: 1.0.0 

17""" 

18 

19from abc import ABC, abstractmethod 

20from dataclasses import dataclass, field 

21from datetime import datetime 

22from typing import Any, Dict, List, Optional 

23import logging 

24 

25from dataknobs_structures.tree import Tree 

26from dataknobs_llm.llm.base import LLMMessage 

27 

28# Current schema version - increment when making schema changes 

29SCHEMA_VERSION = "1.0.0" 

30 

31logger = logging.getLogger(__name__) 

32 

33 

34@dataclass 

35class ConversationNode: 

36 """Data stored in each conversation tree node. 

37 

38 Each node represents a single message (system, user, or assistant) in the 

39 conversation. The tree structure allows for branching conversations where 

40 multiple alternative messages can be explored. 

41 

42 Attributes: 

43 message: The LLM message (role + content) 

44 node_id: Dot-delimited child positions from root (e.g., "0.1.2") 

45 timestamp: When this message was created 

46 prompt_name: Optional name of prompt template used to generate this 

47 branch_name: Optional human-readable label for this branch 

48 metadata: Additional metadata (usage stats, model info, etc.) 

49 

50 Example: 

51 >>> node = ConversationNode( 

52 ... message=LLMMessage(role="user", content="Hello"), 

53 ... node_id="0.1", 

54 ... timestamp=datetime.now(), 

55 ... prompt_name="greeting", 

56 ... branch_name="polite-variant" 

57 ... ) 

58 """ 

59 message: LLMMessage 

60 node_id: str 

61 timestamp: datetime = field(default_factory=datetime.now) 

62 prompt_name: Optional[str] = None 

63 branch_name: Optional[str] = None 

64 metadata: Dict[str, Any] = field(default_factory=dict) 

65 

66 def to_dict(self) -> Dict[str, Any]: 

67 """Convert node to dictionary for storage.""" 

68 return { 

69 "message": { 

70 "role": self.message.role, 

71 "content": self.message.content, 

72 "name": self.message.name, 

73 "metadata": self.message.metadata or {} 

74 }, 

75 "node_id": self.node_id, 

76 "timestamp": self.timestamp.isoformat(), 

77 "prompt_name": self.prompt_name, 

78 "branch_name": self.branch_name, 

79 "metadata": self.metadata 

80 } 

81 

82 @classmethod 

83 def from_dict(cls, data: Dict[str, Any]) -> "ConversationNode": 

84 """Create node from dictionary.""" 

85 msg_data = data["message"] 

86 return cls( 

87 message=LLMMessage( 

88 role=msg_data["role"], 

89 content=msg_data["content"], 

90 name=msg_data.get("name"), 

91 metadata=msg_data.get("metadata", {}) 

92 ), 

93 node_id=data["node_id"], 

94 timestamp=datetime.fromisoformat(data["timestamp"]), 

95 prompt_name=data.get("prompt_name"), 

96 branch_name=data.get("branch_name"), 

97 metadata=data.get("metadata", {}) 

98 ) 

99 

100 

101def calculate_node_id(node: Tree) -> str: 

102 """Calculate dot-delimited node ID by walking up to root. 

103 

104 The node ID represents the path from root to this node as a series of 

105 child indexes. For example, "0.1.2" means: root's child 0, then that 

106 node's child 1, then that node's child 2. 

107 

108 Args: 

109 node: Tree node to calculate ID for 

110 

111 Returns: 

112 Dot-delimited node ID (e.g., "0", "0.1", "0.1.2") 

113 

114 Example: 

115 >>> root = Tree(data) 

116 >>> child = root.add_child(data2) 

117 >>> grandchild = child.add_child(data3) 

118 >>> calculate_node_id(grandchild) 

119 '0.0' 

120 """ 

121 if node.parent is None: 

122 # Root node has no parent, so it's just "0" or we could use "" 

123 # Let's use "" for root to make child IDs cleaner 

124 return "" 

125 

126 # Walk up to root, collecting child indexes 

127 indexes = [] 

128 current = node 

129 while current.parent is not None: 

130 indexes.append(str(current.sibnum)) 

131 current = current.parent 

132 

133 # Reverse to get root-to-node order 

134 indexes.reverse() 

135 

136 return ".".join(indexes) if indexes else "0" 

137 

138 

139def get_node_by_id(tree: Tree, node_id: str) -> Optional[Tree]: 

140 """Retrieve tree node by its dot-delimited ID. 

141 

142 Args: 

143 tree: Root of the tree 

144 node_id: Dot-delimited node ID (e.g., "0.1.2") 

145 

146 Returns: 

147 Tree node with that ID, or None if not found 

148 

149 Example: 

150 >>> node = get_node_by_id(tree, "0.1.2") 

151 >>> # Equivalent to: tree.children[0].children[1].children[2] 

152 """ 

153 if not node_id or node_id == "": 

154 return tree # Root node 

155 

156 # Split into child indexes 

157 try: 

158 indexes = [int(i) for i in node_id.split(".")] 

159 except ValueError: 

160 return None # Invalid node_id format 

161 

162 # Navigate down the tree 

163 current = tree 

164 for idx in indexes: 

165 if not current.children or idx >= len(current.children): 

166 return None # Invalid path 

167 current = current.children[idx] 

168 

169 return current 

170 

171 

172def get_messages_for_llm(tree: Tree, node_id: str) -> List[LLMMessage]: 

173 """Get linear message sequence from root to specified node. 

174 

175 This is what gets sent to the LLM - the path through the tree from 

176 root to current position. 

177 

178 Args: 

179 tree: Root of conversation tree 

180 node_id: ID of current position 

181 

182 Returns: 

183 List of messages from root to current node 

184 

185 Example: 

186 >>> messages = get_messages_for_llm(tree, "0.1.2") 

187 >>> # Returns: [root_msg, child_0_msg, child_1_msg, child_2_msg] 

188 """ 

189 node = get_node_by_id(tree, node_id) 

190 if node is None: 

191 return [] 

192 

193 # Get path from root to node 

194 path = node.get_path() 

195 

196 # Extract messages from each node's data 

197 messages = [] 

198 for tree_node in path: 

199 if isinstance(tree_node.data, ConversationNode): 

200 messages.append(tree_node.data.message) 

201 

202 return messages 

203 

204 

205@dataclass 

206class ConversationState: 

207 """State of a conversation with tree-based branching support. 

208 

209 This replaces the linear message history with a tree structure that 

210 supports multiple branches (alternative conversation paths). 

211 

212 Attributes: 

213 conversation_id: Unique conversation identifier 

214 message_tree: Root of conversation tree (Tree[ConversationNode]) 

215 current_node_id: ID of current position in tree (dot-delimited) 

216 metadata: Additional conversation metadata 

217 created_at: Conversation creation timestamp 

218 updated_at: Last update timestamp 

219 schema_version: Version of the storage schema used 

220 

221 Example: 

222 >>> # Create conversation with system message 

223 >>> root_node = ConversationNode( 

224 ... message=LLMMessage(role="system", content="You are helpful"), 

225 ... node_id="" 

226 ... ) 

227 >>> tree = Tree(root_node) 

228 >>> state = ConversationState( 

229 ... conversation_id="conv-123", 

230 ... message_tree=tree, 

231 ... current_node_id="", 

232 ... metadata={"user_id": "alice"} 

233 ... ) 

234 >>> 

235 >>> # Add user message 

236 >>> user_node = ConversationNode( 

237 ... message=LLMMessage(role="user", content="Hello"), 

238 ... node_id="0" 

239 ... ) 

240 >>> tree.add_child(Tree(user_node)) 

241 >>> state.current_node_id = "0" 

242 """ 

243 conversation_id: str 

244 message_tree: Tree # Tree[ConversationNode] 

245 current_node_id: str = "" 

246 metadata: Dict[str, Any] = field(default_factory=dict) 

247 created_at: datetime = field(default_factory=datetime.now) 

248 updated_at: datetime = field(default_factory=datetime.now) 

249 schema_version: str = SCHEMA_VERSION 

250 

251 def get_current_node(self) -> Optional[Tree]: 

252 """Get the current tree node.""" 

253 return get_node_by_id(self.message_tree, self.current_node_id) 

254 

255 def get_current_messages(self) -> List[LLMMessage]: 

256 """Get messages from root to current position (for LLM).""" 

257 return get_messages_for_llm(self.message_tree, self.current_node_id) 

258 

259 def to_dict(self) -> Dict[str, Any]: 

260 """Convert state to dictionary for storage. 

261 

262 The tree is serialized as a list of edges (parent_id, child_id, node_data). 

263 Includes schema_version for backward compatibility. 

264 """ 

265 # Collect all nodes and their data 

266 nodes = [] 

267 edges = [] 

268 

269 all_nodes = self.message_tree.find_nodes(lambda n: True, traversal="bfs") 

270 for tree_node in all_nodes: 

271 if isinstance(tree_node.data, ConversationNode): 

272 nodes.append(tree_node.data.to_dict()) 

273 

274 # Add edge to parent (if not root) 

275 if tree_node.parent is not None: 

276 parent_id = calculate_node_id(tree_node.parent) 

277 child_id = tree_node.data.node_id 

278 edges.append([parent_id, child_id]) 

279 

280 return { 

281 "schema_version": self.schema_version, 

282 "conversation_id": self.conversation_id, 

283 "nodes": nodes, 

284 "edges": edges, 

285 "current_node_id": self.current_node_id, 

286 "metadata": self.metadata, 

287 "created_at": self.created_at.isoformat(), 

288 "updated_at": self.updated_at.isoformat() 

289 } 

290 

291 @classmethod 

292 def from_dict(cls, data: Dict[str, Any]) -> "ConversationState": 

293 """Create state from dictionary. 

294 

295 Reconstructs the tree from nodes and edges. 

296 Handles schema version migration if needed. 

297 """ 

298 # Check schema version 

299 stored_version = data.get("schema_version", "0.0.0") # Default to 0.0.0 if missing 

300 

301 # Apply migrations if needed 

302 if stored_version != SCHEMA_VERSION: 

303 logger.info( 

304 f"Migrating conversation {data['conversation_id']} " 

305 f"from schema {stored_version} to {SCHEMA_VERSION}" 

306 ) 

307 data = cls._migrate_schema(data, stored_version, SCHEMA_VERSION) 

308 

309 # Create nodes indexed by ID 

310 nodes_by_id: Dict[str, ConversationNode] = {} 

311 for node_data in data["nodes"]: 

312 node = ConversationNode.from_dict(node_data) 

313 nodes_by_id[node.node_id] = node 

314 

315 # Find root (node with empty ID) 

316 root_node = nodes_by_id.get("", None) 

317 if root_node is None: 

318 # Try to find node with no parent in edges 

319 child_ids = {edge[1] for edge in data["edges"]} 

320 parent_ids = {edge[0] for edge in data["edges"]} 

321 root_ids = parent_ids - child_ids 

322 if root_ids: 

323 root_node = nodes_by_id[root_ids.pop()] 

324 else: 

325 # Fallback: first node 

326 root_node = list(nodes_by_id.values())[0] 

327 

328 tree = Tree(root_node) 

329 tree_nodes_by_id = {"": tree} # Map node_id -> Tree node 

330 

331 # Build tree by adding edges 

332 for parent_id, child_id in data["edges"]: 

333 if parent_id in tree_nodes_by_id: 

334 parent_tree_node = tree_nodes_by_id[parent_id] 

335 child_node = nodes_by_id[child_id] 

336 child_tree_node = parent_tree_node.add_child(Tree(child_node)) 

337 tree_nodes_by_id[child_id] = child_tree_node 

338 

339 return cls( 

340 conversation_id=data["conversation_id"], 

341 message_tree=tree, 

342 current_node_id=data["current_node_id"], 

343 metadata=data["metadata"], 

344 created_at=datetime.fromisoformat(data["created_at"]), 

345 updated_at=datetime.fromisoformat(data["updated_at"]), 

346 schema_version=SCHEMA_VERSION # Always use current version after migration 

347 ) 

348 

349 @staticmethod 

350 def _migrate_schema( 

351 data: Dict[str, Any], 

352 from_version: str, 

353 to_version: str 

354 ) -> Dict[str, Any]: 

355 """Migrate data from one schema version to another. 

356 

357 This method applies migrations sequentially to transform data from 

358 an older schema version to the current version. 

359 

360 Args: 

361 data: Data in old schema format 

362 from_version: Source schema version 

363 to_version: Target schema version 

364 

365 Returns: 

366 Data in new schema format 

367 

368 Raises: 

369 SchemaVersionError: If migration path is not supported 

370 """ 

371 # Parse version strings 

372 from_major, from_minor, from_patch = map(int, from_version.split(".")) 

373 to_major, to_minor, to_patch = map(int, to_version.split(".")) 

374 

375 # No migration needed if versions match 

376 if from_version == to_version: 

377 return data 

378 

379 # Apply migrations based on version transitions 

380 # Future migrations will be added here as needed 

381 

382 # Example migration patterns: 

383 # if from_version == "0.0.0" and to_version >= "1.0.0": 

384 # data = cls._migrate_0_to_1(data) 

385 # if from_version < "1.1.0" and to_version >= "1.1.0": 

386 # data = cls._migrate_1_0_to_1_1(data) 

387 

388 # For now, version 0.0.0 (no version field) to 1.0.0 is a no-op 

389 # because the schema didn't change, we just added versioning 

390 if from_version == "0.0.0": 

391 logger.debug("Migrating from unversioned schema to 1.0.0 (no changes needed)") 

392 data["schema_version"] = "1.0.0" 

393 return data 

394 

395 # If we get here and versions still don't match, it's unsupported 

396 if from_major > to_major: 

397 raise SchemaVersionError( 

398 f"Cannot downgrade from schema {from_version} to {to_version}" 

399 ) 

400 

401 logger.warning( 

402 f"No migration path defined from {from_version} to {to_version}. " 

403 "Using data as-is." 

404 ) 

405 data["schema_version"] = to_version 

406 return data 

407 

408 # Future migration methods will be added here as needed: 

409 # @staticmethod 

410 # def _migrate_1_0_to_1_1(data: Dict[str, Any]) -> Dict[str, Any]: 

411 # """Migrate from schema 1.0 to 1.1.""" 

412 # # Add new field with default value 

413 # data["new_field"] = "default_value" 

414 # return data 

415 

416 

417class ConversationStorage(ABC): 

418 """Abstract storage interface for conversations. 

419 

420 This interface defines the contract for persisting conversation state. 

421 Implementations can use any backend (SQL, NoSQL, file, etc.). 

422 """ 

423 

424 @abstractmethod 

425 async def save_conversation(self, state: ConversationState) -> None: 

426 """Save conversation state (upsert). 

427 

428 Args: 

429 state: Conversation state to save 

430 """ 

431 pass 

432 

433 @abstractmethod 

434 async def load_conversation( 

435 self, 

436 conversation_id: str 

437 ) -> Optional[ConversationState]: 

438 """Load conversation state. 

439 

440 Args: 

441 conversation_id: Conversation identifier 

442 

443 Returns: 

444 Conversation state or None if not found 

445 """ 

446 pass 

447 

448 @abstractmethod 

449 async def delete_conversation(self, conversation_id: str) -> bool: 

450 """Delete conversation. 

451 

452 Args: 

453 conversation_id: Conversation identifier 

454 

455 Returns: 

456 True if deleted, False if not found 

457 """ 

458 pass 

459 

460 @abstractmethod 

461 async def list_conversations( 

462 self, 

463 filter_metadata: Optional[Dict[str, Any]] = None, 

464 limit: int = 100, 

465 offset: int = 0 

466 ) -> List[ConversationState]: 

467 """List conversations with optional filtering. 

468 

469 Args: 

470 filter_metadata: Optional metadata filters 

471 limit: Maximum number of results 

472 offset: Offset for pagination 

473 

474 Returns: 

475 List of conversation states 

476 """ 

477 pass 

478 

479 

480class DataknobsConversationStorage(ConversationStorage): 

481 """Conversation storage using dataknobs_data backends. 

482 

483 Stores conversations as Records with the tree serialized as nodes + edges. 

484 Works with any dataknobs backend (Memory, File, S3, Postgres, etc.). 

485 

486 Example: 

487 >>> from dataknobs_data.backends import AsyncMemoryDatabase 

488 >>> storage = DataknobsConversationStorage(AsyncMemoryDatabase()) 

489 >>> await storage.save_conversation(state) 

490 """ 

491 

492 def __init__(self, backend: Any): 

493 """Initialize storage with dataknobs backend. 

494 

495 Args: 

496 backend: Dataknobs async database backend (AsyncMemoryDatabase, etc.) 

497 """ 

498 self.backend = backend 

499 

500 def _state_to_record(self, state: ConversationState) -> Any: 

501 """Convert ConversationState to Record. 

502 

503 Args: 

504 state: Conversation state to convert 

505 

506 Returns: 

507 Record object for storage 

508 """ 

509 # Import here to avoid circular dependency 

510 try: 

511 from dataknobs_data.records import Record 

512 except ImportError: 

513 raise StorageError( 

514 "dataknobs_data package not available. " 

515 "Install it to use DataknobsConversationStorage." 

516 ) 

517 

518 # Convert state to dict 

519 data = state.to_dict() 

520 

521 # Create Record with conversation_id as storage_id 

522 return Record( 

523 data=data, 

524 storage_id=state.conversation_id 

525 ) 

526 

527 def _record_to_state(self, record: Any) -> ConversationState: 

528 """Convert Record to ConversationState. 

529 

530 Args: 

531 record: Record object from storage 

532 

533 Returns: 

534 Conversation state 

535 """ 

536 # Extract data from record 

537 data = {} 

538 for field_name, field in record.fields.items(): 

539 data[field_name] = field.value 

540 

541 # Reconstruct conversation state 

542 return ConversationState.from_dict(data) 

543 

544 async def save_conversation(self, state: ConversationState) -> None: 

545 """Save conversation to backend.""" 

546 try: 

547 record = self._state_to_record(state) 

548 # Use upsert to insert or update 

549 await self.backend.upsert(state.conversation_id, record) 

550 except Exception as e: 

551 raise StorageError(f"Failed to save conversation: {e}") from e 

552 

553 async def load_conversation( 

554 self, 

555 conversation_id: str 

556 ) -> Optional[ConversationState]: 

557 """Load conversation from backend.""" 

558 try: 

559 # Read record by ID 

560 record = await self.backend.read(conversation_id) 

561 if record is None: 

562 return None 

563 

564 return self._record_to_state(record) 

565 

566 except Exception as e: 

567 raise StorageError(f"Failed to load conversation: {e}") from e 

568 

569 async def delete_conversation(self, conversation_id: str) -> bool: 

570 """Delete conversation from backend.""" 

571 try: 

572 return await self.backend.delete(conversation_id) 

573 except Exception as e: 

574 raise StorageError(f"Failed to delete conversation: {e}") from e 

575 

576 async def list_conversations( 

577 self, 

578 filter_metadata: Optional[Dict[str, Any]] = None, 

579 limit: int = 100, 

580 offset: int = 0 

581 ) -> List[ConversationState]: 

582 """List conversations from backend.""" 

583 try: 

584 # Import Query for filtering 

585 try: 

586 from dataknobs_data.query import Query 

587 except ImportError: 

588 raise StorageError( 

589 "dataknobs_data package not available. " 

590 "Install it to use DataknobsConversationStorage." 

591 ) 

592 

593 # Build query with metadata filters using fluent interface 

594 query = Query() 

595 query.limit(limit).offset(offset) 

596 

597 if filter_metadata: 

598 for key, value in filter_metadata.items(): 

599 # Add filter for metadata.key = value 

600 query.filter(f"metadata.{key}", "=", value) 

601 

602 # Search with query 

603 results = await self.backend.search(query) 

604 

605 # Convert records to conversation states 

606 return [self._record_to_state(record) for record in results] 

607 

608 except Exception as e: 

609 raise StorageError(f"Failed to list conversations: {e}") from e 

610 

611 

612class StorageError(Exception): 

613 """Exception raised for storage operation errors.""" 

614 pass 

615 

616 

617class SchemaVersionError(Exception): 

618 """Exception raised for schema version incompatibilities.""" 

619 pass