Coverage for src/dataknobs_llm/fsm_integration/functions.py: 0%

180 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-31 15:21 -0600

1"""Built-in LLM functions for FSM. 

2 

3This module provides LLM-related functions that can be referenced 

4in FSM configurations for AI-powered workflows. 

5 

6Note: This module was migrated from dataknobs_fsm.functions.library.llm to 

7consolidate all LLM functionality in the dataknobs-llm package. 

8""" 

9 

10import asyncio 

11import json 

12from typing import Any, Callable, Dict, List 

13 

14from dataknobs_fsm.functions.base import ( 

15 ITransformFunction, 

16 IValidationFunction, 

17 TransformError, 

18 ValidationError, 

19) 

20from dataknobs_llm.fsm_integration.resources import LLMResource 

21 

22 

23class PromptBuilder(ITransformFunction): 

24 """Build prompts for LLM calls.""" 

25 

26 def __init__( 

27 self, 

28 template: str, 

29 system_prompt: str | None = None, 

30 variables: List[str] | None = None, 

31 format_spec: str | None = None, # "json", "markdown", "plain" 

32 ): 

33 """Initialize the prompt builder. 

34  

35 Args: 

36 template: Prompt template with {variable} placeholders. 

37 system_prompt: Optional system prompt. 

38 variables: List of variable names to extract from data. 

39 format_spec: Output format specification. 

40 """ 

41 self.template = template 

42 self.system_prompt = system_prompt 

43 self.variables = variables or [] 

44 self.format_spec = format_spec 

45 

46 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

47 """Transform data by building prompt. 

48  

49 Args: 

50 data: Input data containing variables for prompt. 

51  

52 Returns: 

53 Data with built prompt. 

54 """ 

55 # Extract variables 

56 variables = {} 

57 for var in self.variables: 

58 if var in data: 

59 variables[var] = data[var] 

60 else: 

61 # Try nested access 

62 parts = var.split(".") 

63 value = data 

64 for part in parts: 

65 if isinstance(value, dict) and part in value: 

66 value = value[part] 

67 else: 

68 value = None 

69 break 

70 if value is not None: 

71 variables[var] = value 

72 

73 # Build prompt 

74 try: 

75 prompt = self.template.format(**variables) 

76 except KeyError as e: 

77 raise TransformError(f"Missing variable for prompt: {e}") from e 

78 

79 # Add format specification if provided 

80 if self.format_spec: 

81 if self.format_spec == "json": 

82 prompt += "\n\nPlease respond with valid JSON only." 

83 elif self.format_spec == "markdown": 

84 prompt += "\n\nPlease format your response using Markdown." 

85 

86 result = { 

87 **data, 

88 "prompt": prompt, 

89 } 

90 

91 if self.system_prompt: 

92 result["system_prompt"] = self.system_prompt 

93 

94 return result 

95 

96 

97class LLMCaller(ITransformFunction): 

98 """Call an LLM with a prompt.""" 

99 

100 def __init__( 

101 self, 

102 resource_name: str, 

103 model: str | None = None, 

104 temperature: float = 0.7, 

105 max_tokens: int | None = None, 

106 stream: bool = False, 

107 response_field: str = "llm_response", 

108 ): 

109 """Initialize the LLM caller. 

110  

111 Args: 

112 resource_name: Name of the LLM resource to use. 

113 model: Model to use (if None, use resource default). 

114 temperature: Temperature for generation. 

115 max_tokens: Maximum tokens to generate. 

116 stream: Whether to stream the response. 

117 response_field: Field to store response in. 

118 """ 

119 self.resource_name = resource_name 

120 self.model = model 

121 self.temperature = temperature 

122 self.max_tokens = max_tokens 

123 self.stream = stream 

124 self.response_field = response_field 

125 

126 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

127 """Transform data by calling LLM. 

128  

129 Args: 

130 data: Input data containing prompt. 

131  

132 Returns: 

133 Data with LLM response. 

134 """ 

135 # Get resource from context 

136 resource = data.get("_resources", {}).get(self.resource_name) 

137 if not resource or not isinstance(resource, LLMResource): 

138 raise TransformError(f"LLM resource '{self.resource_name}' not found") 

139 

140 # Get prompt 

141 prompt = data.get("prompt") 

142 if not prompt: 

143 raise TransformError("No prompt found in data") 

144 

145 system_prompt = data.get("system_prompt") 

146 

147 try: 

148 # Call LLM 

149 response = await resource.generate( 

150 prompt=prompt, 

151 system_prompt=system_prompt, 

152 model=self.model, 

153 temperature=self.temperature, 

154 max_tokens=self.max_tokens, 

155 stream=self.stream, 

156 ) 

157 

158 if self.stream: 

159 # For streaming, return an async generator 

160 return { 

161 **data, 

162 self.response_field: response, # Async generator 

163 "is_streaming": True, 

164 } 

165 else: 

166 # For non-streaming, return the full response 

167 return { 

168 **data, 

169 self.response_field: response, 

170 "tokens_used": response.get("usage", {}).get("total_tokens"), 

171 } 

172 

173 except Exception as e: 

174 raise TransformError(f"LLM call failed: {e}") from e 

175 

176 

177class ResponseValidator(IValidationFunction): 

178 """Validate LLM responses.""" 

179 

180 def __init__( 

181 self, 

182 response_field: str = "llm_response", 

183 format: str | None = None, # "json", "markdown", etc. 

184 schema: Dict[str, Any] | None = None, 

185 min_length: int | None = None, 

186 max_length: int | None = None, 

187 required_fields: List[str] | None = None, 

188 ): 

189 """Initialize the response validator. 

190  

191 Args: 

192 response_field: Field containing LLM response. 

193 format: Expected response format. 

194 schema: JSON schema for validation (if format is JSON). 

195 min_length: Minimum response length. 

196 max_length: Maximum response length. 

197 required_fields: Required fields in parsed response. 

198 """ 

199 self.response_field = response_field 

200 self.format = format 

201 self.schema = schema 

202 self.min_length = min_length 

203 self.max_length = max_length 

204 self.required_fields = required_fields or [] 

205 

206 def validate(self, data: Dict[str, Any]) -> bool: 

207 """Validate LLM response. 

208  

209 Args: 

210 data: Data containing LLM response. 

211  

212 Returns: 

213 True if valid. 

214  

215 Raises: 

216 ValidationError: If validation fails. 

217 """ 

218 response = data.get(self.response_field) 

219 if response is None: 

220 raise ValidationError(f"Response field '{self.response_field}' not found") 

221 

222 # Extract text from response object if needed 

223 if isinstance(response, dict): 

224 text = response.get("text", response.get("content", str(response))) 

225 else: 

226 text = str(response) 

227 

228 # Check length constraints 

229 if self.min_length and len(text) < self.min_length: # type: ignore 

230 raise ValidationError( 

231 f"Response too short: {len(text)} < {self.min_length}" # type: ignore 

232 ) 

233 

234 if self.max_length and len(text) > self.max_length: # type: ignore 

235 raise ValidationError( 

236 f"Response too long: {len(text)} > {self.max_length}" # type: ignore 

237 ) 

238 

239 # Validate format 

240 if self.format == "json": 

241 try: 

242 parsed = json.loads(text) # type: ignore 

243 

244 # Validate against schema if provided 

245 if self.schema: 

246 from pydantic import create_model, ValidationError 

247 model = create_model("ResponseSchema", **self.schema) 

248 try: 

249 model(**parsed) 

250 except ValidationError as e: 

251 raise ValidationError(f"Schema validation failed: {e}") from e 

252 

253 # Check required fields 

254 for field in self.required_fields: 

255 if field not in parsed: 

256 raise ValidationError(f"Required field missing: {field}") 

257 

258 except json.JSONDecodeError as e: 

259 raise ValidationError(f"Invalid JSON response: {e}") from e 

260 

261 return True 

262 

263 

264class FunctionCaller(ITransformFunction): 

265 """Call functions/tools based on LLM output.""" 

266 

267 def __init__( 

268 self, 

269 response_field: str = "llm_response", 

270 function_registry: Dict[str, Callable] | None = None, 

271 result_field: str = "function_result", 

272 ): 

273 """Initialize the function caller. 

274  

275 Args: 

276 response_field: Field containing LLM response with function call. 

277 function_registry: Registry of available functions. 

278 result_field: Field to store function result. 

279 """ 

280 self.response_field = response_field 

281 self.function_registry = function_registry or {} 

282 self.result_field = result_field 

283 

284 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

285 """Transform data by calling function from LLM response. 

286  

287 Args: 

288 data: Input data containing LLM response with function call. 

289  

290 Returns: 

291 Data with function result. 

292 """ 

293 response = data.get(self.response_field) 

294 if not response: 

295 return data 

296 

297 # Parse function call from response 

298 if isinstance(response, str): 

299 try: 

300 response = json.loads(response) 

301 except json.JSONDecodeError: 

302 # Not a JSON response, no function to call 

303 return data 

304 

305 # Extract function call 

306 function_name = response.get("function") 

307 function_args = response.get("arguments", {}) 

308 

309 if not function_name: 

310 return data 

311 

312 # Look up function 

313 if function_name not in self.function_registry: 

314 raise TransformError(f"Function not found: {function_name}") 

315 

316 func = self.function_registry[function_name] 

317 

318 try: 

319 # Call function 

320 if asyncio.iscoroutinefunction(func): 

321 result = await func(**function_args) 

322 else: 

323 result = func(**function_args) 

324 

325 return { 

326 **data, 

327 self.result_field: result, 

328 "function_called": function_name, 

329 } 

330 

331 except Exception as e: 

332 raise TransformError(f"Function call failed: {e}") from e 

333 

334 

335class ConversationManager(ITransformFunction): 

336 """Manage conversation history for multi-turn interactions.""" 

337 

338 def __init__( 

339 self, 

340 max_history: int = 10, 

341 history_field: str = "conversation_history", 

342 role_field: str = "role", 

343 content_field: str = "content", 

344 ): 

345 """Initialize the conversation manager. 

346  

347 Args: 

348 max_history: Maximum number of messages to keep. 

349 history_field: Field to store conversation history. 

350 role_field: Field for message role. 

351 content_field: Field for message content. 

352 """ 

353 self.max_history = max_history 

354 self.history_field = history_field 

355 self.role_field = role_field 

356 self.content_field = content_field 

357 

358 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

359 """Transform data by managing conversation history. 

360  

361 Args: 

362 data: Input data with new message. 

363  

364 Returns: 

365 Data with updated conversation history. 

366 """ 

367 # Get existing history 

368 history = data.get(self.history_field, []) 

369 

370 # Add user message if present 

371 if "prompt" in data: 

372 history.append({ 

373 self.role_field: "user", 

374 self.content_field: data["prompt"], 

375 }) 

376 

377 # Add assistant message if present 

378 if "llm_response" in data: 

379 response = data["llm_response"] 

380 if isinstance(response, dict): 

381 content = response.get("text", response.get("content", str(response))) 

382 else: 

383 content = str(response) 

384 

385 history.append({ 

386 self.role_field: "assistant", 

387 self.content_field: content, 

388 }) 

389 

390 # Trim history if needed 

391 if len(history) > self.max_history: 

392 history = history[-self.max_history:] 

393 

394 return { 

395 **data, 

396 self.history_field: history, 

397 } 

398 

399 

400class EmbeddingGenerator(ITransformFunction): 

401 """Generate embeddings for text using LLM.""" 

402 

403 def __init__( 

404 self, 

405 resource_name: str, 

406 text_field: str = "text", 

407 embedding_field: str = "embedding", 

408 model: str | None = None, 

409 batch_size: int = 100, 

410 ): 

411 """Initialize the embedding generator. 

412  

413 Args: 

414 resource_name: Name of the LLM resource to use. 

415 text_field: Field containing text to embed. 

416 embedding_field: Field to store embeddings. 

417 model: Embedding model to use. 

418 batch_size: Batch size for embedding generation. 

419 """ 

420 self.resource_name = resource_name 

421 self.text_field = text_field 

422 self.embedding_field = embedding_field 

423 self.model = model 

424 self.batch_size = batch_size 

425 

426 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

427 """Transform data by generating embeddings. 

428  

429 Args: 

430 data: Input data containing text. 

431  

432 Returns: 

433 Data with embeddings. 

434 """ 

435 # Get resource from context 

436 resource = data.get("_resources", {}).get(self.resource_name) 

437 if not resource or not isinstance(resource, LLMResource): 

438 raise TransformError(f"LLM resource '{self.resource_name}' not found") 

439 

440 # Get text to embed 

441 text = data.get(self.text_field) 

442 if not text: 

443 return data 

444 

445 try: 

446 # Generate embedding(s) 

447 if isinstance(text, list): 

448 # Batch processing 

449 embeddings = [] 

450 for i in range(0, len(text), self.batch_size): 

451 batch = text[i:i + self.batch_size] 

452 batch_embeddings = await resource.embed(batch, model=self.model) 

453 embeddings.extend(batch_embeddings) 

454 else: 

455 # Single text 

456 embeddings = await resource.embed(text, model=self.model) 

457 

458 return { 

459 **data, 

460 self.embedding_field: embeddings, 

461 } 

462 

463 except Exception as e: 

464 raise TransformError(f"Embedding generation failed: {e}") from e 

465 

466 

467# Convenience functions for creating LLM functions 

468def build_prompt(template: str, **kwargs) -> PromptBuilder: 

469 """Create a PromptBuilder.""" 

470 return PromptBuilder(template, **kwargs) 

471 

472 

473def call_llm(resource: str, **kwargs) -> LLMCaller: 

474 """Create an LLMCaller.""" 

475 return LLMCaller(resource, **kwargs) 

476 

477 

478def validate_response(**kwargs) -> ResponseValidator: 

479 """Create a ResponseValidator.""" 

480 return ResponseValidator(**kwargs) 

481 

482 

483def call_function(**kwargs) -> FunctionCaller: 

484 """Create a FunctionCaller.""" 

485 return FunctionCaller(**kwargs) 

486 

487 

488def manage_conversation(**kwargs) -> ConversationManager: 

489 """Create a ConversationManager.""" 

490 return ConversationManager(**kwargs) 

491 

492 

493def generate_embeddings(resource: str, **kwargs) -> EmbeddingGenerator: 

494 """Create an EmbeddingGenerator.""" 

495 return EmbeddingGenerator(resource, **kwargs)