Coverage for src / dataknobs_llm / llm / utils.py: 28%
207 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-15 10:29 -0700
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-15 10:29 -0700
1"""Utility functions for LLM operations.
3This module provides utility functions for working with LLMs.
4Template rendering utilities have been moved to dataknobs_llm.template_utils
5to avoid circular dependencies.
6"""
8import re
9import json
10from typing import Any, Dict, List, Union
11from dataclasses import dataclass, field
13from .base import LLMMessage, LLMResponse
14from ..template_utils import TemplateStrategy, render_conditional_template
17@dataclass
18class MessageTemplate:
19 """Template for generating message content with multiple rendering strategies.
21 Supports two template strategies:
22 1. SIMPLE (default): Uses Python str.format() with {variable} syntax.
23 - All variables must be provided
24 - Clean and straightforward
25 - Example: "Hello {name}!"
27 2. CONDITIONAL: Advanced conditional rendering with {{variable}} and ((conditional)) syntax.
28 - Variables can be optional
29 - Conditional sections with (( ... ))
30 - Whitespace-aware substitution
31 - Example: "Hello {{name}}((, you have {{count}} messages))"
32 """
33 template: str
34 variables: List[str] = field(default_factory=list)
35 strategy: TemplateStrategy = TemplateStrategy.SIMPLE
37 def __post_init__(self):
38 """Extract variables from template based on strategy."""
39 if not self.variables:
40 if self.strategy == TemplateStrategy.SIMPLE:
41 # Extract {variable} patterns (single braces)
42 self.variables = re.findall(r'\{(\w+)\}', self.template)
43 elif self.strategy == TemplateStrategy.CONDITIONAL:
44 # Extract {{variable}} patterns (double braces)
45 # Extract just the variable names (group 2 from the regex)
46 self.variables = [match.group(2) for match in re.finditer(r'\{\{(\s*)(\w+)(\s*)\}\}', self.template)]
47 # Remove duplicates while preserving order
48 seen = set()
49 unique_vars = []
50 for var in self.variables:
51 if var not in seen:
52 seen.add(var)
53 unique_vars.append(var)
54 self.variables = unique_vars
56 def format(self, **kwargs: Any) -> str:
57 """Format template with variables using the selected strategy.
59 Args:
60 **kwargs: Variable values
62 Returns:
63 Formatted prompt
65 Raises:
66 ValueError: If using SIMPLE strategy and required variables are missing
67 """
68 if self.strategy == TemplateStrategy.SIMPLE:
69 # Simple strategy: all variables must be provided
70 missing = set(self.variables) - set(kwargs.keys())
71 if missing:
72 raise ValueError(f"Missing variables: {missing}")
73 return self.template.format(**kwargs)
75 elif self.strategy == TemplateStrategy.CONDITIONAL:
76 # Conditional strategy: use render_conditional_template
77 return render_conditional_template(self.template, kwargs)
79 else:
80 raise ValueError(f"Unknown template strategy: {self.strategy}")
82 def partial(self, **kwargs: Any) -> 'MessageTemplate':
83 """Create partial template with some variables filled.
85 Args:
86 **kwargs: Variable values to fill
88 Returns:
89 New template with partial values
90 """
91 if self.strategy == TemplateStrategy.SIMPLE:
92 # Simple strategy: replace {variable} patterns
93 new_template = self.template
94 new_variables = self.variables.copy()
96 for key, value in kwargs.items():
97 if key in new_variables:
98 new_template = new_template.replace(f'{{{key}}}', str(value))
99 new_variables.remove(key)
101 return MessageTemplate(new_template, new_variables, self.strategy)
103 elif self.strategy == TemplateStrategy.CONDITIONAL:
104 # For conditional templates, render with provided variables
105 # and keep the template structure for remaining variables
106 new_template = self.template
107 new_variables = self.variables.copy()
109 # Replace only the provided variables with single-brace format
110 # so they become literals in the new template
111 for key, value in kwargs.items():
112 if key in new_variables:
113 # Replace {{var}} with the value, but keep it as a literal
114 # We do this by using a placeholder that won't match the patterns
115 pattern = r'\{\{\s*' + key + r'\s*\}\}'
116 new_template = re.sub(
117 pattern,
118 str(value),
119 new_template
120 )
121 new_variables.remove(key)
123 return MessageTemplate(new_template, new_variables, self.strategy)
125 else:
126 raise ValueError(f"Unknown template strategy: {self.strategy}")
128 @classmethod
129 def from_conditional(cls, template: str, variables: List[str] | None = None) -> 'MessageTemplate':
130 """Create a MessageTemplate using the CONDITIONAL strategy.
132 Convenience method for creating templates with advanced conditional rendering.
134 Args:
135 template: Template string with {{variable}} and ((conditional)) syntax
136 variables: Optional explicit list of variables
138 Returns:
139 MessageTemplate configured with CONDITIONAL strategy
141 Example:
142 ```python
143 template = MessageTemplate.from_conditional(
144 "Hello {{name}}((, you have {{count}} messages))"
145 )
146 template.format(name="Alice", count=5)
147 # "Hello Alice, you have 5 messages"
148 template.format(name="Bob")
149 # "Hello Bob"
150 ```
151 """
152 return cls(template=template, variables=variables or [], strategy=TemplateStrategy.CONDITIONAL)
155class MessageBuilder:
156 """Builder for constructing message sequences."""
158 def __init__(self):
159 self.messages = []
161 def system(self, content: str) -> 'MessageBuilder':
162 """Add system message.
164 Args:
165 content: Message content
167 Returns:
168 Self for chaining
169 """
170 self.messages.append(LLMMessage(role='system', content=content))
171 return self
173 def user(self, content: str) -> 'MessageBuilder':
174 """Add user message.
176 Args:
177 content: Message content
179 Returns:
180 Self for chaining
181 """
182 self.messages.append(LLMMessage(role='user', content=content))
183 return self
185 def assistant(self, content: str) -> 'MessageBuilder':
186 """Add assistant message.
188 Args:
189 content: Message content
191 Returns:
192 Self for chaining
193 """
194 self.messages.append(LLMMessage(role='assistant', content=content))
195 return self
197 def function(
198 self,
199 name: str,
200 content: str,
201 function_call: Dict[str, Any] | None = None
202 ) -> 'MessageBuilder':
203 """Add function message.
205 Args:
206 name: Function name
207 content: Function result
208 function_call: Function call details
210 Returns:
211 Self for chaining
212 """
213 self.messages.append(LLMMessage(
214 role='function',
215 name=name,
216 content=content,
217 function_call=function_call
218 ))
219 return self
221 def from_template(
222 self,
223 role: str,
224 template: MessageTemplate,
225 **kwargs: Any
226 ) -> 'MessageBuilder':
227 """Add message from template.
229 Args:
230 role: Message role
231 template: Message template
232 **kwargs: Template variables
234 Returns:
235 Self for chaining
236 """
237 content = template.format(**kwargs)
238 self.messages.append(LLMMessage(role=role, content=content))
239 return self
241 def build(self) -> List[LLMMessage]:
242 """Build message list.
244 Returns:
245 List of messages
246 """
247 return self.messages.copy()
249 def clear(self) -> 'MessageBuilder':
250 """Clear all messages.
252 Returns:
253 Self for chaining
254 """
255 self.messages.clear()
256 return self
259class ResponseParser:
260 """Parser for LLM responses."""
262 @staticmethod
263 def extract_json(response: Union[str, LLMResponse]) -> Dict[str, Any] | None:
264 """Extract JSON from response.
266 Args:
267 response: LLM response
269 Returns:
270 Extracted JSON or None
271 """
272 text = response.content if isinstance(response, LLMResponse) else response
274 # Try to find JSON in the text
275 json_patterns = [
276 r'\{[^}]*\}', # Simple object
277 r'\[[^\]]*\]', # Array
278 r'```json\s*(.*?)\s*```', # Markdown code block
279 r'```\s*(.*?)\s*```', # Generic code block
280 ]
282 for pattern in json_patterns:
283 matches = re.findall(pattern, text, re.DOTALL)
284 for match in matches:
285 try:
286 return json.loads(match)
287 except json.JSONDecodeError:
288 continue
290 # Try parsing the entire text as JSON
291 try:
292 return json.loads(text)
293 except json.JSONDecodeError:
294 return None
296 @staticmethod
297 def extract_code(
298 response: Union[str, LLMResponse],
299 language: str | None = None
300 ) -> List[str]:
301 """Extract code blocks from response.
303 Args:
304 response: LLM response
305 language: Optional language filter
307 Returns:
308 List of code blocks
309 """
310 text = response.content if isinstance(response, LLMResponse) else response
312 if language:
313 # Language-specific code blocks
314 pattern = rf'```{language}\s*(.*?)\s*```'
315 else:
316 # All code blocks
317 pattern = r'```(?:\w+)?\s*(.*?)\s*```'
319 matches = re.findall(pattern, text, re.DOTALL)
320 return [m.strip() for m in matches]
322 @staticmethod
323 def extract_list(
324 response: Union[str, LLMResponse],
325 numbered: bool = False
326 ) -> List[str]:
327 """Extract list items from response.
329 Args:
330 response: LLM response
331 numbered: Whether to look for numbered lists
333 Returns:
334 List of items
335 """
336 text = response.content if isinstance(response, LLMResponse) else response
338 if numbered:
339 # Numbered list (1. item, 2. item, etc.)
340 pattern = r'^\d+\.\s+(.+)$'
341 else:
342 # Bullet points (-, *, •)
343 pattern = r'^[-*•]\s+(.+)$'
345 matches = re.findall(pattern, text, re.MULTILINE)
346 return [m.strip() for m in matches]
348 @staticmethod
349 def extract_sections(
350 response: Union[str, LLMResponse]
351 ) -> Dict[str, str]:
352 """Extract sections from response.
354 Args:
355 response: LLM response
357 Returns:
358 Dictionary of section name to content
359 """
360 text = response.content if isinstance(response, LLMResponse) else response
362 # Split by headers (# Header, ## Header, etc.)
363 sections = {}
364 current_section = 'main'
365 current_content = []
367 for line in text.split('\n'):
368 header_match = re.match(r'^#+\s+(.+)$', line)
369 if header_match:
370 # Save previous section
371 if current_content:
372 sections[current_section] = '\n'.join(current_content).strip()
373 # Start new section
374 current_section = header_match.group(1).strip()
375 current_content = []
376 else:
377 current_content.append(line)
379 # Save last section
380 if current_content:
381 sections[current_section] = '\n'.join(current_content).strip()
383 return sections
386class TokenCounter:
387 """Estimate token counts for different models."""
389 # Approximate tokens per character for different models
390 TOKENS_PER_CHAR = {
391 'gpt-4': 0.25,
392 'gpt-3.5': 0.25,
393 'claude': 0.25,
394 'llama': 0.3,
395 'default': 0.25
396 }
398 @classmethod
399 def estimate_tokens(
400 cls,
401 text: str,
402 model: str = 'default'
403 ) -> int:
404 """Estimate token count for text.
406 Args:
407 text: Input text
408 model: Model name
410 Returns:
411 Estimated token count
412 """
413 # Find matching model pattern
414 ratio = cls.TOKENS_PER_CHAR['default']
415 for pattern, r in cls.TOKENS_PER_CHAR.items():
416 if pattern in model.lower():
417 ratio = r
418 break
420 # Estimate based on character count
421 return int(len(text) * ratio)
423 @classmethod
424 def estimate_messages_tokens(
425 cls,
426 messages: List[LLMMessage],
427 model: str = 'default'
428 ) -> int:
429 """Estimate token count for messages.
431 Args:
432 messages: List of messages
433 model: Model name
435 Returns:
436 Estimated token count
437 """
438 total = 0
439 for msg in messages:
440 # Add role tokens (approximately 4 tokens)
441 total += 4
442 # Add content tokens
443 total += cls.estimate_tokens(msg.content, model)
444 # Add name tokens if present
445 if msg.name:
446 total += cls.estimate_tokens(msg.name, model)
448 return total
450 @classmethod
451 def fits_in_context(
452 cls,
453 text: str,
454 model: str,
455 max_tokens: int
456 ) -> bool:
457 """Check if text fits in context window.
459 Args:
460 text: Input text
461 model: Model name
462 max_tokens: Maximum token limit
464 Returns:
465 True if fits
466 """
467 estimated = cls.estimate_tokens(text, model)
468 return estimated <= max_tokens
471class CostCalculator:
472 """Calculate costs for LLM usage."""
474 # Cost per 1K tokens (in USD)
475 PRICING = {
476 'gpt-4': {'input': 0.03, 'output': 0.06},
477 'gpt-4-32k': {'input': 0.06, 'output': 0.12},
478 'gpt-3.5-turbo': {'input': 0.0015, 'output': 0.002},
479 'claude-3-opus': {'input': 0.015, 'output': 0.075},
480 'claude-3-sonnet': {'input': 0.003, 'output': 0.015},
481 'claude-3-haiku': {'input': 0.00025, 'output': 0.00125},
482 }
484 @classmethod
485 def calculate_cost(
486 cls,
487 response: LLMResponse,
488 model: str | None = None
489 ) -> float | None:
490 """Calculate cost for LLM response.
492 Args:
493 response: LLM response with usage info
494 model: Model name (if not in response)
496 Returns:
497 Cost in USD or None if cannot calculate
498 """
499 if not response.usage:
500 return None
502 model = model or response.model
504 # Find matching pricing
505 pricing = None
506 for pattern, prices in cls.PRICING.items():
507 if pattern in model.lower():
508 pricing = prices
509 break
511 if not pricing:
512 return None
514 # Calculate cost
515 input_cost = (response.usage.get('prompt_tokens', 0) / 1000) * pricing['input']
516 output_cost = (response.usage.get('completion_tokens', 0) / 1000) * pricing['output']
518 return input_cost + output_cost
520 @classmethod
521 def estimate_cost(
522 cls,
523 text: str,
524 model: str,
525 expected_output_tokens: int = 100
526 ) -> float | None:
527 """Estimate cost for text completion.
529 Args:
530 text: Input text
531 model: Model name
532 expected_output_tokens: Expected output length
534 Returns:
535 Estimated cost in USD
536 """
537 # Find matching pricing
538 pricing = None
539 for pattern, prices in cls.PRICING.items():
540 if pattern in model.lower():
541 pricing = prices
542 break
544 if not pricing:
545 return None
547 # Estimate tokens
548 input_tokens = TokenCounter.estimate_tokens(text, model)
550 # Calculate cost
551 input_cost = (input_tokens / 1000) * pricing['input']
552 output_cost = (expected_output_tokens / 1000) * pricing['output']
554 return input_cost + output_cost
557def chain_prompts(
558 *templates: MessageTemplate
559) -> MessageTemplate:
560 """Chain multiple message templates.
562 All templates must use the same strategy. The combined template
563 will use the strategy of the first template.
565 Args:
566 *templates: Templates to chain
568 Returns:
569 Combined template
571 Raises:
572 ValueError: If templates use different strategies
573 """
574 if not templates:
575 return MessageTemplate("", [])
577 # Check that all templates use the same strategy
578 first_strategy = templates[0].strategy
579 if not all(t.strategy == first_strategy for t in templates):
580 raise ValueError(
581 "Cannot chain templates with different strategies. "
582 "All templates must use the same TemplateStrategy."
583 )
585 combined_template = '\n\n'.join(t.template for t in templates)
586 combined_variables = []
587 seen = set()
589 for t in templates:
590 for var in t.variables:
591 if var not in seen:
592 combined_variables.append(var)
593 seen.add(var)
595 return MessageTemplate(combined_template, combined_variables, first_strategy)
598def create_few_shot_prompt(
599 instruction: str,
600 examples: List[Dict[str, str]],
601 query_key: str = 'input',
602 response_key: str = 'output'
603) -> MessageTemplate:
604 """Create few-shot learning prompt.
606 Args:
607 instruction: Task instruction
608 examples: List of example input/output pairs
609 query_key: Key for input in examples
610 response_key: Key for output in examples
612 Returns:
613 Few-shot prompt template
614 """
615 template_parts = [instruction, '']
617 # Add examples
618 for i, example in enumerate(examples, 1):
619 template_parts.append(f"Example {i}:")
620 template_parts.append(f"Input: {example[query_key]}")
621 template_parts.append(f"Output: {example[response_key]}")
622 template_parts.append('')
624 # Add query placeholder
625 template_parts.append("Now, process this input:")
626 template_parts.append("Input: {query}")
627 template_parts.append("Output:")
629 return MessageTemplate('\n'.join(template_parts), ['query'])