Coverage for src/dataknobs_llm/prompts/utils/template_composition.py: 89%

131 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-08 13:51 -0700

1r"""Template composition utilities for building complex prompts from reusable parts. 

2 

3This module provides the TemplateComposer class which handles: 

4- Section substitution in templates 

5- Template inheritance (extends field) 

6- Config merging for derived templates 

7- Caching for performance 

8 

9Example: 

10 >>> from dataknobs_llm.prompts import TemplateComposer 

11 >>> 

12 >>> # Define base template with sections 

13 >>> base_config = { 

14 ... "sections": { 

15 ... "CODE_SECTION": "```{{language}}\\n{{code}}\\n```", 

16 ... "INSTRUCTIONS": "Analyze for quality" 

17 ... }, 

18 ... "user_prompts": [{ 

19 ... "template": "{{CODE_SECTION}}\\n\\n{{INSTRUCTIONS}}" 

20 ... }] 

21 ... } 

22 >>> 

23 >>> # Derived template overrides one section 

24 >>> derived_config = { 

25 ... "extends": "base_analysis", 

26 ... "sections": { 

27 ... "INSTRUCTIONS": "Analyze for security issues" 

28 ... } 

29 ... } 

30 >>> 

31 >>> composer = TemplateComposer(library) 

32 >>> merged = composer.merge_prompt_configs(base_config, derived_config) 

33""" 

34 

35import logging 

36from typing import Dict, Any 

37 

38logger = logging.getLogger(__name__) 

39 

40 

41class TemplateComposer: 

42 """Handles template composition and inheritance. 

43 

44 This class provides functionality for: 

45 1. Section substitution: Replace {{SECTION_NAME}} placeholders with section content 

46 2. Template inheritance: Support for 'extends' field to inherit from base templates 

47 3. Config merging: Merge derived template configs with base configs 

48 4. Caching: Cache composed templates for performance 

49 

50 The composer works with prompt libraries to retrieve base templates and 

51 their configurations, then composes them according to inheritance rules. 

52 """ 

53 

54 def __init__(self, library: Any | None = None): 

55 """Initialize the template composer. 

56 

57 Args: 

58 library: Optional prompt library for retrieving base templates. 

59 If provided, enables template inheritance via 'extends' field. 

60 """ 

61 self.library = library 

62 self._composition_cache: Dict[str, str] = {} 

63 self._config_cache: Dict[str, Dict[str, Any]] = {} 

64 

65 def compose_template( 

66 self, 

67 template: str, 

68 sections: Dict[str, str] | None = None, 

69 prompt_name: str | None = None 

70 ) -> str: 

71 r"""Compose a template by replacing section placeholders. 

72 

73 Replaces all {{SECTION_NAME}} placeholders in the template with 

74 their corresponding section content from the sections dictionary. 

75 

76 Args: 

77 template: Template string with section placeholders 

78 sections: Dictionary mapping section names to their content 

79 prompt_name: Optional prompt name for caching 

80 

81 Returns: 

82 Composed template with sections expanded 

83 

84 Example: 

85 >>> sections = {"CODE": "```python\\ncode\\n```", "NOTES": "Important!"} 

86 >>> template = "{{CODE}}\\n\\n{{NOTES}}" 

87 >>> composed = composer.compose_template(template, sections) 

88 >>> print(composed) 

89 ```python 

90 code 

91 ``` 

92 

93 Important! 

94 """ 

95 # Check cache if prompt name provided 

96 if prompt_name: 

97 cache_key = f"{prompt_name}:{template[:50]}" 

98 if cache_key in self._composition_cache: 

99 return self._composition_cache[cache_key] 

100 

101 if not sections: 

102 return template 

103 

104 # Replace section placeholders with section content 

105 composed = template 

106 for section_name, section_content in sections.items(): 

107 # Match both {{SECTION_NAME}} and {{ SECTION_NAME }} 

108 placeholder = f"{{{{{section_name}}}}}" 

109 placeholder_with_spaces = f"{{{{ {section_name} }}}}" 

110 

111 composed = composed.replace(placeholder, section_content) 

112 composed = composed.replace(placeholder_with_spaces, section_content) 

113 

114 # Cache if prompt name provided 

115 if prompt_name: 

116 self._composition_cache[cache_key] = composed 

117 

118 return composed 

119 

120 def get_sections_for_prompt( 

121 self, 

122 prompt_name: str, 

123 prompt_config: Dict[str, Any] 

124 ) -> Dict[str, str]: 

125 """Get all sections for a prompt, including inherited sections. 

126 

127 Handles template inheritance by: 

128 1. Getting base template sections if 'extends' field exists 

129 2. Merging with prompt's own sections (prompt sections override base) 

130 3. Recursively resolving inheritance chains 

131 

132 Args: 

133 prompt_name: Name of the prompt 

134 prompt_config: Prompt configuration dictionary 

135 

136 Returns: 

137 Dictionary of all sections (base + overrides) 

138 

139 Raises: 

140 ValueError: If inheritance chain is circular 

141 """ 

142 # Track visited prompts to detect circular inheritance 

143 visited = set() 

144 

145 return self._get_sections_recursive( 

146 prompt_name, 

147 prompt_config, 

148 visited 

149 ) 

150 

151 def _get_sections_recursive( 

152 self, 

153 prompt_name: str, 

154 prompt_config: Dict[str, Any], 

155 visited: set 

156 ) -> Dict[str, str]: 

157 """Recursively resolve sections with inheritance. 

158 

159 Args: 

160 prompt_name: Name of the current prompt 

161 prompt_config: Configuration for the current prompt 

162 visited: Set of already visited prompts (for cycle detection) 

163 

164 Returns: 

165 Merged sections dictionary 

166 

167 Raises: 

168 ValueError: If circular inheritance detected 

169 """ 

170 # Check for circular inheritance 

171 if prompt_name in visited: 

172 raise ValueError( 

173 f"Circular inheritance detected: {prompt_name} already in " 

174 f"inheritance chain {visited}" 

175 ) 

176 

177 visited.add(prompt_name) 

178 

179 # Start with empty sections 

180 all_sections = {} 

181 

182 # If this prompt extends another, get base sections first 

183 extends = prompt_config.get("extends") 

184 if extends and self.library: 

185 # Try to get the base prompt configuration 

186 # We'll try both system and user prompts 

187 base_config = None 

188 

189 # Try system prompts first 

190 try: 

191 base_config = self.library.get_system_prompt(extends) 

192 except (ValueError, KeyError): 

193 pass 

194 

195 # Try user prompts if system didn't work 

196 if not base_config: 

197 try: 

198 base_config = self.library.get_user_prompt(extends, index=0) 

199 except (ValueError, KeyError): 

200 pass 

201 

202 if base_config: 

203 # Recursively get base sections 

204 base_sections = self._get_sections_recursive( 

205 extends, 

206 base_config, 

207 visited.copy() # Copy to avoid affecting sibling branches 

208 ) 

209 all_sections.update(base_sections) 

210 else: 

211 logger.warning( 

212 f"Cannot find base template '{extends}' for '{prompt_name}'" 

213 ) 

214 

215 # Overlay this prompt's sections (overrides base) 

216 prompt_sections = prompt_config.get("sections", {}) 

217 all_sections.update(prompt_sections) 

218 

219 return all_sections 

220 

221 def merge_prompt_configs( 

222 self, 

223 base_config: Dict[str, Any], 

224 derived_config: Dict[str, Any] 

225 ) -> Dict[str, Any]: 

226 """Merge derived prompt config with base config. 

227 

228 Merging rules: 

229 1. Sections: Child sections override parent sections 

230 2. Defaults: Child defaults override parent defaults 

231 3. Validation: Child validation overrides parent validation 

232 4. RAG configs: Child configs are appended to parent configs 

233 5. User/system prompts: Child prompts replace parent prompts 

234 6. Metadata: Merged with child taking priority 

235 

236 Args: 

237 base_config: Base template configuration 

238 derived_config: Derived template configuration 

239 

240 Returns: 

241 Merged configuration dictionary 

242 

243 Example: 

244 >>> base = { 

245 ... "defaults": {"lang": "python"}, 

246 ... "sections": {"CODE": "{{code}}"} 

247 ... } 

248 >>> derived = { 

249 ... "defaults": {"style": "PEP8"}, 

250 ... "sections": {"NOTES": "{{notes}}"} 

251 ... } 

252 >>> merged = composer.merge_prompt_configs(base, derived) 

253 >>> merged["defaults"] 

254 {"lang": "python", "style": "PEP8"} 

255 >>> merged["sections"] 

256 {"CODE": "{{code}}", "NOTES": "{{notes}}"} 

257 """ 

258 merged = {} 

259 

260 # 1. Merge sections (child overrides parent) 

261 if "sections" in base_config or "sections" in derived_config: 

262 merged["sections"] = { 

263 **base_config.get("sections", {}), 

264 **derived_config.get("sections", {}) 

265 } 

266 

267 # 2. Merge defaults (child overrides parent) 

268 if "defaults" in base_config or "defaults" in derived_config: 

269 merged["defaults"] = { 

270 **base_config.get("defaults", {}), 

271 **derived_config.get("defaults", {}) 

272 } 

273 

274 # 3. Merge validation (child overrides parent completely) 

275 if "validation" in derived_config: 

276 merged["validation"] = derived_config["validation"] 

277 elif "validation" in base_config: 

278 merged["validation"] = base_config["validation"] 

279 

280 # 4. Merge RAG configs (append child to parent) 

281 base_rag_refs = base_config.get("rag_config_refs", []) 

282 derived_rag_refs = derived_config.get("rag_config_refs", []) 

283 if base_rag_refs or derived_rag_refs: 

284 # Combine refs, removing duplicates while preserving order 

285 seen = set() 

286 merged["rag_config_refs"] = [] 

287 for ref in base_rag_refs + derived_rag_refs: 

288 if ref not in seen: 

289 seen.add(ref) 

290 merged["rag_config_refs"].append(ref) 

291 

292 base_rag_configs = base_config.get("rag_configs", []) 

293 derived_rag_configs = derived_config.get("rag_configs", []) 

294 if base_rag_configs or derived_rag_configs: 

295 merged["rag_configs"] = base_rag_configs + derived_rag_configs 

296 

297 # 5. User/system prompts - child replaces parent 

298 # (This is for the template dict itself, not the list of prompts) 

299 if "template" in derived_config: 

300 merged["template"] = derived_config["template"] 

301 elif "template" in base_config: 

302 merged["template"] = base_config["template"] 

303 

304 if "user_prompts" in derived_config: 

305 merged["user_prompts"] = derived_config["user_prompts"] 

306 elif "user_prompts" in base_config: 

307 merged["user_prompts"] = base_config["user_prompts"] 

308 

309 if "system_prompts" in derived_config: 

310 merged["system_prompts"] = derived_config["system_prompts"] 

311 elif "system_prompts" in base_config: 

312 merged["system_prompts"] = base_config["system_prompts"] 

313 

314 # 6. Merge metadata (child takes priority) 

315 if "metadata" in base_config or "metadata" in derived_config: 

316 merged["metadata"] = { 

317 **base_config.get("metadata", {}), 

318 **derived_config.get("metadata", {}) 

319 } 

320 

321 # 7. Copy extends field if present 

322 if "extends" in derived_config: 

323 merged["extends"] = derived_config["extends"] 

324 

325 return merged 

326 

327 def resolve_inheritance( 

328 self, 

329 prompt_name: str, 

330 prompt_config: Dict[str, Any] 

331 ) -> Dict[str, Any]: 

332 """Fully resolve a prompt's inheritance chain. 

333 

334 This method walks up the inheritance chain (via 'extends' fields) 

335 and merges all configs from base to derived. 

336 

337 Args: 

338 prompt_name: Name of the prompt to resolve 

339 prompt_config: Initial prompt configuration 

340 

341 Returns: 

342 Fully resolved configuration with all inheritance applied 

343 

344 Raises: 

345 ValueError: If circular inheritance detected 

346 

347 Example: 

348 >>> # grandparent -> parent -> child 

349 >>> resolved = composer.resolve_inheritance("child", child_config) 

350 >>> # Returns merged config with all three levels 

351 """ 

352 # Check cache 

353 cache_key = f"{prompt_name}" 

354 if cache_key in self._config_cache: 

355 return self._config_cache[cache_key] 

356 

357 # Track visited to detect cycles 

358 visited = [] 

359 current_name = prompt_name 

360 current_config = prompt_config 

361 

362 # Collect all configs in the inheritance chain 

363 configs_to_merge = [current_config] 

364 

365 while current_config.get("extends"): 

366 base_name = current_config["extends"] 

367 

368 # Check for circular inheritance 

369 if base_name in visited: 

370 raise ValueError( 

371 f"Circular inheritance detected: {base_name} -> " 

372 f"{' -> '.join(visited)} -> {base_name}" 

373 ) 

374 

375 visited.append(current_name) 

376 

377 # Get base config from library 

378 if not self.library: 

379 logger.warning( 

380 f"Cannot resolve inheritance for '{current_name}': " 

381 f"no library provided" 

382 ) 

383 break 

384 

385 # Try to get base config 

386 base_config = None 

387 try: 

388 base_config = self.library.get_system_prompt(base_name) 

389 except (ValueError, KeyError): 

390 pass 

391 

392 if not base_config: 

393 try: 

394 base_config = self.library.get_user_prompt(base_name, index=0) 

395 except (ValueError, KeyError): 

396 pass 

397 

398 if not base_config: 

399 logger.warning( 

400 f"Cannot find base template '{base_name}' for '{current_name}'" 

401 ) 

402 break 

403 

404 # Add to chain 

405 configs_to_merge.insert(0, base_config) # Insert at front (base first) 

406 

407 # Move up the chain 

408 current_name = base_name 

409 current_config = base_config 

410 

411 # Merge all configs from base to derived 

412 resolved = {} 

413 for config in configs_to_merge: 

414 resolved = self.merge_prompt_configs(resolved, config) 

415 

416 # Cache the result 

417 self._config_cache[cache_key] = resolved 

418 

419 return resolved 

420 

421 def clear_cache(self): 

422 """Clear all caches. 

423 

424 Call this if templates or configs are modified after composition. 

425 """ 

426 self._composition_cache.clear() 

427 self._config_cache.clear()