Coverage for src/dataknobs_llm/conversations/middleware.py: 0%

94 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-31 16:04 -0600

1"""Middleware system for conversation processing. 

2 

3This module provides middleware capabilities for processing messages before they 

4are sent to the LLM and processing responses after they come back from the LLM. 

5Middleware can be used for logging, validation, content filtering, and more. 

6 

7Example: 

8 >>> from dataknobs_llm.conversations import ( 

9 ... ConversationManager, 

10 ... LoggingMiddleware, 

11 ... ValidationMiddleware 

12 ... ) 

13 >>> import logging 

14 >>> 

15 >>> # Create middleware instances 

16 >>> logger = logging.getLogger(__name__) 

17 >>> logging_mw = LoggingMiddleware(logger) 

18 >>> validation_mw = ValidationMiddleware( 

19 ... prompt_builder=builder, 

20 ... validation_prompt="validate_response" 

21 ... ) 

22 >>> 

23 >>> # Create conversation with middleware 

24 >>> manager = await ConversationManager.create( 

25 ... llm=llm, 

26 ... prompt_builder=builder, 

27 ... storage=storage, 

28 ... middleware=[logging_mw, validation_mw] 

29 ... ) 

30""" 

31 

32from abc import ABC, abstractmethod 

33from typing import List, Optional, Any, Dict 

34import logging 

35 

36from dataknobs_llm.llm import LLMMessage, LLMResponse 

37from dataknobs_llm.conversations.storage import ConversationState 

38 

39 

40class ConversationMiddleware(ABC): 

41 """Base class for conversation middleware. 

42 

43 Middleware can process requests before LLM and responses after LLM. 

44 Middleware is executed in order for requests, and in reverse order 

45 for responses (like an onion). 

46 

47 Example: 

48 >>> class CustomMiddleware(ConversationMiddleware): 

49 ... async def process_request(self, messages, state): 

50 ... # Add custom processing before LLM 

51 ... return messages 

52 ... 

53 ... async def process_response(self, response, state): 

54 ... # Add custom processing after LLM 

55 ... return response 

56 """ 

57 

58 @abstractmethod 

59 async def process_request( 

60 self, 

61 messages: List[LLMMessage], 

62 state: ConversationState 

63 ) -> List[LLMMessage]: 

64 """Process messages before sending to LLM. 

65 

66 Args: 

67 messages: Messages to send to LLM 

68 state: Current conversation state 

69 

70 Returns: 

71 Processed messages (can modify, add, or remove messages) 

72 

73 Example: 

74 >>> async def process_request(self, messages, state): 

75 ... # Add timestamp to metadata 

76 ... for msg in messages: 

77 ... if not msg.metadata: 

78 ... msg.metadata = {} 

79 ... msg.metadata["timestamp"] = datetime.now().isoformat() 

80 ... return messages 

81 """ 

82 pass 

83 

84 @abstractmethod 

85 async def process_response( 

86 self, 

87 response: LLMResponse, 

88 state: ConversationState 

89 ) -> LLMResponse: 

90 """Process response from LLM. 

91 

92 Args: 

93 response: LLM response 

94 state: Current conversation state 

95 

96 Returns: 

97 Processed response (can modify content, metadata, etc.) 

98 

99 Example: 

100 >>> async def process_response(self, response, state): 

101 ... # Add processing metadata 

102 ... if not response.metadata: 

103 ... response.metadata = {} 

104 ... response.metadata["processed_at"] = datetime.now().isoformat() 

105 ... return response 

106 """ 

107 pass 

108 

109 

110class LoggingMiddleware(ConversationMiddleware): 

111 """Middleware that logs all requests and responses. 

112 

113 This middleware is useful for debugging and monitoring conversations. 

114 It logs message counts, conversation IDs, and response metadata. 

115 

116 Example: 

117 >>> import logging 

118 >>> logger = logging.getLogger(__name__) 

119 >>> logging.basicConfig(level=logging.INFO) 

120 >>> 

121 >>> middleware = LoggingMiddleware(logger) 

122 >>> manager = await ConversationManager.create( 

123 ... llm=llm, 

124 ... prompt_builder=builder, 

125 ... storage=storage, 

126 ... middleware=[middleware] 

127 ... ) 

128 """ 

129 

130 def __init__(self, logger: Optional[logging.Logger] = None): 

131 """Initialize logging middleware. 

132 

133 Args: 

134 logger: Logger instance to use (defaults to module logger) 

135 """ 

136 self.logger = logger or logging.getLogger(__name__) 

137 

138 async def process_request( 

139 self, 

140 messages: List[LLMMessage], 

141 state: ConversationState 

142 ) -> List[LLMMessage]: 

143 """Log request details before sending to LLM.""" 

144 self.logger.info( 

145 f"Conversation {state.conversation_id} - " 

146 f"Sending {len(messages)} messages to LLM" 

147 ) 

148 self.logger.debug( 

149 f"Conversation {state.conversation_id} - " 

150 f"Message roles: {[msg.role for msg in messages]}" 

151 ) 

152 return messages 

153 

154 async def process_response( 

155 self, 

156 response: LLMResponse, 

157 state: ConversationState 

158 ) -> LLMResponse: 

159 """Log response details after receiving from LLM.""" 

160 content_length = len(response.content) if response.content else 0 

161 self.logger.info( 

162 f"Conversation {state.conversation_id} - " 

163 f"Received response: {content_length} chars, " 

164 f"model={response.model}, finish_reason={response.finish_reason}" 

165 ) 

166 if response.usage: 

167 self.logger.debug( 

168 f"Conversation {state.conversation_id} - " 

169 f"Token usage: {response.usage}" 

170 ) 

171 return response 

172 

173 

174class ContentFilterMiddleware(ConversationMiddleware): 

175 """Middleware that filters inappropriate content from responses. 

176 

177 This middleware can be used to redact or replace specific words or 

178 patterns in LLM responses. Useful for content moderation and compliance. 

179 

180 Example: 

181 >>> # Filter specific words 

182 >>> middleware = ContentFilterMiddleware( 

183 ... filter_words=["badword1", "badword2"], 

184 ... replacement="[FILTERED]" 

185 ... ) 

186 >>> 

187 >>> # Case-insensitive filtering 

188 >>> middleware = ContentFilterMiddleware( 

189 ... filter_words=["sensitive"], 

190 ... case_sensitive=False 

191 ... ) 

192 """ 

193 

194 def __init__( 

195 self, 

196 filter_words: List[str], 

197 replacement: str = "[FILTERED]", 

198 case_sensitive: bool = True 

199 ): 

200 """Initialize content filter middleware. 

201 

202 Args: 

203 filter_words: List of words/phrases to filter 

204 replacement: String to replace filtered content with 

205 case_sensitive: Whether filtering should be case-sensitive 

206 """ 

207 self.filter_words = filter_words 

208 self.replacement = replacement 

209 self.case_sensitive = case_sensitive 

210 

211 async def process_request( 

212 self, 

213 messages: List[LLMMessage], 

214 state: ConversationState 

215 ) -> List[LLMMessage]: 

216 """Pass through requests without filtering.""" 

217 return messages 

218 

219 async def process_response( 

220 self, 

221 response: LLMResponse, 

222 state: ConversationState 

223 ) -> LLMResponse: 

224 """Filter inappropriate content from response.""" 

225 content = response.content 

226 

227 for word in self.filter_words: 

228 if self.case_sensitive: 

229 content = content.replace(word, self.replacement) 

230 else: 

231 # Case-insensitive replacement 

232 import re 

233 pattern = re.compile(re.escape(word), re.IGNORECASE) 

234 content = pattern.sub(self.replacement, content) 

235 

236 # Track if any filtering occurred 

237 if content != response.content: 

238 if not response.metadata: 

239 response.metadata = {} 

240 response.metadata["content_filtered"] = True 

241 response.content = content 

242 

243 return response 

244 

245 

246class ValidationMiddleware(ConversationMiddleware): 

247 """Middleware that validates LLM responses using another LLM call. 

248 

249 This middleware uses a validation prompt and a separate LLM call to check 

250 if responses meet certain criteria. Can optionally retry on validation failure. 

251 

252 Example: 

253 >>> # Create validation middleware 

254 >>> validation_llm = OpenAIProvider(config) 

255 >>> middleware = ValidationMiddleware( 

256 ... llm=validation_llm, 

257 ... prompt_builder=builder, 

258 ... validation_prompt="validate_response", 

259 ... auto_retry=False # Raise error instead of retrying 

260 ... ) 

261 >>> 

262 >>> # Validation prompt should ask the LLM to respond with 

263 >>> # "VALID" or "INVALID" based on the response content 

264 """ 

265 

266 def __init__( 

267 self, 

268 llm: "AsyncLLMProvider", 

269 prompt_builder: "AsyncPromptBuilder", 

270 validation_prompt: str, 

271 auto_retry: bool = False, 

272 retry_limit: int = 3 

273 ): 

274 """Initialize validation middleware. 

275 

276 Args: 

277 llm: LLM provider to use for validation (required) 

278 prompt_builder: Prompt builder for rendering validation prompt 

279 validation_prompt: Name of validation prompt template 

280 auto_retry: Whether to automatically retry on validation failure 

281 retry_limit: Maximum number of retries if auto_retry is True 

282 """ 

283 from dataknobs_llm.prompts import AsyncPromptBuilder 

284 from dataknobs_llm.llm import AsyncLLMProvider 

285 

286 self.llm: AsyncLLMProvider = llm 

287 self.builder: AsyncPromptBuilder = prompt_builder 

288 self.validation_prompt = validation_prompt 

289 self.auto_retry = auto_retry 

290 self.retry_limit = retry_limit 

291 

292 async def process_request( 

293 self, 

294 messages: List[LLMMessage], 

295 state: ConversationState 

296 ) -> List[LLMMessage]: 

297 """Pass through requests without validation.""" 

298 return messages 

299 

300 async def process_response( 

301 self, 

302 response: LLMResponse, 

303 state: ConversationState 

304 ) -> LLMResponse: 

305 """Validate response by calling LLM with validation prompt.""" 

306 # Render validation prompt with response content 

307 validation_prompt_result = await self.builder.render_user_prompt( 

308 self.validation_prompt, 

309 index=0, 

310 params={"response": response.content}, 

311 include_rag=False # Don't need RAG for validation 

312 ) 

313 

314 # Create message and call LLM to get validation judgment 

315 validation_message = LLMMessage( 

316 role="user", 

317 content=validation_prompt_result.content 

318 ) 

319 validation_response = await self.llm.complete([validation_message]) 

320 

321 # Check if LLM says response is valid 

322 is_valid = self._check_validity(validation_response.content) 

323 

324 if not is_valid: 

325 # Track validation failure 

326 if not response.metadata: 

327 response.metadata = {} 

328 response.metadata["validation_failed"] = True 

329 response.metadata["validation_response"] = validation_response.content 

330 

331 if self.auto_retry: 

332 # Note: Actual retry logic would need to be implemented 

333 # at the ConversationManager level. This just marks the failure. 

334 response.metadata["retry_requested"] = True 

335 else: 

336 raise ValueError( 

337 f"Response failed validation: {validation_response.content}" 

338 ) 

339 

340 return response 

341 

342 def _check_validity(self, validation_response: str) -> bool: 

343 """Check if validation response indicates success. 

344 

345 Args: 

346 validation_response: Content from validation prompt response 

347 

348 Returns: 

349 True if valid, False otherwise 

350 """ 

351 # Simple implementation: look for "VALID" in response 

352 # This can be customized based on validation prompt design 

353 return "VALID" in validation_response.upper() 

354 

355 

356class MetadataMiddleware(ConversationMiddleware): 

357 """Middleware that adds custom metadata to messages and responses. 

358 

359 This middleware can inject metadata into both requests and responses, 

360 which is useful for tracking, analytics, and debugging. 

361 

362 Example: 

363 >>> # Add environment info to all messages 

364 >>> middleware = MetadataMiddleware( 

365 ... request_metadata={"environment": "production"}, 

366 ... response_metadata={"version": "1.0.0"} 

367 ... ) 

368 >>> 

369 >>> # Add dynamic metadata via callback 

370 >>> def get_request_meta(): 

371 ... return {"timestamp": datetime.now().isoformat()} 

372 >>> 

373 >>> middleware = MetadataMiddleware( 

374 ... request_metadata_fn=get_request_meta 

375 ... ) 

376 """ 

377 

378 def __init__( 

379 self, 

380 request_metadata: Optional[Dict[str, Any]] = None, 

381 response_metadata: Optional[Dict[str, Any]] = None, 

382 request_metadata_fn: Optional[callable] = None, 

383 response_metadata_fn: Optional[callable] = None 

384 ): 

385 """Initialize metadata middleware. 

386 

387 Args: 

388 request_metadata: Static metadata to add to requests 

389 response_metadata: Static metadata to add to responses 

390 request_metadata_fn: Callable that returns metadata for requests 

391 response_metadata_fn: Callable that returns metadata for responses 

392 """ 

393 self.request_metadata = request_metadata or {} 

394 self.response_metadata = response_metadata or {} 

395 self.request_metadata_fn = request_metadata_fn 

396 self.response_metadata_fn = response_metadata_fn 

397 

398 async def process_request( 

399 self, 

400 messages: List[LLMMessage], 

401 state: ConversationState 

402 ) -> List[LLMMessage]: 

403 """Add metadata to request messages.""" 

404 # Collect metadata to add 

405 metadata_to_add = dict(self.request_metadata) 

406 

407 # Add dynamic metadata if function provided 

408 if self.request_metadata_fn: 

409 dynamic_metadata = self.request_metadata_fn() 

410 metadata_to_add.update(dynamic_metadata) 

411 

412 # Add metadata to each message 

413 if metadata_to_add: 

414 for msg in messages: 

415 if not msg.metadata: 

416 msg.metadata = {} 

417 msg.metadata.update(metadata_to_add) 

418 

419 return messages 

420 

421 async def process_response( 

422 self, 

423 response: LLMResponse, 

424 state: ConversationState 

425 ) -> LLMResponse: 

426 """Add metadata to response.""" 

427 # Collect metadata to add 

428 metadata_to_add = dict(self.response_metadata) 

429 

430 # Add dynamic metadata if function provided 

431 if self.response_metadata_fn: 

432 dynamic_metadata = self.response_metadata_fn() 

433 metadata_to_add.update(dynamic_metadata) 

434 

435 # Add metadata to response 

436 if metadata_to_add: 

437 if not response.metadata: 

438 response.metadata = {} 

439 response.metadata.update(metadata_to_add) 

440 

441 return response