Coverage for src/dataknobs_llm/conversations/middleware.py: 0%
94 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-31 16:04 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-31 16:04 -0600
1"""Middleware system for conversation processing.
3This module provides middleware capabilities for processing messages before they
4are sent to the LLM and processing responses after they come back from the LLM.
5Middleware can be used for logging, validation, content filtering, and more.
7Example:
8 >>> from dataknobs_llm.conversations import (
9 ... ConversationManager,
10 ... LoggingMiddleware,
11 ... ValidationMiddleware
12 ... )
13 >>> import logging
14 >>>
15 >>> # Create middleware instances
16 >>> logger = logging.getLogger(__name__)
17 >>> logging_mw = LoggingMiddleware(logger)
18 >>> validation_mw = ValidationMiddleware(
19 ... prompt_builder=builder,
20 ... validation_prompt="validate_response"
21 ... )
22 >>>
23 >>> # Create conversation with middleware
24 >>> manager = await ConversationManager.create(
25 ... llm=llm,
26 ... prompt_builder=builder,
27 ... storage=storage,
28 ... middleware=[logging_mw, validation_mw]
29 ... )
30"""
32from abc import ABC, abstractmethod
33from typing import List, Optional, Any, Dict
34import logging
36from dataknobs_llm.llm import LLMMessage, LLMResponse
37from dataknobs_llm.conversations.storage import ConversationState
40class ConversationMiddleware(ABC):
41 """Base class for conversation middleware.
43 Middleware can process requests before LLM and responses after LLM.
44 Middleware is executed in order for requests, and in reverse order
45 for responses (like an onion).
47 Example:
48 >>> class CustomMiddleware(ConversationMiddleware):
49 ... async def process_request(self, messages, state):
50 ... # Add custom processing before LLM
51 ... return messages
52 ...
53 ... async def process_response(self, response, state):
54 ... # Add custom processing after LLM
55 ... return response
56 """
58 @abstractmethod
59 async def process_request(
60 self,
61 messages: List[LLMMessage],
62 state: ConversationState
63 ) -> List[LLMMessage]:
64 """Process messages before sending to LLM.
66 Args:
67 messages: Messages to send to LLM
68 state: Current conversation state
70 Returns:
71 Processed messages (can modify, add, or remove messages)
73 Example:
74 >>> async def process_request(self, messages, state):
75 ... # Add timestamp to metadata
76 ... for msg in messages:
77 ... if not msg.metadata:
78 ... msg.metadata = {}
79 ... msg.metadata["timestamp"] = datetime.now().isoformat()
80 ... return messages
81 """
82 pass
84 @abstractmethod
85 async def process_response(
86 self,
87 response: LLMResponse,
88 state: ConversationState
89 ) -> LLMResponse:
90 """Process response from LLM.
92 Args:
93 response: LLM response
94 state: Current conversation state
96 Returns:
97 Processed response (can modify content, metadata, etc.)
99 Example:
100 >>> async def process_response(self, response, state):
101 ... # Add processing metadata
102 ... if not response.metadata:
103 ... response.metadata = {}
104 ... response.metadata["processed_at"] = datetime.now().isoformat()
105 ... return response
106 """
107 pass
110class LoggingMiddleware(ConversationMiddleware):
111 """Middleware that logs all requests and responses.
113 This middleware is useful for debugging and monitoring conversations.
114 It logs message counts, conversation IDs, and response metadata.
116 Example:
117 >>> import logging
118 >>> logger = logging.getLogger(__name__)
119 >>> logging.basicConfig(level=logging.INFO)
120 >>>
121 >>> middleware = LoggingMiddleware(logger)
122 >>> manager = await ConversationManager.create(
123 ... llm=llm,
124 ... prompt_builder=builder,
125 ... storage=storage,
126 ... middleware=[middleware]
127 ... )
128 """
130 def __init__(self, logger: Optional[logging.Logger] = None):
131 """Initialize logging middleware.
133 Args:
134 logger: Logger instance to use (defaults to module logger)
135 """
136 self.logger = logger or logging.getLogger(__name__)
138 async def process_request(
139 self,
140 messages: List[LLMMessage],
141 state: ConversationState
142 ) -> List[LLMMessage]:
143 """Log request details before sending to LLM."""
144 self.logger.info(
145 f"Conversation {state.conversation_id} - "
146 f"Sending {len(messages)} messages to LLM"
147 )
148 self.logger.debug(
149 f"Conversation {state.conversation_id} - "
150 f"Message roles: {[msg.role for msg in messages]}"
151 )
152 return messages
154 async def process_response(
155 self,
156 response: LLMResponse,
157 state: ConversationState
158 ) -> LLMResponse:
159 """Log response details after receiving from LLM."""
160 content_length = len(response.content) if response.content else 0
161 self.logger.info(
162 f"Conversation {state.conversation_id} - "
163 f"Received response: {content_length} chars, "
164 f"model={response.model}, finish_reason={response.finish_reason}"
165 )
166 if response.usage:
167 self.logger.debug(
168 f"Conversation {state.conversation_id} - "
169 f"Token usage: {response.usage}"
170 )
171 return response
174class ContentFilterMiddleware(ConversationMiddleware):
175 """Middleware that filters inappropriate content from responses.
177 This middleware can be used to redact or replace specific words or
178 patterns in LLM responses. Useful for content moderation and compliance.
180 Example:
181 >>> # Filter specific words
182 >>> middleware = ContentFilterMiddleware(
183 ... filter_words=["badword1", "badword2"],
184 ... replacement="[FILTERED]"
185 ... )
186 >>>
187 >>> # Case-insensitive filtering
188 >>> middleware = ContentFilterMiddleware(
189 ... filter_words=["sensitive"],
190 ... case_sensitive=False
191 ... )
192 """
194 def __init__(
195 self,
196 filter_words: List[str],
197 replacement: str = "[FILTERED]",
198 case_sensitive: bool = True
199 ):
200 """Initialize content filter middleware.
202 Args:
203 filter_words: List of words/phrases to filter
204 replacement: String to replace filtered content with
205 case_sensitive: Whether filtering should be case-sensitive
206 """
207 self.filter_words = filter_words
208 self.replacement = replacement
209 self.case_sensitive = case_sensitive
211 async def process_request(
212 self,
213 messages: List[LLMMessage],
214 state: ConversationState
215 ) -> List[LLMMessage]:
216 """Pass through requests without filtering."""
217 return messages
219 async def process_response(
220 self,
221 response: LLMResponse,
222 state: ConversationState
223 ) -> LLMResponse:
224 """Filter inappropriate content from response."""
225 content = response.content
227 for word in self.filter_words:
228 if self.case_sensitive:
229 content = content.replace(word, self.replacement)
230 else:
231 # Case-insensitive replacement
232 import re
233 pattern = re.compile(re.escape(word), re.IGNORECASE)
234 content = pattern.sub(self.replacement, content)
236 # Track if any filtering occurred
237 if content != response.content:
238 if not response.metadata:
239 response.metadata = {}
240 response.metadata["content_filtered"] = True
241 response.content = content
243 return response
246class ValidationMiddleware(ConversationMiddleware):
247 """Middleware that validates LLM responses using another LLM call.
249 This middleware uses a validation prompt and a separate LLM call to check
250 if responses meet certain criteria. Can optionally retry on validation failure.
252 Example:
253 >>> # Create validation middleware
254 >>> validation_llm = OpenAIProvider(config)
255 >>> middleware = ValidationMiddleware(
256 ... llm=validation_llm,
257 ... prompt_builder=builder,
258 ... validation_prompt="validate_response",
259 ... auto_retry=False # Raise error instead of retrying
260 ... )
261 >>>
262 >>> # Validation prompt should ask the LLM to respond with
263 >>> # "VALID" or "INVALID" based on the response content
264 """
266 def __init__(
267 self,
268 llm: "AsyncLLMProvider",
269 prompt_builder: "AsyncPromptBuilder",
270 validation_prompt: str,
271 auto_retry: bool = False,
272 retry_limit: int = 3
273 ):
274 """Initialize validation middleware.
276 Args:
277 llm: LLM provider to use for validation (required)
278 prompt_builder: Prompt builder for rendering validation prompt
279 validation_prompt: Name of validation prompt template
280 auto_retry: Whether to automatically retry on validation failure
281 retry_limit: Maximum number of retries if auto_retry is True
282 """
283 from dataknobs_llm.prompts import AsyncPromptBuilder
284 from dataknobs_llm.llm import AsyncLLMProvider
286 self.llm: AsyncLLMProvider = llm
287 self.builder: AsyncPromptBuilder = prompt_builder
288 self.validation_prompt = validation_prompt
289 self.auto_retry = auto_retry
290 self.retry_limit = retry_limit
292 async def process_request(
293 self,
294 messages: List[LLMMessage],
295 state: ConversationState
296 ) -> List[LLMMessage]:
297 """Pass through requests without validation."""
298 return messages
300 async def process_response(
301 self,
302 response: LLMResponse,
303 state: ConversationState
304 ) -> LLMResponse:
305 """Validate response by calling LLM with validation prompt."""
306 # Render validation prompt with response content
307 validation_prompt_result = await self.builder.render_user_prompt(
308 self.validation_prompt,
309 index=0,
310 params={"response": response.content},
311 include_rag=False # Don't need RAG for validation
312 )
314 # Create message and call LLM to get validation judgment
315 validation_message = LLMMessage(
316 role="user",
317 content=validation_prompt_result.content
318 )
319 validation_response = await self.llm.complete([validation_message])
321 # Check if LLM says response is valid
322 is_valid = self._check_validity(validation_response.content)
324 if not is_valid:
325 # Track validation failure
326 if not response.metadata:
327 response.metadata = {}
328 response.metadata["validation_failed"] = True
329 response.metadata["validation_response"] = validation_response.content
331 if self.auto_retry:
332 # Note: Actual retry logic would need to be implemented
333 # at the ConversationManager level. This just marks the failure.
334 response.metadata["retry_requested"] = True
335 else:
336 raise ValueError(
337 f"Response failed validation: {validation_response.content}"
338 )
340 return response
342 def _check_validity(self, validation_response: str) -> bool:
343 """Check if validation response indicates success.
345 Args:
346 validation_response: Content from validation prompt response
348 Returns:
349 True if valid, False otherwise
350 """
351 # Simple implementation: look for "VALID" in response
352 # This can be customized based on validation prompt design
353 return "VALID" in validation_response.upper()
356class MetadataMiddleware(ConversationMiddleware):
357 """Middleware that adds custom metadata to messages and responses.
359 This middleware can inject metadata into both requests and responses,
360 which is useful for tracking, analytics, and debugging.
362 Example:
363 >>> # Add environment info to all messages
364 >>> middleware = MetadataMiddleware(
365 ... request_metadata={"environment": "production"},
366 ... response_metadata={"version": "1.0.0"}
367 ... )
368 >>>
369 >>> # Add dynamic metadata via callback
370 >>> def get_request_meta():
371 ... return {"timestamp": datetime.now().isoformat()}
372 >>>
373 >>> middleware = MetadataMiddleware(
374 ... request_metadata_fn=get_request_meta
375 ... )
376 """
378 def __init__(
379 self,
380 request_metadata: Optional[Dict[str, Any]] = None,
381 response_metadata: Optional[Dict[str, Any]] = None,
382 request_metadata_fn: Optional[callable] = None,
383 response_metadata_fn: Optional[callable] = None
384 ):
385 """Initialize metadata middleware.
387 Args:
388 request_metadata: Static metadata to add to requests
389 response_metadata: Static metadata to add to responses
390 request_metadata_fn: Callable that returns metadata for requests
391 response_metadata_fn: Callable that returns metadata for responses
392 """
393 self.request_metadata = request_metadata or {}
394 self.response_metadata = response_metadata or {}
395 self.request_metadata_fn = request_metadata_fn
396 self.response_metadata_fn = response_metadata_fn
398 async def process_request(
399 self,
400 messages: List[LLMMessage],
401 state: ConversationState
402 ) -> List[LLMMessage]:
403 """Add metadata to request messages."""
404 # Collect metadata to add
405 metadata_to_add = dict(self.request_metadata)
407 # Add dynamic metadata if function provided
408 if self.request_metadata_fn:
409 dynamic_metadata = self.request_metadata_fn()
410 metadata_to_add.update(dynamic_metadata)
412 # Add metadata to each message
413 if metadata_to_add:
414 for msg in messages:
415 if not msg.metadata:
416 msg.metadata = {}
417 msg.metadata.update(metadata_to_add)
419 return messages
421 async def process_response(
422 self,
423 response: LLMResponse,
424 state: ConversationState
425 ) -> LLMResponse:
426 """Add metadata to response."""
427 # Collect metadata to add
428 metadata_to_add = dict(self.response_metadata)
430 # Add dynamic metadata if function provided
431 if self.response_metadata_fn:
432 dynamic_metadata = self.response_metadata_fn()
433 metadata_to_add.update(dynamic_metadata)
435 # Add metadata to response
436 if metadata_to_add:
437 if not response.metadata:
438 response.metadata = {}
439 response.metadata.update(metadata_to_add)
441 return response