Coverage for src/dataknobs_llm/llm/utils.py: 97%

207 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-08 13:51 -0700

1"""Utility functions for LLM operations. 

2 

3This module provides utility functions for working with LLMs. 

4Template rendering utilities have been moved to dataknobs_llm.template_utils 

5to avoid circular dependencies. 

6""" 

7 

8import re 

9import json 

10from typing import Any, Dict, List, Union 

11from dataclasses import dataclass, field 

12 

13from .base import LLMMessage, LLMResponse 

14from ..template_utils import TemplateStrategy, render_conditional_template 

15 

16 

17@dataclass 

18class MessageTemplate: 

19 """Template for generating message content with multiple rendering strategies. 

20 

21 Supports two template strategies: 

22 1. SIMPLE (default): Uses Python str.format() with {variable} syntax. 

23 - All variables must be provided 

24 - Clean and straightforward 

25 - Example: "Hello {name}!" 

26 

27 2. CONDITIONAL: Advanced conditional rendering with {{variable}} and ((conditional)) syntax. 

28 - Variables can be optional 

29 - Conditional sections with (( ... )) 

30 - Whitespace-aware substitution 

31 - Example: "Hello {{name}}((, you have {{count}} messages))" 

32 """ 

33 template: str 

34 variables: List[str] = field(default_factory=list) 

35 strategy: TemplateStrategy = TemplateStrategy.SIMPLE 

36 

37 def __post_init__(self): 

38 """Extract variables from template based on strategy.""" 

39 if not self.variables: 

40 if self.strategy == TemplateStrategy.SIMPLE: 

41 # Extract {variable} patterns (single braces) 

42 self.variables = re.findall(r'\{(\w+)\}', self.template) 

43 elif self.strategy == TemplateStrategy.CONDITIONAL: 

44 # Extract {{variable}} patterns (double braces) 

45 # Extract just the variable names (group 2 from the regex) 

46 self.variables = [match.group(2) for match in re.finditer(r'\{\{(\s*)(\w+)(\s*)\}\}', self.template)] 

47 # Remove duplicates while preserving order 

48 seen = set() 

49 unique_vars = [] 

50 for var in self.variables: 

51 if var not in seen: 

52 seen.add(var) 

53 unique_vars.append(var) 

54 self.variables = unique_vars 

55 

56 def format(self, **kwargs: Any) -> str: 

57 """Format template with variables using the selected strategy. 

58 

59 Args: 

60 **kwargs: Variable values 

61 

62 Returns: 

63 Formatted prompt 

64 

65 Raises: 

66 ValueError: If using SIMPLE strategy and required variables are missing 

67 """ 

68 if self.strategy == TemplateStrategy.SIMPLE: 

69 # Simple strategy: all variables must be provided 

70 missing = set(self.variables) - set(kwargs.keys()) 

71 if missing: 

72 raise ValueError(f"Missing variables: {missing}") 

73 return self.template.format(**kwargs) 

74 

75 elif self.strategy == TemplateStrategy.CONDITIONAL: 

76 # Conditional strategy: use render_conditional_template 

77 return render_conditional_template(self.template, kwargs) 

78 

79 else: 

80 raise ValueError(f"Unknown template strategy: {self.strategy}") 

81 

82 def partial(self, **kwargs: Any) -> 'MessageTemplate': 

83 """Create partial template with some variables filled. 

84 

85 Args: 

86 **kwargs: Variable values to fill 

87 

88 Returns: 

89 New template with partial values 

90 """ 

91 if self.strategy == TemplateStrategy.SIMPLE: 

92 # Simple strategy: replace {variable} patterns 

93 new_template = self.template 

94 new_variables = self.variables.copy() 

95 

96 for key, value in kwargs.items(): 

97 if key in new_variables: 

98 new_template = new_template.replace(f'{{{key}}}', str(value)) 

99 new_variables.remove(key) 

100 

101 return MessageTemplate(new_template, new_variables, self.strategy) 

102 

103 elif self.strategy == TemplateStrategy.CONDITIONAL: 

104 # For conditional templates, render with provided variables 

105 # and keep the template structure for remaining variables 

106 new_template = self.template 

107 new_variables = self.variables.copy() 

108 

109 # Replace only the provided variables with single-brace format 

110 # so they become literals in the new template 

111 for key, value in kwargs.items(): 

112 if key in new_variables: 

113 # Replace {{var}} with the value, but keep it as a literal 

114 # We do this by using a placeholder that won't match the patterns 

115 pattern = r'\{\{\s*' + key + r'\s*\}\}' 

116 new_template = re.sub( 

117 pattern, 

118 str(value), 

119 new_template 

120 ) 

121 new_variables.remove(key) 

122 

123 return MessageTemplate(new_template, new_variables, self.strategy) 

124 

125 else: 

126 raise ValueError(f"Unknown template strategy: {self.strategy}") 

127 

128 @classmethod 

129 def from_conditional(cls, template: str, variables: List[str] | None = None) -> 'MessageTemplate': 

130 """Create a MessageTemplate using the CONDITIONAL strategy. 

131 

132 Convenience method for creating templates with advanced conditional rendering. 

133 

134 Args: 

135 template: Template string with {{variable}} and ((conditional)) syntax 

136 variables: Optional explicit list of variables 

137 

138 Returns: 

139 MessageTemplate configured with CONDITIONAL strategy 

140 

141 Example: 

142 >>> template = MessageTemplate.from_conditional( 

143 ... "Hello {{name}}((, you have {{count}} messages))" 

144 ... ) 

145 >>> template.format(name="Alice", count=5) 

146 "Hello Alice, you have 5 messages" 

147 >>> template.format(name="Bob") 

148 "Hello Bob" 

149 """ 

150 return cls(template=template, variables=variables or [], strategy=TemplateStrategy.CONDITIONAL) 

151 

152 

153class MessageBuilder: 

154 """Builder for constructing message sequences.""" 

155 

156 def __init__(self): 

157 self.messages = [] 

158 

159 def system(self, content: str) -> 'MessageBuilder': 

160 """Add system message. 

161  

162 Args: 

163 content: Message content 

164  

165 Returns: 

166 Self for chaining 

167 """ 

168 self.messages.append(LLMMessage(role='system', content=content)) 

169 return self 

170 

171 def user(self, content: str) -> 'MessageBuilder': 

172 """Add user message. 

173  

174 Args: 

175 content: Message content 

176  

177 Returns: 

178 Self for chaining 

179 """ 

180 self.messages.append(LLMMessage(role='user', content=content)) 

181 return self 

182 

183 def assistant(self, content: str) -> 'MessageBuilder': 

184 """Add assistant message. 

185  

186 Args: 

187 content: Message content 

188  

189 Returns: 

190 Self for chaining 

191 """ 

192 self.messages.append(LLMMessage(role='assistant', content=content)) 

193 return self 

194 

195 def function( 

196 self, 

197 name: str, 

198 content: str, 

199 function_call: Dict[str, Any] | None = None 

200 ) -> 'MessageBuilder': 

201 """Add function message. 

202  

203 Args: 

204 name: Function name 

205 content: Function result 

206 function_call: Function call details 

207  

208 Returns: 

209 Self for chaining 

210 """ 

211 self.messages.append(LLMMessage( 

212 role='function', 

213 name=name, 

214 content=content, 

215 function_call=function_call 

216 )) 

217 return self 

218 

219 def from_template( 

220 self, 

221 role: str, 

222 template: MessageTemplate, 

223 **kwargs: Any 

224 ) -> 'MessageBuilder': 

225 """Add message from template. 

226 

227 Args: 

228 role: Message role 

229 template: Message template 

230 **kwargs: Template variables 

231 

232 Returns: 

233 Self for chaining 

234 """ 

235 content = template.format(**kwargs) 

236 self.messages.append(LLMMessage(role=role, content=content)) 

237 return self 

238 

239 def build(self) -> List[LLMMessage]: 

240 """Build message list. 

241  

242 Returns: 

243 List of messages 

244 """ 

245 return self.messages.copy() 

246 

247 def clear(self) -> 'MessageBuilder': 

248 """Clear all messages. 

249  

250 Returns: 

251 Self for chaining 

252 """ 

253 self.messages.clear() 

254 return self 

255 

256 

257class ResponseParser: 

258 """Parser for LLM responses.""" 

259 

260 @staticmethod 

261 def extract_json(response: Union[str, LLMResponse]) -> Dict[str, Any] | None: 

262 """Extract JSON from response. 

263  

264 Args: 

265 response: LLM response 

266  

267 Returns: 

268 Extracted JSON or None 

269 """ 

270 text = response.content if isinstance(response, LLMResponse) else response 

271 

272 # Try to find JSON in the text 

273 json_patterns = [ 

274 r'\{[^}]*\}', # Simple object 

275 r'\[[^\]]*\]', # Array 

276 r'```json\s*(.*?)\s*```', # Markdown code block 

277 r'```\s*(.*?)\s*```', # Generic code block 

278 ] 

279 

280 for pattern in json_patterns: 

281 matches = re.findall(pattern, text, re.DOTALL) 

282 for match in matches: 

283 try: 

284 return json.loads(match) 

285 except json.JSONDecodeError: 

286 continue 

287 

288 # Try parsing the entire text as JSON 

289 try: 

290 return json.loads(text) 

291 except json.JSONDecodeError: 

292 return None 

293 

294 @staticmethod 

295 def extract_code( 

296 response: Union[str, LLMResponse], 

297 language: str | None = None 

298 ) -> List[str]: 

299 """Extract code blocks from response. 

300  

301 Args: 

302 response: LLM response 

303 language: Optional language filter 

304  

305 Returns: 

306 List of code blocks 

307 """ 

308 text = response.content if isinstance(response, LLMResponse) else response 

309 

310 if language: 

311 # Language-specific code blocks 

312 pattern = rf'```{language}\s*(.*?)\s*```' 

313 else: 

314 # All code blocks 

315 pattern = r'```(?:\w+)?\s*(.*?)\s*```' 

316 

317 matches = re.findall(pattern, text, re.DOTALL) 

318 return [m.strip() for m in matches] 

319 

320 @staticmethod 

321 def extract_list( 

322 response: Union[str, LLMResponse], 

323 numbered: bool = False 

324 ) -> List[str]: 

325 """Extract list items from response. 

326  

327 Args: 

328 response: LLM response 

329 numbered: Whether to look for numbered lists 

330  

331 Returns: 

332 List of items 

333 """ 

334 text = response.content if isinstance(response, LLMResponse) else response 

335 

336 if numbered: 

337 # Numbered list (1. item, 2. item, etc.) 

338 pattern = r'^\d+\.\s+(.+)$' 

339 else: 

340 # Bullet points (-, *, •) 

341 pattern = r'^[-*•]\s+(.+)$' 

342 

343 matches = re.findall(pattern, text, re.MULTILINE) 

344 return [m.strip() for m in matches] 

345 

346 @staticmethod 

347 def extract_sections( 

348 response: Union[str, LLMResponse] 

349 ) -> Dict[str, str]: 

350 """Extract sections from response. 

351  

352 Args: 

353 response: LLM response 

354  

355 Returns: 

356 Dictionary of section name to content 

357 """ 

358 text = response.content if isinstance(response, LLMResponse) else response 

359 

360 # Split by headers (# Header, ## Header, etc.) 

361 sections = {} 

362 current_section = 'main' 

363 current_content = [] 

364 

365 for line in text.split('\n'): 

366 header_match = re.match(r'^#+\s+(.+)$', line) 

367 if header_match: 

368 # Save previous section 

369 if current_content: 

370 sections[current_section] = '\n'.join(current_content).strip() 

371 # Start new section 

372 current_section = header_match.group(1).strip() 

373 current_content = [] 

374 else: 

375 current_content.append(line) 

376 

377 # Save last section 

378 if current_content: 

379 sections[current_section] = '\n'.join(current_content).strip() 

380 

381 return sections 

382 

383 

384class TokenCounter: 

385 """Estimate token counts for different models.""" 

386 

387 # Approximate tokens per character for different models 

388 TOKENS_PER_CHAR = { 

389 'gpt-4': 0.25, 

390 'gpt-3.5': 0.25, 

391 'claude': 0.25, 

392 'llama': 0.3, 

393 'default': 0.25 

394 } 

395 

396 @classmethod 

397 def estimate_tokens( 

398 cls, 

399 text: str, 

400 model: str = 'default' 

401 ) -> int: 

402 """Estimate token count for text. 

403  

404 Args: 

405 text: Input text 

406 model: Model name 

407  

408 Returns: 

409 Estimated token count 

410 """ 

411 # Find matching model pattern 

412 ratio = cls.TOKENS_PER_CHAR['default'] 

413 for pattern, r in cls.TOKENS_PER_CHAR.items(): 

414 if pattern in model.lower(): 

415 ratio = r 

416 break 

417 

418 # Estimate based on character count 

419 return int(len(text) * ratio) 

420 

421 @classmethod 

422 def estimate_messages_tokens( 

423 cls, 

424 messages: List[LLMMessage], 

425 model: str = 'default' 

426 ) -> int: 

427 """Estimate token count for messages. 

428  

429 Args: 

430 messages: List of messages 

431 model: Model name 

432  

433 Returns: 

434 Estimated token count 

435 """ 

436 total = 0 

437 for msg in messages: 

438 # Add role tokens (approximately 4 tokens) 

439 total += 4 

440 # Add content tokens 

441 total += cls.estimate_tokens(msg.content, model) 

442 # Add name tokens if present 

443 if msg.name: 

444 total += cls.estimate_tokens(msg.name, model) 

445 

446 return total 

447 

448 @classmethod 

449 def fits_in_context( 

450 cls, 

451 text: str, 

452 model: str, 

453 max_tokens: int 

454 ) -> bool: 

455 """Check if text fits in context window. 

456  

457 Args: 

458 text: Input text 

459 model: Model name 

460 max_tokens: Maximum token limit 

461  

462 Returns: 

463 True if fits 

464 """ 

465 estimated = cls.estimate_tokens(text, model) 

466 return estimated <= max_tokens 

467 

468 

469class CostCalculator: 

470 """Calculate costs for LLM usage.""" 

471 

472 # Cost per 1K tokens (in USD) 

473 PRICING = { 

474 'gpt-4': {'input': 0.03, 'output': 0.06}, 

475 'gpt-4-32k': {'input': 0.06, 'output': 0.12}, 

476 'gpt-3.5-turbo': {'input': 0.0015, 'output': 0.002}, 

477 'claude-3-opus': {'input': 0.015, 'output': 0.075}, 

478 'claude-3-sonnet': {'input': 0.003, 'output': 0.015}, 

479 'claude-3-haiku': {'input': 0.00025, 'output': 0.00125}, 

480 } 

481 

482 @classmethod 

483 def calculate_cost( 

484 cls, 

485 response: LLMResponse, 

486 model: str | None = None 

487 ) -> float | None: 

488 """Calculate cost for LLM response. 

489  

490 Args: 

491 response: LLM response with usage info 

492 model: Model name (if not in response) 

493  

494 Returns: 

495 Cost in USD or None if cannot calculate 

496 """ 

497 if not response.usage: 

498 return None 

499 

500 model = model or response.model 

501 

502 # Find matching pricing 

503 pricing = None 

504 for pattern, prices in cls.PRICING.items(): 

505 if pattern in model.lower(): 

506 pricing = prices 

507 break 

508 

509 if not pricing: 

510 return None 

511 

512 # Calculate cost 

513 input_cost = (response.usage.get('prompt_tokens', 0) / 1000) * pricing['input'] 

514 output_cost = (response.usage.get('completion_tokens', 0) / 1000) * pricing['output'] 

515 

516 return input_cost + output_cost 

517 

518 @classmethod 

519 def estimate_cost( 

520 cls, 

521 text: str, 

522 model: str, 

523 expected_output_tokens: int = 100 

524 ) -> float | None: 

525 """Estimate cost for text completion. 

526  

527 Args: 

528 text: Input text 

529 model: Model name 

530 expected_output_tokens: Expected output length 

531  

532 Returns: 

533 Estimated cost in USD 

534 """ 

535 # Find matching pricing 

536 pricing = None 

537 for pattern, prices in cls.PRICING.items(): 

538 if pattern in model.lower(): 

539 pricing = prices 

540 break 

541 

542 if not pricing: 

543 return None 

544 

545 # Estimate tokens 

546 input_tokens = TokenCounter.estimate_tokens(text, model) 

547 

548 # Calculate cost 

549 input_cost = (input_tokens / 1000) * pricing['input'] 

550 output_cost = (expected_output_tokens / 1000) * pricing['output'] 

551 

552 return input_cost + output_cost 

553 

554 

555def chain_prompts( 

556 *templates: MessageTemplate 

557) -> MessageTemplate: 

558 """Chain multiple message templates. 

559 

560 All templates must use the same strategy. The combined template 

561 will use the strategy of the first template. 

562 

563 Args: 

564 *templates: Templates to chain 

565 

566 Returns: 

567 Combined template 

568 

569 Raises: 

570 ValueError: If templates use different strategies 

571 """ 

572 if not templates: 

573 return MessageTemplate("", []) 

574 

575 # Check that all templates use the same strategy 

576 first_strategy = templates[0].strategy 

577 if not all(t.strategy == first_strategy for t in templates): 

578 raise ValueError( 

579 "Cannot chain templates with different strategies. " 

580 "All templates must use the same TemplateStrategy." 

581 ) 

582 

583 combined_template = '\n\n'.join(t.template for t in templates) 

584 combined_variables = [] 

585 seen = set() 

586 

587 for t in templates: 

588 for var in t.variables: 

589 if var not in seen: 

590 combined_variables.append(var) 

591 seen.add(var) 

592 

593 return MessageTemplate(combined_template, combined_variables, first_strategy) 

594 

595 

596def create_few_shot_prompt( 

597 instruction: str, 

598 examples: List[Dict[str, str]], 

599 query_key: str = 'input', 

600 response_key: str = 'output' 

601) -> MessageTemplate: 

602 """Create few-shot learning prompt. 

603 

604 Args: 

605 instruction: Task instruction 

606 examples: List of example input/output pairs 

607 query_key: Key for input in examples 

608 response_key: Key for output in examples 

609 

610 Returns: 

611 Few-shot prompt template 

612 """ 

613 template_parts = [instruction, ''] 

614 

615 # Add examples 

616 for i, example in enumerate(examples, 1): 

617 template_parts.append(f"Example {i}:") 

618 template_parts.append(f"Input: {example[query_key]}") 

619 template_parts.append(f"Output: {example[response_key]}") 

620 template_parts.append('') 

621 

622 # Add query placeholder 

623 template_parts.append("Now, process this input:") 

624 template_parts.append("Input: {query}") 

625 template_parts.append("Output:") 

626 

627 return MessageTemplate('\n'.join(template_parts), ['query'])