Coverage for src/dataknobs_llm/prompts/base/base_prompt_library.py: 83%

142 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-08 13:51 -0700

1"""Base implementation of prompt library with shared functionality. 

2 

3This module provides BasePromptLibrary, a concrete implementation of 

4AbstractPromptLibrary that includes caching and common utilities. 

5""" 

6 

7from typing import Any, Dict, List, Union 

8import logging 

9 

10from .abstract_prompt_library import AbstractPromptLibrary 

11from .types import PromptTemplateDict, MessageIndex, RAGConfig, ValidationConfig, ValidationLevel 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16class BasePromptLibrary(AbstractPromptLibrary): 

17 """Base implementation with caching and common functionality. 

18 

19 This class provides: 

20 - Optional caching of loaded prompts and message indexes 

21 - Helper methods for cache management 

22 - Shared metadata handling 

23 - Default implementations of optional methods 

24 

25 Subclasses should implement the abstract methods to provide 

26 the actual prompt loading logic. 

27 """ 

28 

29 def __init__( 

30 self, 

31 enable_cache: bool = True, 

32 metadata: Dict[str, Any] | None = None 

33 ): 

34 """Initialize the base prompt library. 

35 

36 Args: 

37 enable_cache: Whether to cache loaded prompts (default: True) 

38 metadata: Optional metadata dictionary 

39 """ 

40 self._enable_cache = enable_cache 

41 self._metadata = metadata or {} 

42 

43 # Caches for loaded prompts and indexes 

44 self._system_prompt_cache: Dict[str, PromptTemplateDict] = {} 

45 self._user_prompt_cache: Dict[str, PromptTemplateDict] = {} 

46 self._message_index_cache: Dict[str, MessageIndex] = {} 

47 self._rag_config_cache: Dict[str, RAGConfig] = {} # Standalone RAG configs 

48 self._prompt_rag_cache: Dict[tuple, List[RAGConfig]] = {} # (name, type) 

49 

50 # ===== Cache Management ===== 

51 

52 def clear_cache(self) -> None: 

53 """Clear all cached prompts and indexes.""" 

54 self._system_prompt_cache.clear() 

55 self._user_prompt_cache.clear() 

56 self._message_index_cache.clear() 

57 self._rag_config_cache.clear() 

58 self._prompt_rag_cache.clear() 

59 logger.debug(f"Cleared cache for {self.__class__.__name__}") 

60 

61 def reload(self) -> None: 

62 """Reload the library by clearing the cache. 

63 

64 Subclasses can override to perform additional reload logic. 

65 """ 

66 if self._enable_cache: 

67 self.clear_cache() 

68 logger.info(f"Reloaded {self.__class__.__name__}") 

69 

70 # ===== Metadata ===== 

71 

72 def get_metadata(self) -> Dict[str, Any]: 

73 """Get metadata about this prompt library. 

74 

75 Returns: 

76 Dictionary with library metadata 

77 """ 

78 return { 

79 "class": self.__class__.__name__, 

80 "cache_enabled": self._enable_cache, 

81 **self._metadata 

82 } 

83 

84 # ===== Cache Helpers ===== 

85 

86 def _get_cached_system_prompt(self, name: str) -> PromptTemplateDict | None: 

87 """Get system prompt from cache if caching is enabled. 

88 

89 Args: 

90 name: System prompt identifier 

91 

92 Returns: 

93 Cached PromptTemplateDict if found, None otherwise 

94 """ 

95 if not self._enable_cache: 

96 return None 

97 return self._system_prompt_cache.get(name) 

98 

99 def _cache_system_prompt(self, name: str, template: PromptTemplateDict) -> None: 

100 """Cache a system prompt if caching is enabled. 

101 

102 Args: 

103 name: System prompt identifier 

104 template: PromptTemplateDict to cache 

105 """ 

106 if self._enable_cache: 

107 self._system_prompt_cache[name] = template 

108 

109 def _get_cached_user_prompt(self, name: str) -> PromptTemplateDict | None: 

110 """Get user prompt from cache if caching is enabled. 

111 

112 Args: 

113 name: User prompt identifier 

114 

115 Returns: 

116 Cached PromptTemplateDict if found, None otherwise 

117 """ 

118 if not self._enable_cache: 

119 return None 

120 return self._user_prompt_cache.get(name) 

121 

122 def _cache_user_prompt(self, name: str, template: PromptTemplateDict) -> None: 

123 """Cache a user prompt if caching is enabled. 

124 

125 Args: 

126 name: User prompt identifier 

127 template: PromptTemplateDict to cache 

128 """ 

129 if self._enable_cache: 

130 self._user_prompt_cache[name] = template 

131 

132 def _get_cached_message_index(self, name: str) -> MessageIndex | None: 

133 """Get message index from cache if caching is enabled. 

134 

135 Args: 

136 name: Message index identifier 

137 

138 Returns: 

139 Cached MessageIndex if found, None otherwise 

140 """ 

141 if not self._enable_cache: 

142 return None 

143 return self._message_index_cache.get(name) 

144 

145 def _cache_message_index(self, name: str, index: MessageIndex) -> None: 

146 """Cache a message index if caching is enabled. 

147 

148 Args: 

149 name: Message index identifier 

150 index: MessageIndex to cache 

151 """ 

152 if self._enable_cache: 

153 self._message_index_cache[name] = index 

154 

155 def _get_cached_rag_config(self, name: str) -> RAGConfig | None: 

156 """Get standalone RAG config from cache if caching is enabled. 

157 

158 Args: 

159 name: RAG config identifier 

160 

161 Returns: 

162 Cached RAGConfig if found, None otherwise 

163 """ 

164 if not self._enable_cache: 

165 return None 

166 return self._rag_config_cache.get(name) 

167 

168 def _cache_rag_config(self, name: str, config: RAGConfig) -> None: 

169 """Cache a standalone RAG config if caching is enabled. 

170 

171 Args: 

172 name: RAG config identifier 

173 config: RAGConfig to cache 

174 """ 

175 if self._enable_cache: 

176 self._rag_config_cache[name] = config 

177 

178 def _get_cached_prompt_rag_configs( 

179 self, 

180 prompt_name: str, 

181 prompt_type: str 

182 ) -> List[RAGConfig] | None: 

183 """Get prompt RAG configs from cache if caching is enabled. 

184 

185 Args: 

186 prompt_name: Prompt identifier 

187 prompt_type: Type of prompt ("user" or "system") 

188 

189 Returns: 

190 Cached list of RAGConfig if found, None otherwise 

191 """ 

192 if not self._enable_cache: 

193 return None 

194 return self._prompt_rag_cache.get((prompt_name, prompt_type)) 

195 

196 def _cache_prompt_rag_configs( 

197 self, 

198 prompt_name: str, 

199 prompt_type: str, 

200 configs: List[RAGConfig] 

201 ) -> None: 

202 """Cache prompt RAG configurations if caching is enabled. 

203 

204 Args: 

205 prompt_name: Prompt identifier 

206 prompt_type: Type of prompt ("user" or "system") 

207 configs: List of RAGConfig to cache 

208 """ 

209 if self._enable_cache: 

210 self._prompt_rag_cache[(prompt_name, prompt_type)] = configs 

211 

212 # ===== Common Parsing Methods ===== 

213 

214 def _parse_validation_config(self, data: Union[Dict, ValidationConfig]) -> ValidationConfig: 

215 """Parse validation configuration from dict or ValidationConfig. 

216 

217 This method is shared by all library implementations for consistent 

218 validation config parsing. 

219 

220 Args: 

221 data: Validation data (dict or ValidationConfig instance) 

222 

223 Returns: 

224 ValidationConfig instance 

225 

226 Raises: 

227 ValueError: If data type is invalid 

228 """ 

229 if isinstance(data, ValidationConfig): 

230 return data 

231 

232 if not isinstance(data, dict): 

233 raise ValueError( 

234 f"Invalid validation config: expected dict or ValidationConfig, " 

235 f"got {type(data)}" 

236 ) 

237 

238 # Parse level 

239 level = None 

240 if "level" in data: 

241 level_data = data["level"] 

242 if isinstance(level_data, str): 

243 level = ValidationLevel(level_data.lower()) 

244 elif isinstance(level_data, ValidationLevel): 

245 level = level_data 

246 

247 # Parse params 

248 required_params = data.get("required_params", []) 

249 optional_params = data.get("optional_params", []) 

250 

251 return ValidationConfig( 

252 level=level, 

253 required_params=required_params, 

254 optional_params=optional_params 

255 ) 

256 

257 def _parse_rag_config(self, data: Dict[str, Any]) -> RAGConfig: 

258 """Parse RAG configuration from dict. 

259 

260 This method is shared by all library implementations for consistent 

261 RAG config parsing. 

262 

263 Args: 

264 data: RAG config data dictionary 

265 

266 Returns: 

267 RAGConfig dictionary 

268 """ 

269 rag_config: RAGConfig = { 

270 "adapter_name": data.get("adapter_name", ""), 

271 "query": data.get("query", ""), 

272 } 

273 

274 # Add optional fields 

275 if "k" in data: 

276 rag_config["k"] = data["k"] 

277 

278 if "filters" in data: 

279 rag_config["filters"] = data["filters"] 

280 

281 if "placeholder" in data: 

282 rag_config["placeholder"] = data["placeholder"] 

283 

284 if "header" in data: 

285 rag_config["header"] = data["header"] 

286 

287 if "item_template" in data: 

288 rag_config["item_template"] = data["item_template"] 

289 

290 return rag_config 

291 

292 def _parse_prompt_template(self, data: Any) -> PromptTemplateDict: 

293 """Parse prompt template from various formats. 

294 

295 This method is shared by all library implementations for consistent 

296 template parsing. Supports: 

297 - String templates (converted to {"template": string}) 

298 - Dict with "template" key 

299 - Dict with "extends" key but no "template" (template inherited) 

300 - Empty dict (treated as {"template": ""}) 

301 

302 Args: 

303 data: Prompt template data (string or dict) 

304 

305 Returns: 

306 PromptTemplateDict dictionary 

307 

308 Raises: 

309 ValueError: If data format is invalid 

310 """ 

311 # If just a string, treat as template 

312 if isinstance(data, str): 

313 return {"template": data} 

314 

315 # If empty dict, treat as empty template 

316 if isinstance(data, dict) and len(data) == 0: 

317 return {"template": ""} 

318 

319 # Initialize template 

320 template: PromptTemplateDict = None 

321 

322 # If dict with template field 

323 if isinstance(data, dict) and "template" in data: 

324 template = { 

325 "template": data["template"], 

326 } 

327 

328 # Add optional fields 

329 if "defaults" in data: 

330 template["defaults"] = data["defaults"] 

331 

332 if "validation" in data: 

333 template["validation"] = self._parse_validation_config(data["validation"]) 

334 

335 if "metadata" in data: 

336 template["metadata"] = data["metadata"] 

337 

338 # Add template mode field 

339 if "template_mode" in data: 

340 template["template_mode"] = data["template_mode"] 

341 

342 # Add composition fields 

343 if "sections" in data: 

344 template["sections"] = data["sections"] 

345 

346 if "extends" in data: 

347 template["extends"] = data["extends"] 

348 

349 # Add RAG configuration fields 

350 if "rag_config_refs" in data: 

351 template["rag_config_refs"] = data["rag_config_refs"] 

352 

353 if "rag_configs" in data: 

354 template["rag_configs"] = [ 

355 self._parse_rag_config(rag_data) 

356 for rag_data in data["rag_configs"] 

357 ] 

358 

359 return template 

360 

361 # If dict with extends field but no template (template will be inherited) 

362 elif isinstance(data, dict) and "extends" in data: 

363 template = {} 

364 

365 # Template will be inherited from base 

366 template["extends"] = data["extends"] 

367 

368 # Add optional override fields 

369 if "defaults" in data: 

370 template["defaults"] = data["defaults"] 

371 

372 if "validation" in data: 

373 template["validation"] = self._parse_validation_config(data["validation"]) 

374 

375 if "metadata" in data: 

376 template["metadata"] = data["metadata"] 

377 

378 # Add template mode field 

379 if "template_mode" in data: 

380 template["template_mode"] = data["template_mode"] 

381 

382 if "sections" in data: 

383 template["sections"] = data["sections"] 

384 

385 # Add RAG configuration fields 

386 if "rag_config_refs" in data: 

387 template["rag_config_refs"] = data["rag_config_refs"] 

388 

389 if "rag_configs" in data: 

390 template["rag_configs"] = [ 

391 self._parse_rag_config(rag_data) 

392 for rag_data in data["rag_configs"] 

393 ] 

394 

395 return template 

396 

397 else: 

398 raise ValueError( 

399 f"Invalid prompt template data: expected dict with 'template' or 'extends' key, " 

400 f"or string, got {type(data)}" 

401 ) 

402 

403 # ===== Abstract Methods (must be implemented by subclasses) ===== 

404 

405 def get_system_prompt( 

406 self, 

407 name: str, 

408 **kwargs: Any 

409 ) -> PromptTemplateDict | None: 

410 """Retrieve a system prompt template by name. 

411 

412 Subclasses must implement this method. 

413 """ 

414 raise NotImplementedError( 

415 f"{self.__class__.__name__} must implement get_system_prompt()" 

416 ) 

417 

418 def list_system_prompts(self) -> List[str]: 

419 """List all available system prompt names. 

420 

421 Subclasses must implement this method. 

422 """ 

423 raise NotImplementedError( 

424 f"{self.__class__.__name__} must implement list_system_prompts()" 

425 ) 

426 

427 def get_user_prompt( 

428 self, 

429 name: str, 

430 index: int = 0, 

431 **kwargs: Any 

432 ) -> PromptTemplateDict | None: 

433 """Retrieve a user prompt template by name and index. 

434 

435 Subclasses must implement this method. 

436 """ 

437 raise NotImplementedError( 

438 f"{self.__class__.__name__} must implement get_user_prompt()" 

439 ) 

440 

441 def list_user_prompts(self) -> List[str]: 

442 """List available user prompts. 

443 

444 Subclasses must implement this method. 

445 """ 

446 raise NotImplementedError( 

447 f"{self.__class__.__name__} must implement list_user_prompts()" 

448 ) 

449 

450 def get_message_index( 

451 self, 

452 name: str, 

453 **kwargs: Any 

454 ) -> MessageIndex | None: 

455 """Retrieve a message index by name. 

456 

457 Subclasses must implement this method. 

458 """ 

459 raise NotImplementedError( 

460 f"{self.__class__.__name__} must implement get_message_index()" 

461 ) 

462 

463 def list_message_indexes(self) -> List[str]: 

464 """List all available message index names. 

465 

466 Subclasses must implement this method. 

467 """ 

468 raise NotImplementedError( 

469 f"{self.__class__.__name__} must implement list_message_indexes()" 

470 ) 

471 

472 def get_rag_config( 

473 self, 

474 name: str, 

475 **kwargs: Any 

476 ) -> RAGConfig | None: 

477 """Retrieve a standalone RAG configuration by name. 

478 

479 Subclasses must implement this method. 

480 """ 

481 raise NotImplementedError( 

482 f"{self.__class__.__name__} must implement get_rag_config()" 

483 ) 

484 

485 def get_prompt_rag_configs( 

486 self, 

487 prompt_name: str, 

488 prompt_type: str = "user", 

489 index: int = 0, 

490 **kwargs: Any 

491 ) -> List[RAGConfig]: 

492 """Retrieve RAG configurations for a specific prompt. 

493 

494 Subclasses must implement this method. 

495 """ 

496 raise NotImplementedError( 

497 f"{self.__class__.__name__} must implement get_prompt_rag_configs()" 

498 )