Coverage for src/dataknobs_llm/prompts/base/base_prompt_library.py: 22%

141 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-31 16:04 -0600

1"""Base implementation of prompt library with shared functionality. 

2 

3This module provides BasePromptLibrary, a concrete implementation of 

4AbstractPromptLibrary that includes caching and common utilities. 

5""" 

6 

7from typing import Any, Dict, List, Optional, Union 

8import logging 

9 

10from .abstract_prompt_library import AbstractPromptLibrary 

11from .types import PromptTemplate, MessageIndex, RAGConfig, ValidationConfig, ValidationLevel 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16class BasePromptLibrary(AbstractPromptLibrary): 

17 """Base implementation with caching and common functionality. 

18 

19 This class provides: 

20 - Optional caching of loaded prompts and message indexes 

21 - Helper methods for cache management 

22 - Shared metadata handling 

23 - Default implementations of optional methods 

24 

25 Subclasses should implement the abstract methods to provide 

26 the actual prompt loading logic. 

27 """ 

28 

29 def __init__( 

30 self, 

31 enable_cache: bool = True, 

32 metadata: Optional[Dict[str, Any]] = None 

33 ): 

34 """Initialize the base prompt library. 

35 

36 Args: 

37 enable_cache: Whether to cache loaded prompts (default: True) 

38 metadata: Optional metadata dictionary 

39 """ 

40 self._enable_cache = enable_cache 

41 self._metadata = metadata or {} 

42 

43 # Caches for loaded prompts and indexes 

44 self._system_prompt_cache: Dict[str, PromptTemplate] = {} 

45 self._user_prompt_cache: Dict[str, PromptTemplate] = {} 

46 self._message_index_cache: Dict[str, MessageIndex] = {} 

47 self._rag_config_cache: Dict[str, RAGConfig] = {} # Standalone RAG configs 

48 self._prompt_rag_cache: Dict[tuple, List[RAGConfig]] = {} # (name, type) 

49 

50 # ===== Cache Management ===== 

51 

52 def clear_cache(self) -> None: 

53 """Clear all cached prompts and indexes.""" 

54 self._system_prompt_cache.clear() 

55 self._user_prompt_cache.clear() 

56 self._message_index_cache.clear() 

57 self._rag_config_cache.clear() 

58 self._prompt_rag_cache.clear() 

59 logger.debug(f"Cleared cache for {self.__class__.__name__}") 

60 

61 def reload(self) -> None: 

62 """Reload the library by clearing the cache. 

63 

64 Subclasses can override to perform additional reload logic. 

65 """ 

66 if self._enable_cache: 

67 self.clear_cache() 

68 logger.info(f"Reloaded {self.__class__.__name__}") 

69 

70 # ===== Metadata ===== 

71 

72 def get_metadata(self) -> Dict[str, Any]: 

73 """Get metadata about this prompt library. 

74 

75 Returns: 

76 Dictionary with library metadata 

77 """ 

78 return { 

79 "class": self.__class__.__name__, 

80 "cache_enabled": self._enable_cache, 

81 **self._metadata 

82 } 

83 

84 # ===== Cache Helpers ===== 

85 

86 def _get_cached_system_prompt(self, name: str) -> Optional[PromptTemplate]: 

87 """Get system prompt from cache if caching is enabled. 

88 

89 Args: 

90 name: System prompt identifier 

91 

92 Returns: 

93 Cached PromptTemplate if found, None otherwise 

94 """ 

95 if not self._enable_cache: 

96 return None 

97 return self._system_prompt_cache.get(name) 

98 

99 def _cache_system_prompt(self, name: str, template: PromptTemplate) -> None: 

100 """Cache a system prompt if caching is enabled. 

101 

102 Args: 

103 name: System prompt identifier 

104 template: PromptTemplate to cache 

105 """ 

106 if self._enable_cache: 

107 self._system_prompt_cache[name] = template 

108 

109 def _get_cached_user_prompt(self, name: str) -> Optional[PromptTemplate]: 

110 """Get user prompt from cache if caching is enabled. 

111 

112 Args: 

113 name: User prompt identifier 

114 

115 Returns: 

116 Cached PromptTemplate if found, None otherwise 

117 """ 

118 if not self._enable_cache: 

119 return None 

120 return self._user_prompt_cache.get(name) 

121 

122 def _cache_user_prompt(self, name: str, template: PromptTemplate) -> None: 

123 """Cache a user prompt if caching is enabled. 

124 

125 Args: 

126 name: User prompt identifier 

127 template: PromptTemplate to cache 

128 """ 

129 if self._enable_cache: 

130 self._user_prompt_cache[name] = template 

131 

132 def _get_cached_message_index(self, name: str) -> Optional[MessageIndex]: 

133 """Get message index from cache if caching is enabled. 

134 

135 Args: 

136 name: Message index identifier 

137 

138 Returns: 

139 Cached MessageIndex if found, None otherwise 

140 """ 

141 if not self._enable_cache: 

142 return None 

143 return self._message_index_cache.get(name) 

144 

145 def _cache_message_index(self, name: str, index: MessageIndex) -> None: 

146 """Cache a message index if caching is enabled. 

147 

148 Args: 

149 name: Message index identifier 

150 index: MessageIndex to cache 

151 """ 

152 if self._enable_cache: 

153 self._message_index_cache[name] = index 

154 

155 def _get_cached_rag_config(self, name: str) -> Optional[RAGConfig]: 

156 """Get standalone RAG config from cache if caching is enabled. 

157 

158 Args: 

159 name: RAG config identifier 

160 

161 Returns: 

162 Cached RAGConfig if found, None otherwise 

163 """ 

164 if not self._enable_cache: 

165 return None 

166 return self._rag_config_cache.get(name) 

167 

168 def _cache_rag_config(self, name: str, config: RAGConfig) -> None: 

169 """Cache a standalone RAG config if caching is enabled. 

170 

171 Args: 

172 name: RAG config identifier 

173 config: RAGConfig to cache 

174 """ 

175 if self._enable_cache: 

176 self._rag_config_cache[name] = config 

177 

178 def _get_cached_prompt_rag_configs( 

179 self, 

180 prompt_name: str, 

181 prompt_type: str 

182 ) -> Optional[List[RAGConfig]]: 

183 """Get prompt RAG configs from cache if caching is enabled. 

184 

185 Args: 

186 prompt_name: Prompt identifier 

187 prompt_type: Type of prompt ("user" or "system") 

188 

189 Returns: 

190 Cached list of RAGConfig if found, None otherwise 

191 """ 

192 if not self._enable_cache: 

193 return None 

194 return self._prompt_rag_cache.get((prompt_name, prompt_type)) 

195 

196 def _cache_prompt_rag_configs( 

197 self, 

198 prompt_name: str, 

199 prompt_type: str, 

200 configs: List[RAGConfig] 

201 ) -> None: 

202 """Cache prompt RAG configurations if caching is enabled. 

203 

204 Args: 

205 prompt_name: Prompt identifier 

206 prompt_type: Type of prompt ("user" or "system") 

207 configs: List of RAGConfig to cache 

208 """ 

209 if self._enable_cache: 

210 self._prompt_rag_cache[(prompt_name, prompt_type)] = configs 

211 

212 # ===== Common Parsing Methods ===== 

213 

214 def _parse_validation_config(self, data: Union[Dict, ValidationConfig]) -> ValidationConfig: 

215 """Parse validation configuration from dict or ValidationConfig. 

216 

217 This method is shared by all library implementations for consistent 

218 validation config parsing. 

219 

220 Args: 

221 data: Validation data (dict or ValidationConfig instance) 

222 

223 Returns: 

224 ValidationConfig instance 

225 

226 Raises: 

227 ValueError: If data type is invalid 

228 """ 

229 if isinstance(data, ValidationConfig): 

230 return data 

231 

232 if not isinstance(data, dict): 

233 raise ValueError( 

234 f"Invalid validation config: expected dict or ValidationConfig, " 

235 f"got {type(data)}" 

236 ) 

237 

238 # Parse level 

239 level = None 

240 if "level" in data: 

241 level_data = data["level"] 

242 if isinstance(level_data, str): 

243 level = ValidationLevel(level_data.lower()) 

244 elif isinstance(level_data, ValidationLevel): 

245 level = level_data 

246 

247 # Parse params 

248 required_params = data.get("required_params", []) 

249 optional_params = data.get("optional_params", []) 

250 

251 return ValidationConfig( 

252 level=level, 

253 required_params=required_params, 

254 optional_params=optional_params 

255 ) 

256 

257 def _parse_rag_config(self, data: Dict[str, Any]) -> RAGConfig: 

258 """Parse RAG configuration from dict. 

259 

260 This method is shared by all library implementations for consistent 

261 RAG config parsing. 

262 

263 Args: 

264 data: RAG config data dictionary 

265 

266 Returns: 

267 RAGConfig dictionary 

268 """ 

269 rag_config: RAGConfig = { 

270 "adapter_name": data.get("adapter_name", ""), 

271 "query": data.get("query", ""), 

272 } 

273 

274 # Add optional fields 

275 if "k" in data: 

276 rag_config["k"] = data["k"] 

277 

278 if "filters" in data: 

279 rag_config["filters"] = data["filters"] 

280 

281 if "placeholder" in data: 

282 rag_config["placeholder"] = data["placeholder"] 

283 

284 if "header" in data: 

285 rag_config["header"] = data["header"] 

286 

287 if "item_template" in data: 

288 rag_config["item_template"] = data["item_template"] 

289 

290 return rag_config 

291 

292 def _parse_prompt_template(self, data: Any) -> PromptTemplate: 

293 """Parse prompt template from various formats. 

294 

295 This method is shared by all library implementations for consistent 

296 template parsing. Supports: 

297 - String templates (converted to {"template": string}) 

298 - Dict with "template" key 

299 - Dict with "extends" key but no "template" (template inherited) 

300 - Empty dict (treated as {"template": ""}) 

301 

302 Args: 

303 data: Prompt template data (string or dict) 

304 

305 Returns: 

306 PromptTemplate dictionary 

307 

308 Raises: 

309 ValueError: If data format is invalid 

310 """ 

311 # If just a string, treat as template 

312 if isinstance(data, str): 

313 return {"template": data} 

314 

315 # If empty dict, treat as empty template 

316 if isinstance(data, dict) and len(data) == 0: 

317 return {"template": ""} 

318 

319 # If dict with template field 

320 if isinstance(data, dict) and "template" in data: 

321 template: PromptTemplate = { 

322 "template": data["template"], 

323 } 

324 

325 # Add optional fields 

326 if "defaults" in data: 

327 template["defaults"] = data["defaults"] 

328 

329 if "validation" in data: 

330 template["validation"] = self._parse_validation_config(data["validation"]) 

331 

332 if "metadata" in data: 

333 template["metadata"] = data["metadata"] 

334 

335 # Add template mode field 

336 if "template_mode" in data: 

337 template["template_mode"] = data["template_mode"] 

338 

339 # Add composition fields 

340 if "sections" in data: 

341 template["sections"] = data["sections"] 

342 

343 if "extends" in data: 

344 template["extends"] = data["extends"] 

345 

346 # Add RAG configuration fields 

347 if "rag_config_refs" in data: 

348 template["rag_config_refs"] = data["rag_config_refs"] 

349 

350 if "rag_configs" in data: 

351 template["rag_configs"] = [ 

352 self._parse_rag_config(rag_data) 

353 for rag_data in data["rag_configs"] 

354 ] 

355 

356 return template 

357 

358 # If dict with extends field but no template (template will be inherited) 

359 elif isinstance(data, dict) and "extends" in data: 

360 template: PromptTemplate = {} 

361 

362 # Template will be inherited from base 

363 template["extends"] = data["extends"] 

364 

365 # Add optional override fields 

366 if "defaults" in data: 

367 template["defaults"] = data["defaults"] 

368 

369 if "validation" in data: 

370 template["validation"] = self._parse_validation_config(data["validation"]) 

371 

372 if "metadata" in data: 

373 template["metadata"] = data["metadata"] 

374 

375 # Add template mode field 

376 if "template_mode" in data: 

377 template["template_mode"] = data["template_mode"] 

378 

379 if "sections" in data: 

380 template["sections"] = data["sections"] 

381 

382 # Add RAG configuration fields 

383 if "rag_config_refs" in data: 

384 template["rag_config_refs"] = data["rag_config_refs"] 

385 

386 if "rag_configs" in data: 

387 template["rag_configs"] = [ 

388 self._parse_rag_config(rag_data) 

389 for rag_data in data["rag_configs"] 

390 ] 

391 

392 return template 

393 

394 else: 

395 raise ValueError( 

396 f"Invalid prompt template data: expected dict with 'template' or 'extends' key, " 

397 f"or string, got {type(data)}" 

398 ) 

399 

400 # ===== Abstract Methods (must be implemented by subclasses) ===== 

401 

402 def get_system_prompt( 

403 self, 

404 name: str, 

405 **kwargs: Any 

406 ) -> Optional[PromptTemplate]: 

407 """Retrieve a system prompt template by name. 

408 

409 Subclasses must implement this method. 

410 """ 

411 raise NotImplementedError( 

412 f"{self.__class__.__name__} must implement get_system_prompt()" 

413 ) 

414 

415 def list_system_prompts(self) -> List[str]: 

416 """List all available system prompt names. 

417 

418 Subclasses must implement this method. 

419 """ 

420 raise NotImplementedError( 

421 f"{self.__class__.__name__} must implement list_system_prompts()" 

422 ) 

423 

424 def get_user_prompt( 

425 self, 

426 name: str, 

427 index: int = 0, 

428 **kwargs: Any 

429 ) -> Optional[PromptTemplate]: 

430 """Retrieve a user prompt template by name and index. 

431 

432 Subclasses must implement this method. 

433 """ 

434 raise NotImplementedError( 

435 f"{self.__class__.__name__} must implement get_user_prompt()" 

436 ) 

437 

438 def list_user_prompts(self, name: Optional[str] = None) -> List[str]: 

439 """List available user prompts. 

440 

441 Subclasses must implement this method. 

442 """ 

443 raise NotImplementedError( 

444 f"{self.__class__.__name__} must implement list_user_prompts()" 

445 ) 

446 

447 def get_message_index( 

448 self, 

449 name: str, 

450 **kwargs: Any 

451 ) -> Optional[MessageIndex]: 

452 """Retrieve a message index by name. 

453 

454 Subclasses must implement this method. 

455 """ 

456 raise NotImplementedError( 

457 f"{self.__class__.__name__} must implement get_message_index()" 

458 ) 

459 

460 def list_message_indexes(self) -> List[str]: 

461 """List all available message index names. 

462 

463 Subclasses must implement this method. 

464 """ 

465 raise NotImplementedError( 

466 f"{self.__class__.__name__} must implement list_message_indexes()" 

467 ) 

468 

469 def get_rag_config( 

470 self, 

471 name: str, 

472 **kwargs: Any 

473 ) -> Optional[RAGConfig]: 

474 """Retrieve a standalone RAG configuration by name. 

475 

476 Subclasses must implement this method. 

477 """ 

478 raise NotImplementedError( 

479 f"{self.__class__.__name__} must implement get_rag_config()" 

480 ) 

481 

482 def get_prompt_rag_configs( 

483 self, 

484 prompt_name: str, 

485 prompt_type: str = "user", 

486 index: int = 0, 

487 **kwargs: Any 

488 ) -> List[RAGConfig]: 

489 """Retrieve RAG configurations for a specific prompt. 

490 

491 Subclasses must implement this method. 

492 """ 

493 raise NotImplementedError( 

494 f"{self.__class__.__name__} must implement get_prompt_rag_configs()" 

495 )