Coverage for src / dataknobs_llm / llm / utils.py: 28%

207 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-15 10:29 -0700

1"""Utility functions for LLM operations. 

2 

3This module provides utility functions for working with LLMs. 

4Template rendering utilities have been moved to dataknobs_llm.template_utils 

5to avoid circular dependencies. 

6""" 

7 

8import re 

9import json 

10from typing import Any, Dict, List, Union 

11from dataclasses import dataclass, field 

12 

13from .base import LLMMessage, LLMResponse 

14from ..template_utils import TemplateStrategy, render_conditional_template 

15 

16 

17@dataclass 

18class MessageTemplate: 

19 """Template for generating message content with multiple rendering strategies. 

20 

21 Supports two template strategies: 

22 1. SIMPLE (default): Uses Python str.format() with {variable} syntax. 

23 - All variables must be provided 

24 - Clean and straightforward 

25 - Example: "Hello {name}!" 

26 

27 2. CONDITIONAL: Advanced conditional rendering with {{variable}} and ((conditional)) syntax. 

28 - Variables can be optional 

29 - Conditional sections with (( ... )) 

30 - Whitespace-aware substitution 

31 - Example: "Hello {{name}}((, you have {{count}} messages))" 

32 """ 

33 template: str 

34 variables: List[str] = field(default_factory=list) 

35 strategy: TemplateStrategy = TemplateStrategy.SIMPLE 

36 

37 def __post_init__(self): 

38 """Extract variables from template based on strategy.""" 

39 if not self.variables: 

40 if self.strategy == TemplateStrategy.SIMPLE: 

41 # Extract {variable} patterns (single braces) 

42 self.variables = re.findall(r'\{(\w+)\}', self.template) 

43 elif self.strategy == TemplateStrategy.CONDITIONAL: 

44 # Extract {{variable}} patterns (double braces) 

45 # Extract just the variable names (group 2 from the regex) 

46 self.variables = [match.group(2) for match in re.finditer(r'\{\{(\s*)(\w+)(\s*)\}\}', self.template)] 

47 # Remove duplicates while preserving order 

48 seen = set() 

49 unique_vars = [] 

50 for var in self.variables: 

51 if var not in seen: 

52 seen.add(var) 

53 unique_vars.append(var) 

54 self.variables = unique_vars 

55 

56 def format(self, **kwargs: Any) -> str: 

57 """Format template with variables using the selected strategy. 

58 

59 Args: 

60 **kwargs: Variable values 

61 

62 Returns: 

63 Formatted prompt 

64 

65 Raises: 

66 ValueError: If using SIMPLE strategy and required variables are missing 

67 """ 

68 if self.strategy == TemplateStrategy.SIMPLE: 

69 # Simple strategy: all variables must be provided 

70 missing = set(self.variables) - set(kwargs.keys()) 

71 if missing: 

72 raise ValueError(f"Missing variables: {missing}") 

73 return self.template.format(**kwargs) 

74 

75 elif self.strategy == TemplateStrategy.CONDITIONAL: 

76 # Conditional strategy: use render_conditional_template 

77 return render_conditional_template(self.template, kwargs) 

78 

79 else: 

80 raise ValueError(f"Unknown template strategy: {self.strategy}") 

81 

82 def partial(self, **kwargs: Any) -> 'MessageTemplate': 

83 """Create partial template with some variables filled. 

84 

85 Args: 

86 **kwargs: Variable values to fill 

87 

88 Returns: 

89 New template with partial values 

90 """ 

91 if self.strategy == TemplateStrategy.SIMPLE: 

92 # Simple strategy: replace {variable} patterns 

93 new_template = self.template 

94 new_variables = self.variables.copy() 

95 

96 for key, value in kwargs.items(): 

97 if key in new_variables: 

98 new_template = new_template.replace(f'{{{key}}}', str(value)) 

99 new_variables.remove(key) 

100 

101 return MessageTemplate(new_template, new_variables, self.strategy) 

102 

103 elif self.strategy == TemplateStrategy.CONDITIONAL: 

104 # For conditional templates, render with provided variables 

105 # and keep the template structure for remaining variables 

106 new_template = self.template 

107 new_variables = self.variables.copy() 

108 

109 # Replace only the provided variables with single-brace format 

110 # so they become literals in the new template 

111 for key, value in kwargs.items(): 

112 if key in new_variables: 

113 # Replace {{var}} with the value, but keep it as a literal 

114 # We do this by using a placeholder that won't match the patterns 

115 pattern = r'\{\{\s*' + key + r'\s*\}\}' 

116 new_template = re.sub( 

117 pattern, 

118 str(value), 

119 new_template 

120 ) 

121 new_variables.remove(key) 

122 

123 return MessageTemplate(new_template, new_variables, self.strategy) 

124 

125 else: 

126 raise ValueError(f"Unknown template strategy: {self.strategy}") 

127 

128 @classmethod 

129 def from_conditional(cls, template: str, variables: List[str] | None = None) -> 'MessageTemplate': 

130 """Create a MessageTemplate using the CONDITIONAL strategy. 

131 

132 Convenience method for creating templates with advanced conditional rendering. 

133 

134 Args: 

135 template: Template string with {{variable}} and ((conditional)) syntax 

136 variables: Optional explicit list of variables 

137 

138 Returns: 

139 MessageTemplate configured with CONDITIONAL strategy 

140 

141 Example: 

142 ```python 

143 template = MessageTemplate.from_conditional( 

144 "Hello {{name}}((, you have {{count}} messages))" 

145 ) 

146 template.format(name="Alice", count=5) 

147 # "Hello Alice, you have 5 messages" 

148 template.format(name="Bob") 

149 # "Hello Bob" 

150 ``` 

151 """ 

152 return cls(template=template, variables=variables or [], strategy=TemplateStrategy.CONDITIONAL) 

153 

154 

155class MessageBuilder: 

156 """Builder for constructing message sequences.""" 

157 

158 def __init__(self): 

159 self.messages = [] 

160 

161 def system(self, content: str) -> 'MessageBuilder': 

162 """Add system message. 

163  

164 Args: 

165 content: Message content 

166  

167 Returns: 

168 Self for chaining 

169 """ 

170 self.messages.append(LLMMessage(role='system', content=content)) 

171 return self 

172 

173 def user(self, content: str) -> 'MessageBuilder': 

174 """Add user message. 

175  

176 Args: 

177 content: Message content 

178  

179 Returns: 

180 Self for chaining 

181 """ 

182 self.messages.append(LLMMessage(role='user', content=content)) 

183 return self 

184 

185 def assistant(self, content: str) -> 'MessageBuilder': 

186 """Add assistant message. 

187  

188 Args: 

189 content: Message content 

190  

191 Returns: 

192 Self for chaining 

193 """ 

194 self.messages.append(LLMMessage(role='assistant', content=content)) 

195 return self 

196 

197 def function( 

198 self, 

199 name: str, 

200 content: str, 

201 function_call: Dict[str, Any] | None = None 

202 ) -> 'MessageBuilder': 

203 """Add function message. 

204  

205 Args: 

206 name: Function name 

207 content: Function result 

208 function_call: Function call details 

209  

210 Returns: 

211 Self for chaining 

212 """ 

213 self.messages.append(LLMMessage( 

214 role='function', 

215 name=name, 

216 content=content, 

217 function_call=function_call 

218 )) 

219 return self 

220 

221 def from_template( 

222 self, 

223 role: str, 

224 template: MessageTemplate, 

225 **kwargs: Any 

226 ) -> 'MessageBuilder': 

227 """Add message from template. 

228 

229 Args: 

230 role: Message role 

231 template: Message template 

232 **kwargs: Template variables 

233 

234 Returns: 

235 Self for chaining 

236 """ 

237 content = template.format(**kwargs) 

238 self.messages.append(LLMMessage(role=role, content=content)) 

239 return self 

240 

241 def build(self) -> List[LLMMessage]: 

242 """Build message list. 

243  

244 Returns: 

245 List of messages 

246 """ 

247 return self.messages.copy() 

248 

249 def clear(self) -> 'MessageBuilder': 

250 """Clear all messages. 

251  

252 Returns: 

253 Self for chaining 

254 """ 

255 self.messages.clear() 

256 return self 

257 

258 

259class ResponseParser: 

260 """Parser for LLM responses.""" 

261 

262 @staticmethod 

263 def extract_json(response: Union[str, LLMResponse]) -> Dict[str, Any] | None: 

264 """Extract JSON from response. 

265  

266 Args: 

267 response: LLM response 

268  

269 Returns: 

270 Extracted JSON or None 

271 """ 

272 text = response.content if isinstance(response, LLMResponse) else response 

273 

274 # Try to find JSON in the text 

275 json_patterns = [ 

276 r'\{[^}]*\}', # Simple object 

277 r'\[[^\]]*\]', # Array 

278 r'```json\s*(.*?)\s*```', # Markdown code block 

279 r'```\s*(.*?)\s*```', # Generic code block 

280 ] 

281 

282 for pattern in json_patterns: 

283 matches = re.findall(pattern, text, re.DOTALL) 

284 for match in matches: 

285 try: 

286 return json.loads(match) 

287 except json.JSONDecodeError: 

288 continue 

289 

290 # Try parsing the entire text as JSON 

291 try: 

292 return json.loads(text) 

293 except json.JSONDecodeError: 

294 return None 

295 

296 @staticmethod 

297 def extract_code( 

298 response: Union[str, LLMResponse], 

299 language: str | None = None 

300 ) -> List[str]: 

301 """Extract code blocks from response. 

302  

303 Args: 

304 response: LLM response 

305 language: Optional language filter 

306  

307 Returns: 

308 List of code blocks 

309 """ 

310 text = response.content if isinstance(response, LLMResponse) else response 

311 

312 if language: 

313 # Language-specific code blocks 

314 pattern = rf'```{language}\s*(.*?)\s*```' 

315 else: 

316 # All code blocks 

317 pattern = r'```(?:\w+)?\s*(.*?)\s*```' 

318 

319 matches = re.findall(pattern, text, re.DOTALL) 

320 return [m.strip() for m in matches] 

321 

322 @staticmethod 

323 def extract_list( 

324 response: Union[str, LLMResponse], 

325 numbered: bool = False 

326 ) -> List[str]: 

327 """Extract list items from response. 

328  

329 Args: 

330 response: LLM response 

331 numbered: Whether to look for numbered lists 

332  

333 Returns: 

334 List of items 

335 """ 

336 text = response.content if isinstance(response, LLMResponse) else response 

337 

338 if numbered: 

339 # Numbered list (1. item, 2. item, etc.) 

340 pattern = r'^\d+\.\s+(.+)$' 

341 else: 

342 # Bullet points (-, *, •) 

343 pattern = r'^[-*•]\s+(.+)$' 

344 

345 matches = re.findall(pattern, text, re.MULTILINE) 

346 return [m.strip() for m in matches] 

347 

348 @staticmethod 

349 def extract_sections( 

350 response: Union[str, LLMResponse] 

351 ) -> Dict[str, str]: 

352 """Extract sections from response. 

353  

354 Args: 

355 response: LLM response 

356  

357 Returns: 

358 Dictionary of section name to content 

359 """ 

360 text = response.content if isinstance(response, LLMResponse) else response 

361 

362 # Split by headers (# Header, ## Header, etc.) 

363 sections = {} 

364 current_section = 'main' 

365 current_content = [] 

366 

367 for line in text.split('\n'): 

368 header_match = re.match(r'^#+\s+(.+)$', line) 

369 if header_match: 

370 # Save previous section 

371 if current_content: 

372 sections[current_section] = '\n'.join(current_content).strip() 

373 # Start new section 

374 current_section = header_match.group(1).strip() 

375 current_content = [] 

376 else: 

377 current_content.append(line) 

378 

379 # Save last section 

380 if current_content: 

381 sections[current_section] = '\n'.join(current_content).strip() 

382 

383 return sections 

384 

385 

386class TokenCounter: 

387 """Estimate token counts for different models.""" 

388 

389 # Approximate tokens per character for different models 

390 TOKENS_PER_CHAR = { 

391 'gpt-4': 0.25, 

392 'gpt-3.5': 0.25, 

393 'claude': 0.25, 

394 'llama': 0.3, 

395 'default': 0.25 

396 } 

397 

398 @classmethod 

399 def estimate_tokens( 

400 cls, 

401 text: str, 

402 model: str = 'default' 

403 ) -> int: 

404 """Estimate token count for text. 

405  

406 Args: 

407 text: Input text 

408 model: Model name 

409  

410 Returns: 

411 Estimated token count 

412 """ 

413 # Find matching model pattern 

414 ratio = cls.TOKENS_PER_CHAR['default'] 

415 for pattern, r in cls.TOKENS_PER_CHAR.items(): 

416 if pattern in model.lower(): 

417 ratio = r 

418 break 

419 

420 # Estimate based on character count 

421 return int(len(text) * ratio) 

422 

423 @classmethod 

424 def estimate_messages_tokens( 

425 cls, 

426 messages: List[LLMMessage], 

427 model: str = 'default' 

428 ) -> int: 

429 """Estimate token count for messages. 

430  

431 Args: 

432 messages: List of messages 

433 model: Model name 

434  

435 Returns: 

436 Estimated token count 

437 """ 

438 total = 0 

439 for msg in messages: 

440 # Add role tokens (approximately 4 tokens) 

441 total += 4 

442 # Add content tokens 

443 total += cls.estimate_tokens(msg.content, model) 

444 # Add name tokens if present 

445 if msg.name: 

446 total += cls.estimate_tokens(msg.name, model) 

447 

448 return total 

449 

450 @classmethod 

451 def fits_in_context( 

452 cls, 

453 text: str, 

454 model: str, 

455 max_tokens: int 

456 ) -> bool: 

457 """Check if text fits in context window. 

458  

459 Args: 

460 text: Input text 

461 model: Model name 

462 max_tokens: Maximum token limit 

463  

464 Returns: 

465 True if fits 

466 """ 

467 estimated = cls.estimate_tokens(text, model) 

468 return estimated <= max_tokens 

469 

470 

471class CostCalculator: 

472 """Calculate costs for LLM usage.""" 

473 

474 # Cost per 1K tokens (in USD) 

475 PRICING = { 

476 'gpt-4': {'input': 0.03, 'output': 0.06}, 

477 'gpt-4-32k': {'input': 0.06, 'output': 0.12}, 

478 'gpt-3.5-turbo': {'input': 0.0015, 'output': 0.002}, 

479 'claude-3-opus': {'input': 0.015, 'output': 0.075}, 

480 'claude-3-sonnet': {'input': 0.003, 'output': 0.015}, 

481 'claude-3-haiku': {'input': 0.00025, 'output': 0.00125}, 

482 } 

483 

484 @classmethod 

485 def calculate_cost( 

486 cls, 

487 response: LLMResponse, 

488 model: str | None = None 

489 ) -> float | None: 

490 """Calculate cost for LLM response. 

491  

492 Args: 

493 response: LLM response with usage info 

494 model: Model name (if not in response) 

495  

496 Returns: 

497 Cost in USD or None if cannot calculate 

498 """ 

499 if not response.usage: 

500 return None 

501 

502 model = model or response.model 

503 

504 # Find matching pricing 

505 pricing = None 

506 for pattern, prices in cls.PRICING.items(): 

507 if pattern in model.lower(): 

508 pricing = prices 

509 break 

510 

511 if not pricing: 

512 return None 

513 

514 # Calculate cost 

515 input_cost = (response.usage.get('prompt_tokens', 0) / 1000) * pricing['input'] 

516 output_cost = (response.usage.get('completion_tokens', 0) / 1000) * pricing['output'] 

517 

518 return input_cost + output_cost 

519 

520 @classmethod 

521 def estimate_cost( 

522 cls, 

523 text: str, 

524 model: str, 

525 expected_output_tokens: int = 100 

526 ) -> float | None: 

527 """Estimate cost for text completion. 

528  

529 Args: 

530 text: Input text 

531 model: Model name 

532 expected_output_tokens: Expected output length 

533  

534 Returns: 

535 Estimated cost in USD 

536 """ 

537 # Find matching pricing 

538 pricing = None 

539 for pattern, prices in cls.PRICING.items(): 

540 if pattern in model.lower(): 

541 pricing = prices 

542 break 

543 

544 if not pricing: 

545 return None 

546 

547 # Estimate tokens 

548 input_tokens = TokenCounter.estimate_tokens(text, model) 

549 

550 # Calculate cost 

551 input_cost = (input_tokens / 1000) * pricing['input'] 

552 output_cost = (expected_output_tokens / 1000) * pricing['output'] 

553 

554 return input_cost + output_cost 

555 

556 

557def chain_prompts( 

558 *templates: MessageTemplate 

559) -> MessageTemplate: 

560 """Chain multiple message templates. 

561 

562 All templates must use the same strategy. The combined template 

563 will use the strategy of the first template. 

564 

565 Args: 

566 *templates: Templates to chain 

567 

568 Returns: 

569 Combined template 

570 

571 Raises: 

572 ValueError: If templates use different strategies 

573 """ 

574 if not templates: 

575 return MessageTemplate("", []) 

576 

577 # Check that all templates use the same strategy 

578 first_strategy = templates[0].strategy 

579 if not all(t.strategy == first_strategy for t in templates): 

580 raise ValueError( 

581 "Cannot chain templates with different strategies. " 

582 "All templates must use the same TemplateStrategy." 

583 ) 

584 

585 combined_template = '\n\n'.join(t.template for t in templates) 

586 combined_variables = [] 

587 seen = set() 

588 

589 for t in templates: 

590 for var in t.variables: 

591 if var not in seen: 

592 combined_variables.append(var) 

593 seen.add(var) 

594 

595 return MessageTemplate(combined_template, combined_variables, first_strategy) 

596 

597 

598def create_few_shot_prompt( 

599 instruction: str, 

600 examples: List[Dict[str, str]], 

601 query_key: str = 'input', 

602 response_key: str = 'output' 

603) -> MessageTemplate: 

604 """Create few-shot learning prompt. 

605 

606 Args: 

607 instruction: Task instruction 

608 examples: List of example input/output pairs 

609 query_key: Key for input in examples 

610 response_key: Key for output in examples 

611 

612 Returns: 

613 Few-shot prompt template 

614 """ 

615 template_parts = [instruction, ''] 

616 

617 # Add examples 

618 for i, example in enumerate(examples, 1): 

619 template_parts.append(f"Example {i}:") 

620 template_parts.append(f"Input: {example[query_key]}") 

621 template_parts.append(f"Output: {example[response_key]}") 

622 template_parts.append('') 

623 

624 # Add query placeholder 

625 template_parts.append("Now, process this input:") 

626 template_parts.append("Input: {query}") 

627 template_parts.append("Output:") 

628 

629 return MessageTemplate('\n'.join(template_parts), ['query'])