Coverage for src / dataknobs_llm / prompts / base / base_prompt_library.py: 38%
142 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-15 10:29 -0700
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-15 10:29 -0700
1"""Base implementation of prompt library with shared functionality.
3This module provides BasePromptLibrary, a concrete implementation of
4AbstractPromptLibrary that includes caching and common utilities.
5"""
7from typing import Any, Dict, List, Union
8import logging
10from .abstract_prompt_library import AbstractPromptLibrary
11from .types import PromptTemplateDict, MessageIndex, RAGConfig, ValidationConfig, ValidationLevel
13logger = logging.getLogger(__name__)
16class BasePromptLibrary(AbstractPromptLibrary):
17 """Base implementation with caching and common functionality.
19 This class provides:
20 - Optional caching of loaded prompts and message indexes
21 - Helper methods for cache management
22 - Shared metadata handling
23 - Default implementations of optional methods
25 Subclasses should implement the abstract methods to provide
26 the actual prompt loading logic.
27 """
29 def __init__(
30 self,
31 enable_cache: bool = True,
32 metadata: Dict[str, Any] | None = None
33 ):
34 """Initialize the base prompt library.
36 Args:
37 enable_cache: Whether to cache loaded prompts (default: True)
38 metadata: Optional metadata dictionary
39 """
40 self._enable_cache = enable_cache
41 self._metadata = metadata or {}
43 # Caches for loaded prompts and indexes
44 self._system_prompt_cache: Dict[str, PromptTemplateDict] = {}
45 self._user_prompt_cache: Dict[str, PromptTemplateDict] = {}
46 self._message_index_cache: Dict[str, MessageIndex] = {}
47 self._rag_config_cache: Dict[str, RAGConfig] = {} # Standalone RAG configs
48 self._prompt_rag_cache: Dict[tuple, List[RAGConfig]] = {} # (name, type)
50 # ===== Cache Management =====
52 def clear_cache(self) -> None:
53 """Clear all cached prompts and indexes."""
54 self._system_prompt_cache.clear()
55 self._user_prompt_cache.clear()
56 self._message_index_cache.clear()
57 self._rag_config_cache.clear()
58 self._prompt_rag_cache.clear()
59 logger.debug(f"Cleared cache for {self.__class__.__name__}")
61 def reload(self) -> None:
62 """Reload the library by clearing the cache.
64 Subclasses can override to perform additional reload logic.
65 """
66 if self._enable_cache:
67 self.clear_cache()
68 logger.info(f"Reloaded {self.__class__.__name__}")
70 # ===== Metadata =====
72 def get_metadata(self) -> Dict[str, Any]:
73 """Get metadata about this prompt library.
75 Returns:
76 Dictionary with library metadata
77 """
78 return {
79 "class": self.__class__.__name__,
80 "cache_enabled": self._enable_cache,
81 **self._metadata
82 }
84 # ===== Cache Helpers =====
86 def _get_cached_system_prompt(self, name: str) -> PromptTemplateDict | None:
87 """Get system prompt from cache if caching is enabled.
89 Args:
90 name: System prompt identifier
92 Returns:
93 Cached PromptTemplateDict if found, None otherwise
94 """
95 if not self._enable_cache:
96 return None
97 return self._system_prompt_cache.get(name)
99 def _cache_system_prompt(self, name: str, template: PromptTemplateDict) -> None:
100 """Cache a system prompt if caching is enabled.
102 Args:
103 name: System prompt identifier
104 template: PromptTemplateDict to cache
105 """
106 if self._enable_cache:
107 self._system_prompt_cache[name] = template
109 def _get_cached_user_prompt(self, name: str) -> PromptTemplateDict | None:
110 """Get user prompt from cache if caching is enabled.
112 Args:
113 name: User prompt identifier
115 Returns:
116 Cached PromptTemplateDict if found, None otherwise
117 """
118 if not self._enable_cache:
119 return None
120 return self._user_prompt_cache.get(name)
122 def _cache_user_prompt(self, name: str, template: PromptTemplateDict) -> None:
123 """Cache a user prompt if caching is enabled.
125 Args:
126 name: User prompt identifier
127 template: PromptTemplateDict to cache
128 """
129 if self._enable_cache:
130 self._user_prompt_cache[name] = template
132 def _get_cached_message_index(self, name: str) -> MessageIndex | None:
133 """Get message index from cache if caching is enabled.
135 Args:
136 name: Message index identifier
138 Returns:
139 Cached MessageIndex if found, None otherwise
140 """
141 if not self._enable_cache:
142 return None
143 return self._message_index_cache.get(name)
145 def _cache_message_index(self, name: str, index: MessageIndex) -> None:
146 """Cache a message index if caching is enabled.
148 Args:
149 name: Message index identifier
150 index: MessageIndex to cache
151 """
152 if self._enable_cache:
153 self._message_index_cache[name] = index
155 def _get_cached_rag_config(self, name: str) -> RAGConfig | None:
156 """Get standalone RAG config from cache if caching is enabled.
158 Args:
159 name: RAG config identifier
161 Returns:
162 Cached RAGConfig if found, None otherwise
163 """
164 if not self._enable_cache:
165 return None
166 return self._rag_config_cache.get(name)
168 def _cache_rag_config(self, name: str, config: RAGConfig) -> None:
169 """Cache a standalone RAG config if caching is enabled.
171 Args:
172 name: RAG config identifier
173 config: RAGConfig to cache
174 """
175 if self._enable_cache:
176 self._rag_config_cache[name] = config
178 def _get_cached_prompt_rag_configs(
179 self,
180 prompt_name: str,
181 prompt_type: str
182 ) -> List[RAGConfig] | None:
183 """Get prompt RAG configs from cache if caching is enabled.
185 Args:
186 prompt_name: Prompt identifier
187 prompt_type: Type of prompt ("user" or "system")
189 Returns:
190 Cached list of RAGConfig if found, None otherwise
191 """
192 if not self._enable_cache:
193 return None
194 return self._prompt_rag_cache.get((prompt_name, prompt_type))
196 def _cache_prompt_rag_configs(
197 self,
198 prompt_name: str,
199 prompt_type: str,
200 configs: List[RAGConfig]
201 ) -> None:
202 """Cache prompt RAG configurations if caching is enabled.
204 Args:
205 prompt_name: Prompt identifier
206 prompt_type: Type of prompt ("user" or "system")
207 configs: List of RAGConfig to cache
208 """
209 if self._enable_cache:
210 self._prompt_rag_cache[(prompt_name, prompt_type)] = configs
212 # ===== Common Parsing Methods =====
214 def _parse_validation_config(self, data: Union[Dict, ValidationConfig]) -> ValidationConfig:
215 """Parse validation configuration from dict or ValidationConfig.
217 This method is shared by all library implementations for consistent
218 validation config parsing.
220 Args:
221 data: Validation data (dict or ValidationConfig instance)
223 Returns:
224 ValidationConfig instance
226 Raises:
227 ValueError: If data type is invalid
228 """
229 if isinstance(data, ValidationConfig):
230 return data
232 if not isinstance(data, dict):
233 raise ValueError(
234 f"Invalid validation config: expected dict or ValidationConfig, "
235 f"got {type(data)}"
236 )
238 # Parse level
239 level = None
240 if "level" in data:
241 level_data = data["level"]
242 if isinstance(level_data, str):
243 level = ValidationLevel(level_data.lower())
244 elif isinstance(level_data, ValidationLevel):
245 level = level_data
247 # Parse params
248 required_params = data.get("required_params", [])
249 optional_params = data.get("optional_params", [])
251 return ValidationConfig(
252 level=level,
253 required_params=required_params,
254 optional_params=optional_params
255 )
257 def _parse_rag_config(self, data: Dict[str, Any]) -> RAGConfig:
258 """Parse RAG configuration from dict.
260 This method is shared by all library implementations for consistent
261 RAG config parsing.
263 Args:
264 data: RAG config data dictionary
266 Returns:
267 RAGConfig dictionary
268 """
269 rag_config: RAGConfig = {
270 "adapter_name": data.get("adapter_name", ""),
271 "query": data.get("query", ""),
272 }
274 # Add optional fields
275 if "k" in data:
276 rag_config["k"] = data["k"]
278 if "filters" in data:
279 rag_config["filters"] = data["filters"]
281 if "placeholder" in data:
282 rag_config["placeholder"] = data["placeholder"]
284 if "header" in data:
285 rag_config["header"] = data["header"]
287 if "item_template" in data:
288 rag_config["item_template"] = data["item_template"]
290 return rag_config
292 def _parse_prompt_template(self, data: Any) -> PromptTemplateDict:
293 """Parse prompt template from various formats.
295 This method is shared by all library implementations for consistent
296 template parsing. Supports:
297 - String templates (converted to {"template": string})
298 - Dict with "template" key
299 - Dict with "extends" key but no "template" (template inherited)
300 - Empty dict (treated as {"template": ""})
302 Args:
303 data: Prompt template data (string or dict)
305 Returns:
306 PromptTemplateDict dictionary
308 Raises:
309 ValueError: If data format is invalid
310 """
311 # If just a string, treat as template
312 if isinstance(data, str):
313 return {"template": data}
315 # If empty dict, treat as empty template
316 if isinstance(data, dict) and len(data) == 0:
317 return {"template": ""}
319 # Initialize template
320 template: PromptTemplateDict = None
322 # If dict with template field
323 if isinstance(data, dict) and "template" in data:
324 template = {
325 "template": data["template"],
326 }
328 # Add optional fields
329 if "defaults" in data:
330 template["defaults"] = data["defaults"]
332 if "validation" in data:
333 template["validation"] = self._parse_validation_config(data["validation"])
335 if "metadata" in data:
336 template["metadata"] = data["metadata"]
338 # Add template mode field
339 if "template_mode" in data:
340 template["template_mode"] = data["template_mode"]
342 # Add composition fields
343 if "sections" in data:
344 template["sections"] = data["sections"]
346 if "extends" in data:
347 template["extends"] = data["extends"]
349 # Add RAG configuration fields
350 if "rag_config_refs" in data:
351 template["rag_config_refs"] = data["rag_config_refs"]
353 if "rag_configs" in data:
354 template["rag_configs"] = [
355 self._parse_rag_config(rag_data)
356 for rag_data in data["rag_configs"]
357 ]
359 return template
361 # If dict with extends field but no template (template will be inherited)
362 elif isinstance(data, dict) and "extends" in data:
363 template = {}
365 # Template will be inherited from base
366 template["extends"] = data["extends"]
368 # Add optional override fields
369 if "defaults" in data:
370 template["defaults"] = data["defaults"]
372 if "validation" in data:
373 template["validation"] = self._parse_validation_config(data["validation"])
375 if "metadata" in data:
376 template["metadata"] = data["metadata"]
378 # Add template mode field
379 if "template_mode" in data:
380 template["template_mode"] = data["template_mode"]
382 if "sections" in data:
383 template["sections"] = data["sections"]
385 # Add RAG configuration fields
386 if "rag_config_refs" in data:
387 template["rag_config_refs"] = data["rag_config_refs"]
389 if "rag_configs" in data:
390 template["rag_configs"] = [
391 self._parse_rag_config(rag_data)
392 for rag_data in data["rag_configs"]
393 ]
395 return template
397 else:
398 raise ValueError(
399 f"Invalid prompt template data: expected dict with 'template' or 'extends' key, "
400 f"or string, got {type(data)}"
401 )
403 # ===== Abstract Methods (must be implemented by subclasses) =====
405 def get_system_prompt(
406 self,
407 name: str,
408 **kwargs: Any
409 ) -> PromptTemplateDict | None:
410 """Retrieve a system prompt template by name.
412 Subclasses must implement this method.
413 """
414 raise NotImplementedError(
415 f"{self.__class__.__name__} must implement get_system_prompt()"
416 )
418 def list_system_prompts(self) -> List[str]:
419 """List all available system prompt names.
421 Subclasses must implement this method.
422 """
423 raise NotImplementedError(
424 f"{self.__class__.__name__} must implement list_system_prompts()"
425 )
427 def get_user_prompt(
428 self,
429 name: str,
430 index: int = 0,
431 **kwargs: Any
432 ) -> PromptTemplateDict | None:
433 """Retrieve a user prompt template by name and index.
435 Subclasses must implement this method.
436 """
437 raise NotImplementedError(
438 f"{self.__class__.__name__} must implement get_user_prompt()"
439 )
441 def list_user_prompts(self) -> List[str]:
442 """List available user prompts.
444 Subclasses must implement this method.
445 """
446 raise NotImplementedError(
447 f"{self.__class__.__name__} must implement list_user_prompts()"
448 )
450 def get_message_index(
451 self,
452 name: str,
453 **kwargs: Any
454 ) -> MessageIndex | None:
455 """Retrieve a message index by name.
457 Subclasses must implement this method.
458 """
459 raise NotImplementedError(
460 f"{self.__class__.__name__} must implement get_message_index()"
461 )
463 def list_message_indexes(self) -> List[str]:
464 """List all available message index names.
466 Subclasses must implement this method.
467 """
468 raise NotImplementedError(
469 f"{self.__class__.__name__} must implement list_message_indexes()"
470 )
472 def get_rag_config(
473 self,
474 name: str,
475 **kwargs: Any
476 ) -> RAGConfig | None:
477 """Retrieve a standalone RAG configuration by name.
479 Subclasses must implement this method.
480 """
481 raise NotImplementedError(
482 f"{self.__class__.__name__} must implement get_rag_config()"
483 )
485 def get_prompt_rag_configs(
486 self,
487 prompt_name: str,
488 prompt_type: str = "user",
489 index: int = 0,
490 **kwargs: Any
491 ) -> List[RAGConfig]:
492 """Retrieve RAG configurations for a specific prompt.
494 Subclasses must implement this method.
495 """
496 raise NotImplementedError(
497 f"{self.__class__.__name__} must implement get_prompt_rag_configs()"
498 )