Coverage for src/dataknobs_llm/fsm_integration/functions.py: 18%

193 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-08 13:51 -0700

1"""Built-in LLM functions for FSM. 

2 

3This module provides LLM-related functions that can be referenced 

4in FSM configurations for AI-powered workflows. 

5 

6Note: This module was migrated from dataknobs_fsm.functions.library.llm to 

7consolidate all LLM functionality in the dataknobs-llm package. 

8""" 

9 

10import asyncio 

11import json 

12from typing import Any, Callable, Dict, List 

13 

14from dataknobs_fsm.functions.base import ( 

15 ITransformFunction, 

16 IValidationFunction, 

17 TransformError, 

18 ValidationError, 

19) 

20from dataknobs_llm.fsm_integration.resources import LLMResource 

21 

22 

23class PromptBuilder(ITransformFunction): 

24 """Build prompts for LLM calls.""" 

25 

26 def __init__( 

27 self, 

28 template: str, 

29 system_prompt: str | None = None, 

30 variables: List[str] | None = None, 

31 format_spec: str | None = None, # "json", "markdown", "plain" 

32 ): 

33 """Initialize the prompt builder. 

34  

35 Args: 

36 template: Prompt template with {variable} placeholders. 

37 system_prompt: Optional system prompt. 

38 variables: List of variable names to extract from data. 

39 format_spec: Output format specification. 

40 """ 

41 self.template = template 

42 self.system_prompt = system_prompt 

43 self.variables = variables or [] 

44 self.format_spec = format_spec 

45 

46 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

47 """Transform data by building prompt. 

48  

49 Args: 

50 data: Input data containing variables for prompt. 

51  

52 Returns: 

53 Data with built prompt. 

54 """ 

55 # Extract variables 

56 variables = {} 

57 for var in self.variables: 

58 if var in data: 

59 variables[var] = data[var] 

60 else: 

61 # Try nested access 

62 parts = var.split(".") 

63 value = data 

64 for part in parts: 

65 if isinstance(value, dict) and part in value: 

66 value = value[part] 

67 else: 

68 value = None 

69 break 

70 if value is not None: 

71 variables[var] = value 

72 

73 # Build prompt 

74 try: 

75 prompt = self.template.format(**variables) 

76 except KeyError as e: 

77 raise TransformError(f"Missing variable for prompt: {e}") from e 

78 

79 # Add format specification if provided 

80 if self.format_spec: 

81 if self.format_spec == "json": 

82 prompt += "\n\nPlease respond with valid JSON only." 

83 elif self.format_spec == "markdown": 

84 prompt += "\n\nPlease format your response using Markdown." 

85 

86 result = { 

87 **data, 

88 "prompt": prompt, 

89 } 

90 

91 if self.system_prompt: 

92 result["system_prompt"] = self.system_prompt 

93 

94 return result 

95 

96 def get_transform_description(self) -> str: 

97 """Get a description of the transformation. 

98 

99 Returns: 

100 String describing what this transform does. 

101 """ 

102 return f"Build prompt from template: {self.template}" 

103 

104 

105class LLMCaller(ITransformFunction): 

106 """Call an LLM with a prompt.""" 

107 

108 def __init__( 

109 self, 

110 resource_name: str, 

111 model: str | None = None, 

112 temperature: float = 0.7, 

113 max_tokens: int | None = None, 

114 stream: bool = False, 

115 response_field: str = "llm_response", 

116 ): 

117 """Initialize the LLM caller. 

118  

119 Args: 

120 resource_name: Name of the LLM resource to use. 

121 model: Model to use (if None, use resource default). 

122 temperature: Temperature for generation. 

123 max_tokens: Maximum tokens to generate. 

124 stream: Whether to stream the response. 

125 response_field: Field to store response in. 

126 """ 

127 self.resource_name = resource_name 

128 self.model = model 

129 self.temperature = temperature 

130 self.max_tokens = max_tokens 

131 self.stream = stream 

132 self.response_field = response_field 

133 

134 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

135 """Transform data by calling LLM. 

136  

137 Args: 

138 data: Input data containing prompt. 

139  

140 Returns: 

141 Data with LLM response. 

142 """ 

143 # Get resource from context 

144 resource = data.get("_resources", {}).get(self.resource_name) 

145 if not resource or not isinstance(resource, LLMResource): 

146 raise TransformError(f"LLM resource '{self.resource_name}' not found") 

147 

148 # Get prompt 

149 prompt = data.get("prompt") 

150 if not prompt: 

151 raise TransformError("No prompt found in data") 

152 

153 system_prompt = data.get("system_prompt") 

154 

155 try: 

156 # Call LLM 

157 response = await resource.generate( 

158 prompt=prompt, 

159 system_prompt=system_prompt, 

160 model=self.model, 

161 temperature=self.temperature, 

162 max_tokens=self.max_tokens, 

163 stream=self.stream, 

164 ) 

165 

166 if self.stream: 

167 # For streaming, return an async generator 

168 return { 

169 **data, 

170 self.response_field: response, # Async generator 

171 "is_streaming": True, 

172 } 

173 else: 

174 # For non-streaming, return the full response 

175 return { 

176 **data, 

177 self.response_field: response, 

178 "tokens_used": response.get("usage", {}).get("total_tokens"), 

179 } 

180 

181 except Exception as e: 

182 raise TransformError(f"LLM call failed: {e}") from e 

183 

184 def get_transform_description(self) -> str: 

185 """Get a description of the transformation. 

186 

187 Returns: 

188 String describing what this transform does. 

189 """ 

190 return f"Call LLM resource '{self.resource_name}' with prompt" 

191 

192 

193class ResponseValidator(IValidationFunction): 

194 """Validate LLM responses.""" 

195 

196 def __init__( 

197 self, 

198 response_field: str = "llm_response", 

199 format: str | None = None, # "json", "markdown", etc. 

200 schema: Dict[str, Any] | None = None, 

201 min_length: int | None = None, 

202 max_length: int | None = None, 

203 required_fields: List[str] | None = None, 

204 ): 

205 """Initialize the response validator. 

206  

207 Args: 

208 response_field: Field containing LLM response. 

209 format: Expected response format. 

210 schema: JSON schema for validation (if format is JSON). 

211 min_length: Minimum response length. 

212 max_length: Maximum response length. 

213 required_fields: Required fields in parsed response. 

214 """ 

215 self.response_field = response_field 

216 self.format = format 

217 self.schema = schema 

218 self.min_length = min_length 

219 self.max_length = max_length 

220 self.required_fields = required_fields or [] 

221 

222 def validate(self, data: Dict[str, Any]) -> bool: 

223 """Validate LLM response. 

224  

225 Args: 

226 data: Data containing LLM response. 

227  

228 Returns: 

229 True if valid. 

230  

231 Raises: 

232 ValidationError: If validation fails. 

233 """ 

234 response = data.get(self.response_field) 

235 if response is None: 

236 raise ValidationError(f"Response field '{self.response_field}' not found") 

237 

238 # Extract text from response object if needed 

239 if isinstance(response, dict): 

240 text = response.get("text", response.get("content", str(response))) 

241 else: 

242 text = str(response) 

243 

244 # Check length constraints 

245 if self.min_length and len(text) < self.min_length: # type: ignore 

246 raise ValidationError( 

247 f"Response too short: {len(text)} < {self.min_length}" # type: ignore 

248 ) 

249 

250 if self.max_length and len(text) > self.max_length: # type: ignore 

251 raise ValidationError( 

252 f"Response too long: {len(text)} > {self.max_length}" # type: ignore 

253 ) 

254 

255 # Validate format 

256 if self.format == "json": 

257 try: 

258 parsed = json.loads(text) # type: ignore 

259 

260 # Validate against schema if provided 

261 if self.schema: 

262 from pydantic import create_model, ValidationError as PydanticValidationError 

263 model = create_model("ResponseSchema", **self.schema) 

264 try: 

265 model(**parsed) 

266 except PydanticValidationError as e: 

267 raise ValidationError(f"Schema validation failed: {e}") from e 

268 

269 # Check required fields 

270 for field in self.required_fields: 

271 if field not in parsed: 

272 raise ValidationError(f"Required field missing: {field}") 

273 

274 except json.JSONDecodeError as e: 

275 raise ValidationError(f"Invalid JSON response: {e}") from e 

276 

277 return True 

278 

279 def get_validation_rules(self) -> Dict[str, Any]: 

280 """Get the validation rules this function implements. 

281 

282 Returns: 

283 Dictionary describing the validation rules. 

284 """ 

285 return { 

286 "response_field": self.response_field, 

287 "format": self.format, 

288 "schema": self.schema, 

289 "min_length": self.min_length, 

290 "max_length": self.max_length, 

291 "required_fields": self.required_fields, 

292 } 

293 

294 

295class FunctionCaller(ITransformFunction): 

296 """Call functions/tools based on LLM output.""" 

297 

298 def __init__( 

299 self, 

300 response_field: str = "llm_response", 

301 function_registry: Dict[str, Callable] | None = None, 

302 result_field: str = "function_result", 

303 ): 

304 """Initialize the function caller. 

305  

306 Args: 

307 response_field: Field containing LLM response with function call. 

308 function_registry: Registry of available functions. 

309 result_field: Field to store function result. 

310 """ 

311 self.response_field = response_field 

312 self.function_registry = function_registry or {} 

313 self.result_field = result_field 

314 

315 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

316 """Transform data by calling function from LLM response. 

317  

318 Args: 

319 data: Input data containing LLM response with function call. 

320  

321 Returns: 

322 Data with function result. 

323 """ 

324 response = data.get(self.response_field) 

325 if not response: 

326 return data 

327 

328 # Parse function call from response 

329 if isinstance(response, str): 

330 try: 

331 response = json.loads(response) 

332 except json.JSONDecodeError: 

333 # Not a JSON response, no function to call 

334 return data 

335 

336 # Extract function call 

337 function_name = response.get("function") 

338 function_args = response.get("arguments", {}) 

339 

340 if not function_name: 

341 return data 

342 

343 # Look up function 

344 if function_name not in self.function_registry: 

345 raise TransformError(f"Function not found: {function_name}") 

346 

347 func = self.function_registry[function_name] 

348 

349 try: 

350 # Call function 

351 if asyncio.iscoroutinefunction(func): 

352 result = await func(**function_args) 

353 else: 

354 result = func(**function_args) 

355 

356 return { 

357 **data, 

358 self.result_field: result, 

359 "function_called": function_name, 

360 } 

361 

362 except Exception as e: 

363 raise TransformError(f"Function call failed: {e}") from e 

364 

365 def get_transform_description(self) -> str: 

366 """Get a description of the transformation. 

367 

368 Returns: 

369 String describing what this transform does. 

370 """ 

371 available_funcs = ", ".join(self.function_registry.keys()) if self.function_registry else "none" 

372 return f"Call function from LLM response (available: {available_funcs})" 

373 

374 

375class ConversationManager(ITransformFunction): 

376 """Manage conversation history for multi-turn interactions.""" 

377 

378 def __init__( 

379 self, 

380 max_history: int = 10, 

381 history_field: str = "conversation_history", 

382 role_field: str = "role", 

383 content_field: str = "content", 

384 ): 

385 """Initialize the conversation manager. 

386  

387 Args: 

388 max_history: Maximum number of messages to keep. 

389 history_field: Field to store conversation history. 

390 role_field: Field for message role. 

391 content_field: Field for message content. 

392 """ 

393 self.max_history = max_history 

394 self.history_field = history_field 

395 self.role_field = role_field 

396 self.content_field = content_field 

397 

398 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

399 """Transform data by managing conversation history. 

400  

401 Args: 

402 data: Input data with new message. 

403  

404 Returns: 

405 Data with updated conversation history. 

406 """ 

407 # Get existing history 

408 history = data.get(self.history_field, []) 

409 

410 # Add user message if present 

411 if "prompt" in data: 

412 history.append({ 

413 self.role_field: "user", 

414 self.content_field: data["prompt"], 

415 }) 

416 

417 # Add assistant message if present 

418 if "llm_response" in data: 

419 response = data["llm_response"] 

420 if isinstance(response, dict): 

421 content = response.get("text", response.get("content", str(response))) 

422 else: 

423 content = str(response) 

424 

425 history.append({ 

426 self.role_field: "assistant", 

427 self.content_field: content, 

428 }) 

429 

430 # Trim history if needed 

431 if len(history) > self.max_history: 

432 history = history[-self.max_history:] 

433 

434 return { 

435 **data, 

436 self.history_field: history, 

437 } 

438 

439 def get_transform_description(self) -> str: 

440 """Get a description of the transformation. 

441 

442 Returns: 

443 String describing what this transform does. 

444 """ 

445 return f"Manage conversation history (max {self.max_history} messages)" 

446 

447 

448class EmbeddingGenerator(ITransformFunction): 

449 """Generate embeddings for text using LLM.""" 

450 

451 def __init__( 

452 self, 

453 resource_name: str, 

454 text_field: str = "text", 

455 embedding_field: str = "embedding", 

456 model: str | None = None, 

457 batch_size: int = 100, 

458 ): 

459 """Initialize the embedding generator. 

460  

461 Args: 

462 resource_name: Name of the LLM resource to use. 

463 text_field: Field containing text to embed. 

464 embedding_field: Field to store embeddings. 

465 model: Embedding model to use. 

466 batch_size: Batch size for embedding generation. 

467 """ 

468 self.resource_name = resource_name 

469 self.text_field = text_field 

470 self.embedding_field = embedding_field 

471 self.model = model 

472 self.batch_size = batch_size 

473 

474 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

475 """Transform data by generating embeddings. 

476  

477 Args: 

478 data: Input data containing text. 

479  

480 Returns: 

481 Data with embeddings. 

482 """ 

483 # Get resource from context 

484 resource = data.get("_resources", {}).get(self.resource_name) 

485 if not resource or not isinstance(resource, LLMResource): 

486 raise TransformError(f"LLM resource '{self.resource_name}' not found") 

487 

488 # Get text to embed 

489 text = data.get(self.text_field) 

490 if not text: 

491 return data 

492 

493 try: 

494 # Generate embedding(s) 

495 if isinstance(text, list): 

496 # Batch processing 

497 embeddings = [] 

498 for i in range(0, len(text), self.batch_size): 

499 batch = text[i:i + self.batch_size] 

500 batch_embeddings = await resource.embed(batch, model=self.model) 

501 embeddings.extend(batch_embeddings) 

502 else: 

503 # Single text 

504 embeddings = await resource.embed(text, model=self.model) 

505 

506 return { 

507 **data, 

508 self.embedding_field: embeddings, 

509 } 

510 

511 except Exception as e: 

512 raise TransformError(f"Embedding generation failed: {e}") from e 

513 

514 def get_transform_description(self) -> str: 

515 """Get a description of the transformation. 

516 

517 Returns: 

518 String describing what this transform does. 

519 """ 

520 return f"Generate embeddings using resource '{self.resource_name}'" 

521 

522 

523# Convenience functions for creating LLM functions 

524def build_prompt(template: str, **kwargs) -> PromptBuilder: 

525 """Create a PromptBuilder.""" 

526 return PromptBuilder(template, **kwargs) 

527 

528 

529def call_llm(resource: str, **kwargs) -> LLMCaller: 

530 """Create an LLMCaller.""" 

531 return LLMCaller(resource, **kwargs) 

532 

533 

534def validate_response(**kwargs) -> ResponseValidator: 

535 """Create a ResponseValidator.""" 

536 return ResponseValidator(**kwargs) 

537 

538 

539def call_function(**kwargs) -> FunctionCaller: 

540 """Create a FunctionCaller.""" 

541 return FunctionCaller(**kwargs) 

542 

543 

544def manage_conversation(**kwargs) -> ConversationManager: 

545 """Create a ConversationManager.""" 

546 return ConversationManager(**kwargs) 

547 

548 

549def generate_embeddings(resource: str, **kwargs) -> EmbeddingGenerator: 

550 """Create an EmbeddingGenerator.""" 

551 return EmbeddingGenerator(resource, **kwargs)