Coverage for src/dataknobs_llm/fsm_integration/functions.py: 0%
180 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-31 15:21 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-31 15:21 -0600
1"""Built-in LLM functions for FSM.
3This module provides LLM-related functions that can be referenced
4in FSM configurations for AI-powered workflows.
6Note: This module was migrated from dataknobs_fsm.functions.library.llm to
7consolidate all LLM functionality in the dataknobs-llm package.
8"""
10import asyncio
11import json
12from typing import Any, Callable, Dict, List
14from dataknobs_fsm.functions.base import (
15 ITransformFunction,
16 IValidationFunction,
17 TransformError,
18 ValidationError,
19)
20from dataknobs_llm.fsm_integration.resources import LLMResource
23class PromptBuilder(ITransformFunction):
24 """Build prompts for LLM calls."""
26 def __init__(
27 self,
28 template: str,
29 system_prompt: str | None = None,
30 variables: List[str] | None = None,
31 format_spec: str | None = None, # "json", "markdown", "plain"
32 ):
33 """Initialize the prompt builder.
35 Args:
36 template: Prompt template with {variable} placeholders.
37 system_prompt: Optional system prompt.
38 variables: List of variable names to extract from data.
39 format_spec: Output format specification.
40 """
41 self.template = template
42 self.system_prompt = system_prompt
43 self.variables = variables or []
44 self.format_spec = format_spec
46 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
47 """Transform data by building prompt.
49 Args:
50 data: Input data containing variables for prompt.
52 Returns:
53 Data with built prompt.
54 """
55 # Extract variables
56 variables = {}
57 for var in self.variables:
58 if var in data:
59 variables[var] = data[var]
60 else:
61 # Try nested access
62 parts = var.split(".")
63 value = data
64 for part in parts:
65 if isinstance(value, dict) and part in value:
66 value = value[part]
67 else:
68 value = None
69 break
70 if value is not None:
71 variables[var] = value
73 # Build prompt
74 try:
75 prompt = self.template.format(**variables)
76 except KeyError as e:
77 raise TransformError(f"Missing variable for prompt: {e}") from e
79 # Add format specification if provided
80 if self.format_spec:
81 if self.format_spec == "json":
82 prompt += "\n\nPlease respond with valid JSON only."
83 elif self.format_spec == "markdown":
84 prompt += "\n\nPlease format your response using Markdown."
86 result = {
87 **data,
88 "prompt": prompt,
89 }
91 if self.system_prompt:
92 result["system_prompt"] = self.system_prompt
94 return result
97class LLMCaller(ITransformFunction):
98 """Call an LLM with a prompt."""
100 def __init__(
101 self,
102 resource_name: str,
103 model: str | None = None,
104 temperature: float = 0.7,
105 max_tokens: int | None = None,
106 stream: bool = False,
107 response_field: str = "llm_response",
108 ):
109 """Initialize the LLM caller.
111 Args:
112 resource_name: Name of the LLM resource to use.
113 model: Model to use (if None, use resource default).
114 temperature: Temperature for generation.
115 max_tokens: Maximum tokens to generate.
116 stream: Whether to stream the response.
117 response_field: Field to store response in.
118 """
119 self.resource_name = resource_name
120 self.model = model
121 self.temperature = temperature
122 self.max_tokens = max_tokens
123 self.stream = stream
124 self.response_field = response_field
126 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
127 """Transform data by calling LLM.
129 Args:
130 data: Input data containing prompt.
132 Returns:
133 Data with LLM response.
134 """
135 # Get resource from context
136 resource = data.get("_resources", {}).get(self.resource_name)
137 if not resource or not isinstance(resource, LLMResource):
138 raise TransformError(f"LLM resource '{self.resource_name}' not found")
140 # Get prompt
141 prompt = data.get("prompt")
142 if not prompt:
143 raise TransformError("No prompt found in data")
145 system_prompt = data.get("system_prompt")
147 try:
148 # Call LLM
149 response = await resource.generate(
150 prompt=prompt,
151 system_prompt=system_prompt,
152 model=self.model,
153 temperature=self.temperature,
154 max_tokens=self.max_tokens,
155 stream=self.stream,
156 )
158 if self.stream:
159 # For streaming, return an async generator
160 return {
161 **data,
162 self.response_field: response, # Async generator
163 "is_streaming": True,
164 }
165 else:
166 # For non-streaming, return the full response
167 return {
168 **data,
169 self.response_field: response,
170 "tokens_used": response.get("usage", {}).get("total_tokens"),
171 }
173 except Exception as e:
174 raise TransformError(f"LLM call failed: {e}") from e
177class ResponseValidator(IValidationFunction):
178 """Validate LLM responses."""
180 def __init__(
181 self,
182 response_field: str = "llm_response",
183 format: str | None = None, # "json", "markdown", etc.
184 schema: Dict[str, Any] | None = None,
185 min_length: int | None = None,
186 max_length: int | None = None,
187 required_fields: List[str] | None = None,
188 ):
189 """Initialize the response validator.
191 Args:
192 response_field: Field containing LLM response.
193 format: Expected response format.
194 schema: JSON schema for validation (if format is JSON).
195 min_length: Minimum response length.
196 max_length: Maximum response length.
197 required_fields: Required fields in parsed response.
198 """
199 self.response_field = response_field
200 self.format = format
201 self.schema = schema
202 self.min_length = min_length
203 self.max_length = max_length
204 self.required_fields = required_fields or []
206 def validate(self, data: Dict[str, Any]) -> bool:
207 """Validate LLM response.
209 Args:
210 data: Data containing LLM response.
212 Returns:
213 True if valid.
215 Raises:
216 ValidationError: If validation fails.
217 """
218 response = data.get(self.response_field)
219 if response is None:
220 raise ValidationError(f"Response field '{self.response_field}' not found")
222 # Extract text from response object if needed
223 if isinstance(response, dict):
224 text = response.get("text", response.get("content", str(response)))
225 else:
226 text = str(response)
228 # Check length constraints
229 if self.min_length and len(text) < self.min_length: # type: ignore
230 raise ValidationError(
231 f"Response too short: {len(text)} < {self.min_length}" # type: ignore
232 )
234 if self.max_length and len(text) > self.max_length: # type: ignore
235 raise ValidationError(
236 f"Response too long: {len(text)} > {self.max_length}" # type: ignore
237 )
239 # Validate format
240 if self.format == "json":
241 try:
242 parsed = json.loads(text) # type: ignore
244 # Validate against schema if provided
245 if self.schema:
246 from pydantic import create_model, ValidationError
247 model = create_model("ResponseSchema", **self.schema)
248 try:
249 model(**parsed)
250 except ValidationError as e:
251 raise ValidationError(f"Schema validation failed: {e}") from e
253 # Check required fields
254 for field in self.required_fields:
255 if field not in parsed:
256 raise ValidationError(f"Required field missing: {field}")
258 except json.JSONDecodeError as e:
259 raise ValidationError(f"Invalid JSON response: {e}") from e
261 return True
264class FunctionCaller(ITransformFunction):
265 """Call functions/tools based on LLM output."""
267 def __init__(
268 self,
269 response_field: str = "llm_response",
270 function_registry: Dict[str, Callable] | None = None,
271 result_field: str = "function_result",
272 ):
273 """Initialize the function caller.
275 Args:
276 response_field: Field containing LLM response with function call.
277 function_registry: Registry of available functions.
278 result_field: Field to store function result.
279 """
280 self.response_field = response_field
281 self.function_registry = function_registry or {}
282 self.result_field = result_field
284 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
285 """Transform data by calling function from LLM response.
287 Args:
288 data: Input data containing LLM response with function call.
290 Returns:
291 Data with function result.
292 """
293 response = data.get(self.response_field)
294 if not response:
295 return data
297 # Parse function call from response
298 if isinstance(response, str):
299 try:
300 response = json.loads(response)
301 except json.JSONDecodeError:
302 # Not a JSON response, no function to call
303 return data
305 # Extract function call
306 function_name = response.get("function")
307 function_args = response.get("arguments", {})
309 if not function_name:
310 return data
312 # Look up function
313 if function_name not in self.function_registry:
314 raise TransformError(f"Function not found: {function_name}")
316 func = self.function_registry[function_name]
318 try:
319 # Call function
320 if asyncio.iscoroutinefunction(func):
321 result = await func(**function_args)
322 else:
323 result = func(**function_args)
325 return {
326 **data,
327 self.result_field: result,
328 "function_called": function_name,
329 }
331 except Exception as e:
332 raise TransformError(f"Function call failed: {e}") from e
335class ConversationManager(ITransformFunction):
336 """Manage conversation history for multi-turn interactions."""
338 def __init__(
339 self,
340 max_history: int = 10,
341 history_field: str = "conversation_history",
342 role_field: str = "role",
343 content_field: str = "content",
344 ):
345 """Initialize the conversation manager.
347 Args:
348 max_history: Maximum number of messages to keep.
349 history_field: Field to store conversation history.
350 role_field: Field for message role.
351 content_field: Field for message content.
352 """
353 self.max_history = max_history
354 self.history_field = history_field
355 self.role_field = role_field
356 self.content_field = content_field
358 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
359 """Transform data by managing conversation history.
361 Args:
362 data: Input data with new message.
364 Returns:
365 Data with updated conversation history.
366 """
367 # Get existing history
368 history = data.get(self.history_field, [])
370 # Add user message if present
371 if "prompt" in data:
372 history.append({
373 self.role_field: "user",
374 self.content_field: data["prompt"],
375 })
377 # Add assistant message if present
378 if "llm_response" in data:
379 response = data["llm_response"]
380 if isinstance(response, dict):
381 content = response.get("text", response.get("content", str(response)))
382 else:
383 content = str(response)
385 history.append({
386 self.role_field: "assistant",
387 self.content_field: content,
388 })
390 # Trim history if needed
391 if len(history) > self.max_history:
392 history = history[-self.max_history:]
394 return {
395 **data,
396 self.history_field: history,
397 }
400class EmbeddingGenerator(ITransformFunction):
401 """Generate embeddings for text using LLM."""
403 def __init__(
404 self,
405 resource_name: str,
406 text_field: str = "text",
407 embedding_field: str = "embedding",
408 model: str | None = None,
409 batch_size: int = 100,
410 ):
411 """Initialize the embedding generator.
413 Args:
414 resource_name: Name of the LLM resource to use.
415 text_field: Field containing text to embed.
416 embedding_field: Field to store embeddings.
417 model: Embedding model to use.
418 batch_size: Batch size for embedding generation.
419 """
420 self.resource_name = resource_name
421 self.text_field = text_field
422 self.embedding_field = embedding_field
423 self.model = model
424 self.batch_size = batch_size
426 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
427 """Transform data by generating embeddings.
429 Args:
430 data: Input data containing text.
432 Returns:
433 Data with embeddings.
434 """
435 # Get resource from context
436 resource = data.get("_resources", {}).get(self.resource_name)
437 if not resource or not isinstance(resource, LLMResource):
438 raise TransformError(f"LLM resource '{self.resource_name}' not found")
440 # Get text to embed
441 text = data.get(self.text_field)
442 if not text:
443 return data
445 try:
446 # Generate embedding(s)
447 if isinstance(text, list):
448 # Batch processing
449 embeddings = []
450 for i in range(0, len(text), self.batch_size):
451 batch = text[i:i + self.batch_size]
452 batch_embeddings = await resource.embed(batch, model=self.model)
453 embeddings.extend(batch_embeddings)
454 else:
455 # Single text
456 embeddings = await resource.embed(text, model=self.model)
458 return {
459 **data,
460 self.embedding_field: embeddings,
461 }
463 except Exception as e:
464 raise TransformError(f"Embedding generation failed: {e}") from e
467# Convenience functions for creating LLM functions
468def build_prompt(template: str, **kwargs) -> PromptBuilder:
469 """Create a PromptBuilder."""
470 return PromptBuilder(template, **kwargs)
473def call_llm(resource: str, **kwargs) -> LLMCaller:
474 """Create an LLMCaller."""
475 return LLMCaller(resource, **kwargs)
478def validate_response(**kwargs) -> ResponseValidator:
479 """Create a ResponseValidator."""
480 return ResponseValidator(**kwargs)
483def call_function(**kwargs) -> FunctionCaller:
484 """Create a FunctionCaller."""
485 return FunctionCaller(**kwargs)
488def manage_conversation(**kwargs) -> ConversationManager:
489 """Create a ConversationManager."""
490 return ConversationManager(**kwargs)
493def generate_embeddings(resource: str, **kwargs) -> EmbeddingGenerator:
494 """Create an EmbeddingGenerator."""
495 return EmbeddingGenerator(resource, **kwargs)