Coverage for src/dataknobs_llm/llm/utils.py: 97%
207 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
1"""Utility functions for LLM operations.
3This module provides utility functions for working with LLMs.
4Template rendering utilities have been moved to dataknobs_llm.template_utils
5to avoid circular dependencies.
6"""
8import re
9import json
10from typing import Any, Dict, List, Union
11from dataclasses import dataclass, field
13from .base import LLMMessage, LLMResponse
14from ..template_utils import TemplateStrategy, render_conditional_template
17@dataclass
18class MessageTemplate:
19 """Template for generating message content with multiple rendering strategies.
21 Supports two template strategies:
22 1. SIMPLE (default): Uses Python str.format() with {variable} syntax.
23 - All variables must be provided
24 - Clean and straightforward
25 - Example: "Hello {name}!"
27 2. CONDITIONAL: Advanced conditional rendering with {{variable}} and ((conditional)) syntax.
28 - Variables can be optional
29 - Conditional sections with (( ... ))
30 - Whitespace-aware substitution
31 - Example: "Hello {{name}}((, you have {{count}} messages))"
32 """
33 template: str
34 variables: List[str] = field(default_factory=list)
35 strategy: TemplateStrategy = TemplateStrategy.SIMPLE
37 def __post_init__(self):
38 """Extract variables from template based on strategy."""
39 if not self.variables:
40 if self.strategy == TemplateStrategy.SIMPLE:
41 # Extract {variable} patterns (single braces)
42 self.variables = re.findall(r'\{(\w+)\}', self.template)
43 elif self.strategy == TemplateStrategy.CONDITIONAL:
44 # Extract {{variable}} patterns (double braces)
45 # Extract just the variable names (group 2 from the regex)
46 self.variables = [match.group(2) for match in re.finditer(r'\{\{(\s*)(\w+)(\s*)\}\}', self.template)]
47 # Remove duplicates while preserving order
48 seen = set()
49 unique_vars = []
50 for var in self.variables:
51 if var not in seen:
52 seen.add(var)
53 unique_vars.append(var)
54 self.variables = unique_vars
56 def format(self, **kwargs: Any) -> str:
57 """Format template with variables using the selected strategy.
59 Args:
60 **kwargs: Variable values
62 Returns:
63 Formatted prompt
65 Raises:
66 ValueError: If using SIMPLE strategy and required variables are missing
67 """
68 if self.strategy == TemplateStrategy.SIMPLE:
69 # Simple strategy: all variables must be provided
70 missing = set(self.variables) - set(kwargs.keys())
71 if missing:
72 raise ValueError(f"Missing variables: {missing}")
73 return self.template.format(**kwargs)
75 elif self.strategy == TemplateStrategy.CONDITIONAL:
76 # Conditional strategy: use render_conditional_template
77 return render_conditional_template(self.template, kwargs)
79 else:
80 raise ValueError(f"Unknown template strategy: {self.strategy}")
82 def partial(self, **kwargs: Any) -> 'MessageTemplate':
83 """Create partial template with some variables filled.
85 Args:
86 **kwargs: Variable values to fill
88 Returns:
89 New template with partial values
90 """
91 if self.strategy == TemplateStrategy.SIMPLE:
92 # Simple strategy: replace {variable} patterns
93 new_template = self.template
94 new_variables = self.variables.copy()
96 for key, value in kwargs.items():
97 if key in new_variables:
98 new_template = new_template.replace(f'{{{key}}}', str(value))
99 new_variables.remove(key)
101 return MessageTemplate(new_template, new_variables, self.strategy)
103 elif self.strategy == TemplateStrategy.CONDITIONAL:
104 # For conditional templates, render with provided variables
105 # and keep the template structure for remaining variables
106 new_template = self.template
107 new_variables = self.variables.copy()
109 # Replace only the provided variables with single-brace format
110 # so they become literals in the new template
111 for key, value in kwargs.items():
112 if key in new_variables:
113 # Replace {{var}} with the value, but keep it as a literal
114 # We do this by using a placeholder that won't match the patterns
115 pattern = r'\{\{\s*' + key + r'\s*\}\}'
116 new_template = re.sub(
117 pattern,
118 str(value),
119 new_template
120 )
121 new_variables.remove(key)
123 return MessageTemplate(new_template, new_variables, self.strategy)
125 else:
126 raise ValueError(f"Unknown template strategy: {self.strategy}")
128 @classmethod
129 def from_conditional(cls, template: str, variables: List[str] | None = None) -> 'MessageTemplate':
130 """Create a MessageTemplate using the CONDITIONAL strategy.
132 Convenience method for creating templates with advanced conditional rendering.
134 Args:
135 template: Template string with {{variable}} and ((conditional)) syntax
136 variables: Optional explicit list of variables
138 Returns:
139 MessageTemplate configured with CONDITIONAL strategy
141 Example:
142 >>> template = MessageTemplate.from_conditional(
143 ... "Hello {{name}}((, you have {{count}} messages))"
144 ... )
145 >>> template.format(name="Alice", count=5)
146 "Hello Alice, you have 5 messages"
147 >>> template.format(name="Bob")
148 "Hello Bob"
149 """
150 return cls(template=template, variables=variables or [], strategy=TemplateStrategy.CONDITIONAL)
153class MessageBuilder:
154 """Builder for constructing message sequences."""
156 def __init__(self):
157 self.messages = []
159 def system(self, content: str) -> 'MessageBuilder':
160 """Add system message.
162 Args:
163 content: Message content
165 Returns:
166 Self for chaining
167 """
168 self.messages.append(LLMMessage(role='system', content=content))
169 return self
171 def user(self, content: str) -> 'MessageBuilder':
172 """Add user message.
174 Args:
175 content: Message content
177 Returns:
178 Self for chaining
179 """
180 self.messages.append(LLMMessage(role='user', content=content))
181 return self
183 def assistant(self, content: str) -> 'MessageBuilder':
184 """Add assistant message.
186 Args:
187 content: Message content
189 Returns:
190 Self for chaining
191 """
192 self.messages.append(LLMMessage(role='assistant', content=content))
193 return self
195 def function(
196 self,
197 name: str,
198 content: str,
199 function_call: Dict[str, Any] | None = None
200 ) -> 'MessageBuilder':
201 """Add function message.
203 Args:
204 name: Function name
205 content: Function result
206 function_call: Function call details
208 Returns:
209 Self for chaining
210 """
211 self.messages.append(LLMMessage(
212 role='function',
213 name=name,
214 content=content,
215 function_call=function_call
216 ))
217 return self
219 def from_template(
220 self,
221 role: str,
222 template: MessageTemplate,
223 **kwargs: Any
224 ) -> 'MessageBuilder':
225 """Add message from template.
227 Args:
228 role: Message role
229 template: Message template
230 **kwargs: Template variables
232 Returns:
233 Self for chaining
234 """
235 content = template.format(**kwargs)
236 self.messages.append(LLMMessage(role=role, content=content))
237 return self
239 def build(self) -> List[LLMMessage]:
240 """Build message list.
242 Returns:
243 List of messages
244 """
245 return self.messages.copy()
247 def clear(self) -> 'MessageBuilder':
248 """Clear all messages.
250 Returns:
251 Self for chaining
252 """
253 self.messages.clear()
254 return self
257class ResponseParser:
258 """Parser for LLM responses."""
260 @staticmethod
261 def extract_json(response: Union[str, LLMResponse]) -> Dict[str, Any] | None:
262 """Extract JSON from response.
264 Args:
265 response: LLM response
267 Returns:
268 Extracted JSON or None
269 """
270 text = response.content if isinstance(response, LLMResponse) else response
272 # Try to find JSON in the text
273 json_patterns = [
274 r'\{[^}]*\}', # Simple object
275 r'\[[^\]]*\]', # Array
276 r'```json\s*(.*?)\s*```', # Markdown code block
277 r'```\s*(.*?)\s*```', # Generic code block
278 ]
280 for pattern in json_patterns:
281 matches = re.findall(pattern, text, re.DOTALL)
282 for match in matches:
283 try:
284 return json.loads(match)
285 except json.JSONDecodeError:
286 continue
288 # Try parsing the entire text as JSON
289 try:
290 return json.loads(text)
291 except json.JSONDecodeError:
292 return None
294 @staticmethod
295 def extract_code(
296 response: Union[str, LLMResponse],
297 language: str | None = None
298 ) -> List[str]:
299 """Extract code blocks from response.
301 Args:
302 response: LLM response
303 language: Optional language filter
305 Returns:
306 List of code blocks
307 """
308 text = response.content if isinstance(response, LLMResponse) else response
310 if language:
311 # Language-specific code blocks
312 pattern = rf'```{language}\s*(.*?)\s*```'
313 else:
314 # All code blocks
315 pattern = r'```(?:\w+)?\s*(.*?)\s*```'
317 matches = re.findall(pattern, text, re.DOTALL)
318 return [m.strip() for m in matches]
320 @staticmethod
321 def extract_list(
322 response: Union[str, LLMResponse],
323 numbered: bool = False
324 ) -> List[str]:
325 """Extract list items from response.
327 Args:
328 response: LLM response
329 numbered: Whether to look for numbered lists
331 Returns:
332 List of items
333 """
334 text = response.content if isinstance(response, LLMResponse) else response
336 if numbered:
337 # Numbered list (1. item, 2. item, etc.)
338 pattern = r'^\d+\.\s+(.+)$'
339 else:
340 # Bullet points (-, *, •)
341 pattern = r'^[-*•]\s+(.+)$'
343 matches = re.findall(pattern, text, re.MULTILINE)
344 return [m.strip() for m in matches]
346 @staticmethod
347 def extract_sections(
348 response: Union[str, LLMResponse]
349 ) -> Dict[str, str]:
350 """Extract sections from response.
352 Args:
353 response: LLM response
355 Returns:
356 Dictionary of section name to content
357 """
358 text = response.content if isinstance(response, LLMResponse) else response
360 # Split by headers (# Header, ## Header, etc.)
361 sections = {}
362 current_section = 'main'
363 current_content = []
365 for line in text.split('\n'):
366 header_match = re.match(r'^#+\s+(.+)$', line)
367 if header_match:
368 # Save previous section
369 if current_content:
370 sections[current_section] = '\n'.join(current_content).strip()
371 # Start new section
372 current_section = header_match.group(1).strip()
373 current_content = []
374 else:
375 current_content.append(line)
377 # Save last section
378 if current_content:
379 sections[current_section] = '\n'.join(current_content).strip()
381 return sections
384class TokenCounter:
385 """Estimate token counts for different models."""
387 # Approximate tokens per character for different models
388 TOKENS_PER_CHAR = {
389 'gpt-4': 0.25,
390 'gpt-3.5': 0.25,
391 'claude': 0.25,
392 'llama': 0.3,
393 'default': 0.25
394 }
396 @classmethod
397 def estimate_tokens(
398 cls,
399 text: str,
400 model: str = 'default'
401 ) -> int:
402 """Estimate token count for text.
404 Args:
405 text: Input text
406 model: Model name
408 Returns:
409 Estimated token count
410 """
411 # Find matching model pattern
412 ratio = cls.TOKENS_PER_CHAR['default']
413 for pattern, r in cls.TOKENS_PER_CHAR.items():
414 if pattern in model.lower():
415 ratio = r
416 break
418 # Estimate based on character count
419 return int(len(text) * ratio)
421 @classmethod
422 def estimate_messages_tokens(
423 cls,
424 messages: List[LLMMessage],
425 model: str = 'default'
426 ) -> int:
427 """Estimate token count for messages.
429 Args:
430 messages: List of messages
431 model: Model name
433 Returns:
434 Estimated token count
435 """
436 total = 0
437 for msg in messages:
438 # Add role tokens (approximately 4 tokens)
439 total += 4
440 # Add content tokens
441 total += cls.estimate_tokens(msg.content, model)
442 # Add name tokens if present
443 if msg.name:
444 total += cls.estimate_tokens(msg.name, model)
446 return total
448 @classmethod
449 def fits_in_context(
450 cls,
451 text: str,
452 model: str,
453 max_tokens: int
454 ) -> bool:
455 """Check if text fits in context window.
457 Args:
458 text: Input text
459 model: Model name
460 max_tokens: Maximum token limit
462 Returns:
463 True if fits
464 """
465 estimated = cls.estimate_tokens(text, model)
466 return estimated <= max_tokens
469class CostCalculator:
470 """Calculate costs for LLM usage."""
472 # Cost per 1K tokens (in USD)
473 PRICING = {
474 'gpt-4': {'input': 0.03, 'output': 0.06},
475 'gpt-4-32k': {'input': 0.06, 'output': 0.12},
476 'gpt-3.5-turbo': {'input': 0.0015, 'output': 0.002},
477 'claude-3-opus': {'input': 0.015, 'output': 0.075},
478 'claude-3-sonnet': {'input': 0.003, 'output': 0.015},
479 'claude-3-haiku': {'input': 0.00025, 'output': 0.00125},
480 }
482 @classmethod
483 def calculate_cost(
484 cls,
485 response: LLMResponse,
486 model: str | None = None
487 ) -> float | None:
488 """Calculate cost for LLM response.
490 Args:
491 response: LLM response with usage info
492 model: Model name (if not in response)
494 Returns:
495 Cost in USD or None if cannot calculate
496 """
497 if not response.usage:
498 return None
500 model = model or response.model
502 # Find matching pricing
503 pricing = None
504 for pattern, prices in cls.PRICING.items():
505 if pattern in model.lower():
506 pricing = prices
507 break
509 if not pricing:
510 return None
512 # Calculate cost
513 input_cost = (response.usage.get('prompt_tokens', 0) / 1000) * pricing['input']
514 output_cost = (response.usage.get('completion_tokens', 0) / 1000) * pricing['output']
516 return input_cost + output_cost
518 @classmethod
519 def estimate_cost(
520 cls,
521 text: str,
522 model: str,
523 expected_output_tokens: int = 100
524 ) -> float | None:
525 """Estimate cost for text completion.
527 Args:
528 text: Input text
529 model: Model name
530 expected_output_tokens: Expected output length
532 Returns:
533 Estimated cost in USD
534 """
535 # Find matching pricing
536 pricing = None
537 for pattern, prices in cls.PRICING.items():
538 if pattern in model.lower():
539 pricing = prices
540 break
542 if not pricing:
543 return None
545 # Estimate tokens
546 input_tokens = TokenCounter.estimate_tokens(text, model)
548 # Calculate cost
549 input_cost = (input_tokens / 1000) * pricing['input']
550 output_cost = (expected_output_tokens / 1000) * pricing['output']
552 return input_cost + output_cost
555def chain_prompts(
556 *templates: MessageTemplate
557) -> MessageTemplate:
558 """Chain multiple message templates.
560 All templates must use the same strategy. The combined template
561 will use the strategy of the first template.
563 Args:
564 *templates: Templates to chain
566 Returns:
567 Combined template
569 Raises:
570 ValueError: If templates use different strategies
571 """
572 if not templates:
573 return MessageTemplate("", [])
575 # Check that all templates use the same strategy
576 first_strategy = templates[0].strategy
577 if not all(t.strategy == first_strategy for t in templates):
578 raise ValueError(
579 "Cannot chain templates with different strategies. "
580 "All templates must use the same TemplateStrategy."
581 )
583 combined_template = '\n\n'.join(t.template for t in templates)
584 combined_variables = []
585 seen = set()
587 for t in templates:
588 for var in t.variables:
589 if var not in seen:
590 combined_variables.append(var)
591 seen.add(var)
593 return MessageTemplate(combined_template, combined_variables, first_strategy)
596def create_few_shot_prompt(
597 instruction: str,
598 examples: List[Dict[str, str]],
599 query_key: str = 'input',
600 response_key: str = 'output'
601) -> MessageTemplate:
602 """Create few-shot learning prompt.
604 Args:
605 instruction: Task instruction
606 examples: List of example input/output pairs
607 query_key: Key for input in examples
608 response_key: Key for output in examples
610 Returns:
611 Few-shot prompt template
612 """
613 template_parts = [instruction, '']
615 # Add examples
616 for i, example in enumerate(examples, 1):
617 template_parts.append(f"Example {i}:")
618 template_parts.append(f"Input: {example[query_key]}")
619 template_parts.append(f"Output: {example[response_key]}")
620 template_parts.append('')
622 # Add query placeholder
623 template_parts.append("Now, process this input:")
624 template_parts.append("Input: {query}")
625 template_parts.append("Output:")
627 return MessageTemplate('\n'.join(template_parts), ['query'])