Coverage for src / dataknobs_llm / prompts / rendering / template_renderer.py: 28%

163 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-15 10:29 -0700

1"""Template renderer with validation support and Jinja2 integration. 

2 

3This module provides template rendering with: 

4- Custom (( )) conditional syntax (backward compatible) 

5- Jinja2 integration for filters, advanced conditionals, includes 

6- Two rendering modes: "mixed" (both syntaxes) and "jinja2" (pure Jinja2) 

7- Validation capabilities for missing parameters 

8""" 

9 

10import re 

11import logging 

12from typing import Any, Callable, Dict, List, Set, Tuple 

13from dataclasses import dataclass 

14 

15from jinja2 import Environment, TemplateSyntaxError as Jinja2SyntaxError, Undefined 

16 

17from dataknobs_llm.template_utils import render_conditional_template 

18from ..base.types import ( 

19 ValidationLevel, 

20 ValidationConfig, 

21 PromptTemplateDict, 

22 RenderResult, 

23 TemplateMode 

24) 

25 

26logger = logging.getLogger(__name__) 

27 

28 

29class PreserveUndefined(Undefined): 

30 """Jinja2 Undefined handler that preserves placeholders for undefined variables. 

31 

32 This maintains backward compatibility with the old template behavior where 

33 undefined variables are left as {{variable}} instead of being rendered as 

34 empty strings. 

35 """ 

36 def __str__(self) -> str: 

37 """Return the original placeholder for undefined variables.""" 

38 return f"{{{{{self._undefined_name}}}}}" 

39 

40 def __repr__(self) -> str: 

41 """Return the original placeholder for undefined variables.""" 

42 return f"{{{{{self._undefined_name}}}}}" 

43 

44 

45@dataclass 

46class TemplateSyntaxError: 

47 """Represents a template syntax error with location information.""" 

48 message: str 

49 line: int 

50 column: int 

51 snippet: str 

52 error_type: str # 'unmatched_brace', 'unmatched_conditional', 'malformed_variable' 

53 

54 def __str__(self) -> str: 

55 """Format error message with location.""" 

56 return ( 

57 f"{self.error_type} at line {self.line}, column {self.column}: {self.message}\n" 

58 f" {self.snippet}" 

59 ) 

60 

61 

62class TemplateRenderer: 

63 """Template renderer with configurable validation and Jinja2 support. 

64 

65 This class provides: 

66 - Two rendering modes: "mixed" (default) and "jinja2" 

67 - Custom (( )) conditional syntax (backward compatible) 

68 - Jinja2 filters, conditionals, includes, loops, macros 

69 - Validation of required parameters (ERROR/WARN/IGNORE) 

70 - Tracking of used and missing parameters 

71 - Detailed render results with metadata 

72 """ 

73 

74 def __init__( 

75 self, 

76 default_validation: ValidationLevel = ValidationLevel.WARN, 

77 default_mode: TemplateMode = TemplateMode.MIXED 

78 ): 

79 """Initialize the template renderer with Jinja2. 

80 

81 Args: 

82 default_validation: Default validation level for templates 

83 without explicit validation configuration 

84 default_mode: Default template mode (mixed or jinja2) 

85 """ 

86 self._default_validation = default_validation 

87 self._default_mode = default_mode 

88 

89 # Initialize Jinja2 environment 

90 self._jinja_env = Environment( 

91 # Keep same delimiters as our custom syntax 

92 variable_start_string='{{', 

93 variable_end_string='}}', 

94 block_start_string='{%', 

95 block_end_string='%}', 

96 comment_start_string='{#', 

97 comment_end_string='#}', 

98 # Prompt generation, not HTML - no autoescaping 

99 autoescape=False, 

100 # Better whitespace handling 

101 trim_blocks=True, 

102 lstrip_blocks=True, 

103 keep_trailing_newline=True, 

104 # Preserve undefined variables (backward compatibility) 

105 undefined=PreserveUndefined, 

106 ) 

107 

108 # Register custom filters (domain-specific) 

109 self._register_custom_filters() 

110 

111 def render( 

112 self, 

113 template: str, 

114 params: Dict[str, Any], 

115 validation: ValidationConfig | None = None, 

116 template_metadata: Dict[str, Any] | None = None, 

117 mode: TemplateMode | None = None 

118 ) -> RenderResult: 

119 """Render a template with parameters and validation. 

120 

121 Args: 

122 template: Template string with {{variables}} and ((conditionals)) 

123 params: Parameters to substitute in the template 

124 validation: Optional validation configuration (overrides template default) 

125 template_metadata: Optional metadata about the template 

126 mode: Template mode (mixed or jinja2), defaults to renderer default 

127 

128 Returns: 

129 RenderResult with rendered content and validation information 

130 

131 Raises: 

132 ValueError: If validation fails or syntax errors occur 

133 """ 

134 # Determine mode 

135 effective_mode = mode if mode is not None else self._default_mode 

136 

137 # Use provided validation or create default 

138 if validation is None: 

139 validation = ValidationConfig(level=self._default_validation) 

140 

141 # Determine effective validation level (inherit from renderer if not set) 

142 effective_level = validation.level if validation.level is not None else self._default_validation 

143 

144 try: 

145 # Step 1: Pre-process (( )) if in mixed mode 

146 if effective_mode == TemplateMode.MIXED: 

147 # Validate: no Jinja2 syntax inside (( )) 

148 self._validate_no_jinja_in_conditionals(template) 

149 

150 # Pre-process conditionals 

151 intermediate = render_conditional_template(template, params) 

152 else: 

153 # Pure Jinja2 mode - use template as-is 

154 intermediate = template 

155 

156 # Step 2: Render with Jinja2 

157 jinja_template = self._jinja_env.from_string(intermediate) 

158 content = jinja_template.render(**params) 

159 

160 except Jinja2SyntaxError as e: 

161 raise ValueError( 

162 f"Template syntax error at line {e.lineno}: {e.message}\n" 

163 f"Template: {e.source or 'N/A'}" 

164 ) from e 

165 except Exception as e: 

166 raise ValueError(f"Template rendering error: {e}") from e 

167 

168 # Step 3: Validation 

169 template_vars = self._extract_variables(template) 

170 params_used = {k: v for k, v in params.items() if k in template_vars} 

171 params_missing = [] 

172 

173 # Check for missing required parameters 

174 for var in validation.required_params: 

175 if var not in params or params[var] is None: 

176 params_missing.append(var) 

177 

178 # Handle validation 

179 validation_warnings = [] 

180 if params_missing: 

181 missing_str = ", ".join(params_missing) 

182 if effective_level == ValidationLevel.ERROR: 

183 raise ValueError( 

184 f"Missing required parameters: {missing_str}" 

185 ) 

186 elif effective_level == ValidationLevel.WARN: 

187 warning_msg = f"Missing required parameters: {missing_str}" 

188 validation_warnings.append(warning_msg) 

189 logger.warning(warning_msg) 

190 # IGNORE level: do nothing 

191 

192 # Step 4: Build result 

193 return RenderResult( 

194 content=content, 

195 params_used=params_used, 

196 params_missing=params_missing, 

197 validation_warnings=validation_warnings, 

198 metadata={ 

199 "validation_level": effective_level.value, 

200 "template_vars": list(template_vars), 

201 "template_mode": effective_mode.value, 

202 **(template_metadata or {}) 

203 } 

204 ) 

205 

206 def render_prompt_template( 

207 self, 

208 prompt_template: PromptTemplateDict, 

209 params: Dict[str, Any], 

210 validation_override: ValidationLevel | None = None, 

211 mode_override: TemplateMode | None = None 

212 ) -> RenderResult: 

213 """Render a PromptTemplateDict structure with validation. 

214 

215 Args: 

216 prompt_template: PromptTemplateDict dictionary with template, defaults, validation 

217 params: Parameters to substitute (merged with template defaults) 

218 validation_override: Optional runtime validation level override 

219 mode_override: Optional template mode override 

220 

221 Returns: 

222 RenderResult with rendered content and validation information 

223 """ 

224 # Merge defaults with provided params (params take priority) 

225 merged_params = { 

226 **prompt_template.get("defaults", {}), 

227 **params 

228 } 

229 

230 # Get template mode from template or use override 

231 template_mode_str = prompt_template.get("template_mode", "mixed") 

232 template_mode = TemplateMode.from_string(template_mode_str) 

233 effective_mode = mode_override if mode_override is not None else template_mode 

234 

235 # Get validation config from template 

236 template_validation = prompt_template.get("validation") 

237 

238 # Apply validation override if provided 

239 if validation_override is not None: 

240 if template_validation: 

241 # Create new config with overridden level 

242 validation = ValidationConfig( 

243 level=validation_override, 

244 required_params=list(template_validation.required_params), 

245 optional_params=list(template_validation.optional_params) 

246 ) 

247 else: 

248 # Create new config with just the override level 

249 validation = ValidationConfig(level=validation_override) 

250 else: 

251 validation = template_validation 

252 

253 # Render with merged params, validation, and effective mode 

254 return self.render( 

255 template=prompt_template["template"], 

256 params=merged_params, 

257 validation=validation, 

258 template_metadata=prompt_template.get("metadata"), 

259 mode=effective_mode 

260 ) 

261 

262 def batch_render( 

263 self, 

264 templates: List[str], 

265 params: Dict[str, Any], 

266 validation: ValidationConfig | None = None 

267 ) -> List[RenderResult]: 

268 """Render multiple templates with the same parameters. 

269 

270 Args: 

271 templates: List of template strings 

272 params: Parameters to substitute in all templates 

273 validation: Optional validation configuration for all templates 

274 

275 Returns: 

276 List of RenderResult objects, one per template 

277 """ 

278 return [ 

279 self.render(template, params, validation) 

280 for template in templates 

281 ] 

282 

283 @staticmethod 

284 def _extract_variables(template: str) -> Set[str]: 

285 """Extract all variable names from a template (including Jinja2 syntax). 

286 

287 Args: 

288 template: Template string with {{variable}} or {{variable|filter}} syntax 

289 

290 Returns: 

291 Set of variable names found in the template 

292 """ 

293 # Pattern to match {{var}} or {{var|filter}} - extract just the variable name 

294 var_pattern = r'\{\{\s*(\w+)(?:\s*\|[^}]*)?\s*\}\}' 

295 matches = re.finditer(var_pattern, template) 

296 

297 # Extract variable names (group 1 from regex) 

298 variables = {match.group(1) for match in matches} 

299 

300 return variables 

301 

302 def _validate_no_jinja_in_conditionals(self, template: str): 

303 """Validate that no Jinja2 syntax appears inside (( )) blocks. 

304 

305 Args: 

306 template: Template string to validate 

307 

308 Raises: 

309 ValueError: If Jinja2 syntax found inside (( )) 

310 """ 

311 # Find all (( ... )) blocks 

312 pattern = r'\(\(((?:[^()]|\((?!\()|(?<!\))\))*)\)\)' 

313 

314 for match in re.finditer(pattern, template): 

315 block_content = match.group(1) 

316 

317 # Check for {% %} blocks 

318 if '{%' in block_content: 

319 raise ValueError( 

320 f"Jinja2 block syntax ('{{% %}}') not allowed inside " 

321 f"conditional blocks '(( ))'.\n" 

322 f"Found in: ((... {block_content[:50]} ...))\n" 

323 f"Hint: Move {{% %}} blocks outside (( )) or use " 

324 f"template_mode='jinja2' for pure Jinja2 templates." 

325 ) 

326 

327 # Check for filters (| after {{) 

328 if re.search(r'\{\{\s*\w+\s*\|', block_content): 

329 raise ValueError( 

330 f"Jinja2 filters (|filter) not allowed inside " 

331 f"conditional blocks '(( ))'.\n" 

332 f"Found in: ((... {block_content[:50]} ...))\n" 

333 f"Hint: Apply filters outside (( )) blocks or use " 

334 f"template_mode='jinja2' for pure Jinja2 templates." 

335 ) 

336 

337 def _register_custom_filters(self): 

338 """Register custom domain-specific filters.""" 

339 

340 # Example: Token counting filter 

341 def count_tokens(text: str, model: str = "gpt-4") -> int: 

342 """Count approximate tokens in text.""" 

343 # Simple approximation: ~4 chars per token 

344 return len(text) // 4 

345 

346 self._jinja_env.filters['count_tokens'] = count_tokens 

347 

348 # Example: Prompt formatting 

349 def format_code(code: str, language: str = "python") -> str: 

350 """Format code in markdown code block.""" 

351 return f"```{language}\n{code}\n```" 

352 

353 self._jinja_env.filters['format_code'] = format_code 

354 

355 # Users can add more via add_custom_filter() 

356 

357 def add_custom_filter( 

358 self, 

359 name: str, 

360 filter_func: Callable[..., Any] 

361 ): 

362 """Register a custom filter with Jinja2. 

363 

364 Args: 

365 name: Filter name (used in templates as |name) 

366 filter_func: Filter function (first arg is value to filter) 

367 

368 Example: 

369 >>> renderer.add_custom_filter( 

370 ... 'double', 

371 ... lambda x: x * 2 

372 ... ) 

373 >>> result = renderer.render("{{count|double}}", {"count": 5}) 

374 >>> result.content 

375 "10" 

376 """ 

377 self._jinja_env.filters[name] = filter_func 

378 

379 @staticmethod 

380 def _get_line_col(template: str, position: int) -> Tuple[int, int]: 

381 """Get line and column number for a position in the template. 

382 

383 Args: 

384 template: Template string 

385 position: Character position in template 

386 

387 Returns: 

388 Tuple of (line_number, column_number) (1-indexed) 

389 """ 

390 lines = template[:position].split('\n') 

391 line = len(lines) 

392 column = len(lines[-1]) + 1 

393 return line, column 

394 

395 @staticmethod 

396 def _get_snippet(template: str, position: int, context: int = 20) -> str: 

397 """Get a snippet of text around a position. 

398 

399 Args: 

400 template: Template string 

401 position: Character position 

402 context: Number of characters to show before/after 

403 

404 Returns: 

405 Snippet with error position marked 

406 """ 

407 start = max(0, position - context) 

408 end = min(len(template), position + context) 

409 snippet = template[start:end] 

410 

411 # Replace newlines for better display 

412 snippet = snippet.replace('\n', '\\n') 

413 

414 # Mark the error position 

415 error_pos = min(position - start, len(snippet)) 

416 if error_pos < len(snippet): 

417 snippet = snippet[:error_pos] + '⮜HERE⮞' + snippet[error_pos:] 

418 

419 return snippet 

420 

421 @staticmethod 

422 def validate_template_syntax_detailed(template: str) -> List[TemplateSyntaxError]: 

423 """Validate template syntax and return detailed errors with locations. 

424 

425 Args: 

426 template: Template string to validate 

427 

428 Returns: 

429 List of TemplateSyntaxError objects (empty if valid) 

430 """ 

431 errors = [] 

432 

433 # Check for unmatched braces 

434 brace_pattern = r'(?<!\{)\{(?!\{)|(?<!\})\}(?!\})' 

435 for match in re.finditer(brace_pattern, template): 

436 position = match.start() 

437 line, col = TemplateRenderer._get_line_col(template, position) 

438 snippet = TemplateRenderer._get_snippet(template, position) 

439 

440 errors.append(TemplateSyntaxError( 

441 message="Unmatched brace. Use {{ }} for variables, not { }.", 

442 line=line, 

443 column=col, 

444 snippet=snippet, 

445 error_type="unmatched_brace" 

446 )) 

447 

448 # Check for unmatched conditional sections 

449 open_positions = [m.start() for m in re.finditer(r'\(\(', template)] 

450 close_positions = [m.start() for m in re.finditer(r'\)\)', template)] 

451 

452 # Simple stack-based matching 

453 stack = [] 

454 all_positions = sorted( 

455 [(pos, 'open') for pos in open_positions] + 

456 [(pos, 'close') for pos in close_positions] 

457 ) 

458 

459 for position, bracket_type in all_positions: 

460 if bracket_type == 'open': 

461 stack.append(position) 

462 else: # close 

463 if not stack: 

464 # Closing without opening 

465 line, col = TemplateRenderer._get_line_col(template, position) 

466 snippet = TemplateRenderer._get_snippet(template, position) 

467 errors.append(TemplateSyntaxError( 

468 message="Closing ')) without matching opening '(('.", 

469 line=line, 

470 column=col, 

471 snippet=snippet, 

472 error_type="unmatched_conditional" 

473 )) 

474 else: 

475 stack.pop() 

476 

477 # Remaining unclosed openings 

478 for position in stack: 

479 line, col = TemplateRenderer._get_line_col(template, position) 

480 snippet = TemplateRenderer._get_snippet(template, position) 

481 errors.append(TemplateSyntaxError( 

482 message="Opening '((' without matching closing '))'.", 

483 line=line, 

484 column=col, 

485 snippet=snippet, 

486 error_type="unmatched_conditional" 

487 )) 

488 

489 # Check for malformed variable patterns 

490 # Look for {{ }} that don't contain valid variable names 

491 var_pattern = r'\{\{[^}]*\}\}' 

492 for match in re.finditer(var_pattern, template): 

493 var_content = match.group(0)[2:-2].strip() # Remove {{ }} 

494 

495 # Valid variable: only word characters (letters, digits, underscores) 

496 if var_content and not re.match(r'^\w+$', var_content): 

497 position = match.start() 

498 line, col = TemplateRenderer._get_line_col(template, position) 

499 snippet = TemplateRenderer._get_snippet(template, position, context=30) 

500 

501 errors.append(TemplateSyntaxError( 

502 message=( 

503 f"Malformed variable '{{{{' {var_content} '}}}}'. " 

504 "Variables should contain only letters, numbers, and underscores." 

505 ), 

506 line=line, 

507 column=col, 

508 snippet=snippet, 

509 error_type="malformed_variable" 

510 )) 

511 elif not var_content: 

512 # Empty variable {{}} 

513 position = match.start() 

514 line, col = TemplateRenderer._get_line_col(template, position) 

515 snippet = TemplateRenderer._get_snippet(template, position) 

516 

517 errors.append(TemplateSyntaxError( 

518 message="Empty variable {{}}. Variables must have a name.", 

519 line=line, 

520 column=col, 

521 snippet=snippet, 

522 error_type="malformed_variable" 

523 )) 

524 

525 return errors 

526 

527 @staticmethod 

528 def validate_template_syntax(template: str) -> List[str]: 

529 """Validate template syntax and return error messages. 

530 

531 This is a convenience wrapper around validate_template_syntax_detailed() 

532 that returns simple string messages instead of detailed error objects. 

533 

534 Args: 

535 template: Template string to validate 

536 

537 Returns: 

538 List of error messages (empty if valid) 

539 """ 

540 detailed_errors = TemplateRenderer.validate_template_syntax_detailed(template) 

541 return [str(error) for error in detailed_errors] 

542 

543 

544# Convenience functions for one-off rendering 

545 

546def render_template( 

547 template: str, 

548 params: Dict[str, Any], 

549 validation_level: ValidationLevel = ValidationLevel.WARN 

550) -> str: 

551 """Convenience function to render a template with parameters. 

552 

553 Args: 

554 template: Template string with {{variables}} and ((conditionals)) 

555 params: Parameters to substitute 

556 validation_level: Validation level to use (default: WARN) 

557 

558 Returns: 

559 Rendered template string 

560 

561 Example: 

562 >>> result = render_template( 

563 ... "Hello {{name}}((, you are {{age}} years old))", 

564 ... {"name": "Alice", "age": 30} 

565 ... ) 

566 >>> print(result) 

567 "Hello Alice, you are 30 years old" 

568 """ 

569 renderer = TemplateRenderer(default_validation=validation_level) 

570 result = renderer.render(template, params) 

571 return result.content 

572 

573 

574def render_template_strict( 

575 template: str, 

576 params: Dict[str, Any], 

577 required_params: List[str] 

578) -> str: 

579 """Render a template with strict validation (ERROR level). 

580 

581 Args: 

582 template: Template string 

583 params: Parameters to substitute 

584 required_params: List of required parameter names 

585 

586 Returns: 

587 Rendered template string 

588 

589 Raises: 

590 ValueError: If any required parameters are missing 

591 """ 

592 renderer = TemplateRenderer() 

593 validation = ValidationConfig( 

594 level=ValidationLevel.ERROR, 

595 required_params=required_params 

596 ) 

597 result = renderer.render(template, params, validation) 

598 return result.content