Coverage for src/dataknobs_llm/fsm_integration/functions.py: 18%
193 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
1"""Built-in LLM functions for FSM.
3This module provides LLM-related functions that can be referenced
4in FSM configurations for AI-powered workflows.
6Note: This module was migrated from dataknobs_fsm.functions.library.llm to
7consolidate all LLM functionality in the dataknobs-llm package.
8"""
10import asyncio
11import json
12from typing import Any, Callable, Dict, List
14from dataknobs_fsm.functions.base import (
15 ITransformFunction,
16 IValidationFunction,
17 TransformError,
18 ValidationError,
19)
20from dataknobs_llm.fsm_integration.resources import LLMResource
23class PromptBuilder(ITransformFunction):
24 """Build prompts for LLM calls."""
26 def __init__(
27 self,
28 template: str,
29 system_prompt: str | None = None,
30 variables: List[str] | None = None,
31 format_spec: str | None = None, # "json", "markdown", "plain"
32 ):
33 """Initialize the prompt builder.
35 Args:
36 template: Prompt template with {variable} placeholders.
37 system_prompt: Optional system prompt.
38 variables: List of variable names to extract from data.
39 format_spec: Output format specification.
40 """
41 self.template = template
42 self.system_prompt = system_prompt
43 self.variables = variables or []
44 self.format_spec = format_spec
46 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
47 """Transform data by building prompt.
49 Args:
50 data: Input data containing variables for prompt.
52 Returns:
53 Data with built prompt.
54 """
55 # Extract variables
56 variables = {}
57 for var in self.variables:
58 if var in data:
59 variables[var] = data[var]
60 else:
61 # Try nested access
62 parts = var.split(".")
63 value = data
64 for part in parts:
65 if isinstance(value, dict) and part in value:
66 value = value[part]
67 else:
68 value = None
69 break
70 if value is not None:
71 variables[var] = value
73 # Build prompt
74 try:
75 prompt = self.template.format(**variables)
76 except KeyError as e:
77 raise TransformError(f"Missing variable for prompt: {e}") from e
79 # Add format specification if provided
80 if self.format_spec:
81 if self.format_spec == "json":
82 prompt += "\n\nPlease respond with valid JSON only."
83 elif self.format_spec == "markdown":
84 prompt += "\n\nPlease format your response using Markdown."
86 result = {
87 **data,
88 "prompt": prompt,
89 }
91 if self.system_prompt:
92 result["system_prompt"] = self.system_prompt
94 return result
96 def get_transform_description(self) -> str:
97 """Get a description of the transformation.
99 Returns:
100 String describing what this transform does.
101 """
102 return f"Build prompt from template: {self.template}"
105class LLMCaller(ITransformFunction):
106 """Call an LLM with a prompt."""
108 def __init__(
109 self,
110 resource_name: str,
111 model: str | None = None,
112 temperature: float = 0.7,
113 max_tokens: int | None = None,
114 stream: bool = False,
115 response_field: str = "llm_response",
116 ):
117 """Initialize the LLM caller.
119 Args:
120 resource_name: Name of the LLM resource to use.
121 model: Model to use (if None, use resource default).
122 temperature: Temperature for generation.
123 max_tokens: Maximum tokens to generate.
124 stream: Whether to stream the response.
125 response_field: Field to store response in.
126 """
127 self.resource_name = resource_name
128 self.model = model
129 self.temperature = temperature
130 self.max_tokens = max_tokens
131 self.stream = stream
132 self.response_field = response_field
134 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
135 """Transform data by calling LLM.
137 Args:
138 data: Input data containing prompt.
140 Returns:
141 Data with LLM response.
142 """
143 # Get resource from context
144 resource = data.get("_resources", {}).get(self.resource_name)
145 if not resource or not isinstance(resource, LLMResource):
146 raise TransformError(f"LLM resource '{self.resource_name}' not found")
148 # Get prompt
149 prompt = data.get("prompt")
150 if not prompt:
151 raise TransformError("No prompt found in data")
153 system_prompt = data.get("system_prompt")
155 try:
156 # Call LLM
157 response = await resource.generate(
158 prompt=prompt,
159 system_prompt=system_prompt,
160 model=self.model,
161 temperature=self.temperature,
162 max_tokens=self.max_tokens,
163 stream=self.stream,
164 )
166 if self.stream:
167 # For streaming, return an async generator
168 return {
169 **data,
170 self.response_field: response, # Async generator
171 "is_streaming": True,
172 }
173 else:
174 # For non-streaming, return the full response
175 return {
176 **data,
177 self.response_field: response,
178 "tokens_used": response.get("usage", {}).get("total_tokens"),
179 }
181 except Exception as e:
182 raise TransformError(f"LLM call failed: {e}") from e
184 def get_transform_description(self) -> str:
185 """Get a description of the transformation.
187 Returns:
188 String describing what this transform does.
189 """
190 return f"Call LLM resource '{self.resource_name}' with prompt"
193class ResponseValidator(IValidationFunction):
194 """Validate LLM responses."""
196 def __init__(
197 self,
198 response_field: str = "llm_response",
199 format: str | None = None, # "json", "markdown", etc.
200 schema: Dict[str, Any] | None = None,
201 min_length: int | None = None,
202 max_length: int | None = None,
203 required_fields: List[str] | None = None,
204 ):
205 """Initialize the response validator.
207 Args:
208 response_field: Field containing LLM response.
209 format: Expected response format.
210 schema: JSON schema for validation (if format is JSON).
211 min_length: Minimum response length.
212 max_length: Maximum response length.
213 required_fields: Required fields in parsed response.
214 """
215 self.response_field = response_field
216 self.format = format
217 self.schema = schema
218 self.min_length = min_length
219 self.max_length = max_length
220 self.required_fields = required_fields or []
222 def validate(self, data: Dict[str, Any]) -> bool:
223 """Validate LLM response.
225 Args:
226 data: Data containing LLM response.
228 Returns:
229 True if valid.
231 Raises:
232 ValidationError: If validation fails.
233 """
234 response = data.get(self.response_field)
235 if response is None:
236 raise ValidationError(f"Response field '{self.response_field}' not found")
238 # Extract text from response object if needed
239 if isinstance(response, dict):
240 text = response.get("text", response.get("content", str(response)))
241 else:
242 text = str(response)
244 # Check length constraints
245 if self.min_length and len(text) < self.min_length: # type: ignore
246 raise ValidationError(
247 f"Response too short: {len(text)} < {self.min_length}" # type: ignore
248 )
250 if self.max_length and len(text) > self.max_length: # type: ignore
251 raise ValidationError(
252 f"Response too long: {len(text)} > {self.max_length}" # type: ignore
253 )
255 # Validate format
256 if self.format == "json":
257 try:
258 parsed = json.loads(text) # type: ignore
260 # Validate against schema if provided
261 if self.schema:
262 from pydantic import create_model, ValidationError as PydanticValidationError
263 model = create_model("ResponseSchema", **self.schema)
264 try:
265 model(**parsed)
266 except PydanticValidationError as e:
267 raise ValidationError(f"Schema validation failed: {e}") from e
269 # Check required fields
270 for field in self.required_fields:
271 if field not in parsed:
272 raise ValidationError(f"Required field missing: {field}")
274 except json.JSONDecodeError as e:
275 raise ValidationError(f"Invalid JSON response: {e}") from e
277 return True
279 def get_validation_rules(self) -> Dict[str, Any]:
280 """Get the validation rules this function implements.
282 Returns:
283 Dictionary describing the validation rules.
284 """
285 return {
286 "response_field": self.response_field,
287 "format": self.format,
288 "schema": self.schema,
289 "min_length": self.min_length,
290 "max_length": self.max_length,
291 "required_fields": self.required_fields,
292 }
295class FunctionCaller(ITransformFunction):
296 """Call functions/tools based on LLM output."""
298 def __init__(
299 self,
300 response_field: str = "llm_response",
301 function_registry: Dict[str, Callable] | None = None,
302 result_field: str = "function_result",
303 ):
304 """Initialize the function caller.
306 Args:
307 response_field: Field containing LLM response with function call.
308 function_registry: Registry of available functions.
309 result_field: Field to store function result.
310 """
311 self.response_field = response_field
312 self.function_registry = function_registry or {}
313 self.result_field = result_field
315 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
316 """Transform data by calling function from LLM response.
318 Args:
319 data: Input data containing LLM response with function call.
321 Returns:
322 Data with function result.
323 """
324 response = data.get(self.response_field)
325 if not response:
326 return data
328 # Parse function call from response
329 if isinstance(response, str):
330 try:
331 response = json.loads(response)
332 except json.JSONDecodeError:
333 # Not a JSON response, no function to call
334 return data
336 # Extract function call
337 function_name = response.get("function")
338 function_args = response.get("arguments", {})
340 if not function_name:
341 return data
343 # Look up function
344 if function_name not in self.function_registry:
345 raise TransformError(f"Function not found: {function_name}")
347 func = self.function_registry[function_name]
349 try:
350 # Call function
351 if asyncio.iscoroutinefunction(func):
352 result = await func(**function_args)
353 else:
354 result = func(**function_args)
356 return {
357 **data,
358 self.result_field: result,
359 "function_called": function_name,
360 }
362 except Exception as e:
363 raise TransformError(f"Function call failed: {e}") from e
365 def get_transform_description(self) -> str:
366 """Get a description of the transformation.
368 Returns:
369 String describing what this transform does.
370 """
371 available_funcs = ", ".join(self.function_registry.keys()) if self.function_registry else "none"
372 return f"Call function from LLM response (available: {available_funcs})"
375class ConversationManager(ITransformFunction):
376 """Manage conversation history for multi-turn interactions."""
378 def __init__(
379 self,
380 max_history: int = 10,
381 history_field: str = "conversation_history",
382 role_field: str = "role",
383 content_field: str = "content",
384 ):
385 """Initialize the conversation manager.
387 Args:
388 max_history: Maximum number of messages to keep.
389 history_field: Field to store conversation history.
390 role_field: Field for message role.
391 content_field: Field for message content.
392 """
393 self.max_history = max_history
394 self.history_field = history_field
395 self.role_field = role_field
396 self.content_field = content_field
398 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
399 """Transform data by managing conversation history.
401 Args:
402 data: Input data with new message.
404 Returns:
405 Data with updated conversation history.
406 """
407 # Get existing history
408 history = data.get(self.history_field, [])
410 # Add user message if present
411 if "prompt" in data:
412 history.append({
413 self.role_field: "user",
414 self.content_field: data["prompt"],
415 })
417 # Add assistant message if present
418 if "llm_response" in data:
419 response = data["llm_response"]
420 if isinstance(response, dict):
421 content = response.get("text", response.get("content", str(response)))
422 else:
423 content = str(response)
425 history.append({
426 self.role_field: "assistant",
427 self.content_field: content,
428 })
430 # Trim history if needed
431 if len(history) > self.max_history:
432 history = history[-self.max_history:]
434 return {
435 **data,
436 self.history_field: history,
437 }
439 def get_transform_description(self) -> str:
440 """Get a description of the transformation.
442 Returns:
443 String describing what this transform does.
444 """
445 return f"Manage conversation history (max {self.max_history} messages)"
448class EmbeddingGenerator(ITransformFunction):
449 """Generate embeddings for text using LLM."""
451 def __init__(
452 self,
453 resource_name: str,
454 text_field: str = "text",
455 embedding_field: str = "embedding",
456 model: str | None = None,
457 batch_size: int = 100,
458 ):
459 """Initialize the embedding generator.
461 Args:
462 resource_name: Name of the LLM resource to use.
463 text_field: Field containing text to embed.
464 embedding_field: Field to store embeddings.
465 model: Embedding model to use.
466 batch_size: Batch size for embedding generation.
467 """
468 self.resource_name = resource_name
469 self.text_field = text_field
470 self.embedding_field = embedding_field
471 self.model = model
472 self.batch_size = batch_size
474 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
475 """Transform data by generating embeddings.
477 Args:
478 data: Input data containing text.
480 Returns:
481 Data with embeddings.
482 """
483 # Get resource from context
484 resource = data.get("_resources", {}).get(self.resource_name)
485 if not resource or not isinstance(resource, LLMResource):
486 raise TransformError(f"LLM resource '{self.resource_name}' not found")
488 # Get text to embed
489 text = data.get(self.text_field)
490 if not text:
491 return data
493 try:
494 # Generate embedding(s)
495 if isinstance(text, list):
496 # Batch processing
497 embeddings = []
498 for i in range(0, len(text), self.batch_size):
499 batch = text[i:i + self.batch_size]
500 batch_embeddings = await resource.embed(batch, model=self.model)
501 embeddings.extend(batch_embeddings)
502 else:
503 # Single text
504 embeddings = await resource.embed(text, model=self.model)
506 return {
507 **data,
508 self.embedding_field: embeddings,
509 }
511 except Exception as e:
512 raise TransformError(f"Embedding generation failed: {e}") from e
514 def get_transform_description(self) -> str:
515 """Get a description of the transformation.
517 Returns:
518 String describing what this transform does.
519 """
520 return f"Generate embeddings using resource '{self.resource_name}'"
523# Convenience functions for creating LLM functions
524def build_prompt(template: str, **kwargs) -> PromptBuilder:
525 """Create a PromptBuilder."""
526 return PromptBuilder(template, **kwargs)
529def call_llm(resource: str, **kwargs) -> LLMCaller:
530 """Create an LLMCaller."""
531 return LLMCaller(resource, **kwargs)
534def validate_response(**kwargs) -> ResponseValidator:
535 """Create a ResponseValidator."""
536 return ResponseValidator(**kwargs)
539def call_function(**kwargs) -> FunctionCaller:
540 """Create a FunctionCaller."""
541 return FunctionCaller(**kwargs)
544def manage_conversation(**kwargs) -> ConversationManager:
545 """Create a ConversationManager."""
546 return ConversationManager(**kwargs)
549def generate_embeddings(resource: str, **kwargs) -> EmbeddingGenerator:
550 """Create an EmbeddingGenerator."""
551 return EmbeddingGenerator(resource, **kwargs)