Coverage for src/dataknobs_llm/conversations/middleware.py: 61%
169 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
1"""Middleware system for conversation processing.
3This module provides middleware capabilities for processing messages before they
4are sent to the LLM and processing responses after they come back from the LLM.
5Middleware can be used for logging, validation, content filtering, rate limiting,
6metadata injection, and more.
8Execution Model (Onion Pattern):
9 Middleware wraps around LLM calls in an "onion" pattern:
11 Request Flow: MW0 → MW1 → MW2 → LLM
12 Response Flow: LLM → MW2 → MW1 → MW0
14 Example with 3 middleware [Logging, RateLimit, Validation]:
16 ```
17 1. Logging.process_request() # Log incoming messages
18 2. RateLimit.process_request() # Check rate limits
19 3. Validation.process_request() # Validate request
20 4. → LLM Call → # Actual LLM API call
21 5. Validation.process_response() # Validate LLM response
22 6. RateLimit.process_response() # Add rate limit info to response
23 7. Logging.process_response() # Log response details
24 ```
26 This ensures middleware can:
27 - Time the full LLM call (start timer in process_request, stop in process_response)
28 - Wrap operations symmetrically (open resources → LLM → close resources)
29 - See the final state after inner middleware modifications
31Performance Considerations:
32 - **Middleware adds latency**: Each middleware's `process_request()` and
33 `process_response()` adds to total response time. Keep middleware logic fast.
35 - **Async is key**: All middleware methods are async. Use `await` for I/O
36 operations (DB calls, network requests) to avoid blocking.
38 - **Order matters**: Place expensive middleware (like ValidationMiddleware
39 that makes additional LLM calls) at the end of the list to minimize
40 wasted work if earlier middleware rejects the request.
42 - **Memory usage**: RateLimitMiddleware keeps request history in memory.
43 For high-traffic applications, consider external rate limiting (Redis, etc.).
45Available Middleware:
46 - **LoggingMiddleware**: Log requests and responses for debugging
47 - **ContentFilterMiddleware**: Filter inappropriate content from responses
48 - **ValidationMiddleware**: Validate responses with additional LLM call
49 - **MetadataMiddleware**: Inject custom metadata into messages/responses
50 - **RateLimitMiddleware**: Enforce rate limits with sliding window
52Example:
53 ```python
54 from dataknobs_llm.conversations import (
55 ConversationManager,
56 LoggingMiddleware,
57 RateLimitMiddleware,
58 ContentFilterMiddleware
59 )
60 import logging
62 # Create middleware instances (order matters!)
63 logger = logging.getLogger(__name__)
64 logging_mw = LoggingMiddleware(logger)
65 rate_limit_mw = RateLimitMiddleware(max_requests=10, window_seconds=60)
66 filter_mw = ContentFilterMiddleware(
67 filter_words=["inappropriate"],
68 replacement="[FILTERED]"
69 )
71 # Create conversation with middleware stack
72 # Execution: Logging → RateLimit → Filter → LLM → Filter → RateLimit → Logging
73 manager = await ConversationManager.create(
74 llm=llm,
75 prompt_builder=builder,
76 storage=storage,
77 middleware=[logging_mw, rate_limit_mw, filter_mw]
78 )
80 # All requests will go through middleware pipeline
81 await manager.add_message(role="user", content="Hello")
82 response = await manager.complete() # Middleware applied automatically
83 ```
85See Also:
86 ConversationManager: Uses middleware for all LLM interactions
87 ConversationMiddleware: Base class for custom middleware
88"""
90from abc import ABC, abstractmethod
91from typing import List, Any, Dict, Callable
92import logging
94from dataknobs_llm.llm import LLMMessage, LLMResponse
95from dataknobs_llm.llm.providers import AsyncLLMProvider
96from dataknobs_llm.conversations.storage import ConversationState
97from dataknobs_llm.prompts import AsyncPromptBuilder
100class ConversationMiddleware(ABC):
101 """Base class for conversation middleware.
103 Middleware can process requests before LLM and responses after LLM.
104 Middleware is executed in order for requests, and in reverse order
105 for responses (onion pattern).
107 Execution Order:
108 Given middleware list [MW0, MW1, MW2]:
110 - **Request**: MW0 → MW1 → MW2 → LLM
111 - **Response**: LLM → MW2 → MW1 → MW0
113 This allows MW0 to:
114 1. Start a timer in `process_request()`
115 2. See the LLM call complete
116 3. Stop the timer in `process_response()` and log total time
118 Use Cases:
119 - **Logging**: Track request/response details
120 - **Validation**: Verify request/response content
121 - **Transformation**: Modify messages or responses
122 - **Rate Limiting**: Enforce API usage limits
123 - **Caching**: Store/retrieve responses
124 - **Monitoring**: Collect metrics and analytics
125 - **Security**: Filter sensitive information
127 Example:
128 ```python
129 from dataknobs_llm.conversations import ConversationMiddleware
130 import time
132 class TimingMiddleware(ConversationMiddleware):
133 '''Measure LLM call duration.'''
135 async def process_request(self, messages, state):
136 # Store start time in state metadata
137 state.metadata["request_start"] = time.time()
138 return messages
140 async def process_response(self, response, state):
141 # Calculate elapsed time
142 start = state.metadata.get("request_start")
143 if start:
144 elapsed = time.time() - start
145 if not response.metadata:
146 response.metadata = {}
147 response.metadata["llm_duration_seconds"] = elapsed
148 print(f"LLM call took {elapsed:.2f}s")
149 return response
151 # Use in conversation
152 manager = await ConversationManager.create(
153 llm=llm,
154 middleware=[TimingMiddleware()]
155 )
156 ```
158 Note:
159 **Performance Tips**:
161 - Keep `process_request()` and `process_response()` fast
162 - Use async I/O (await) for external calls (DB, network)
163 - Don't block the async loop with synchronous operations
164 - For expensive operations, consider running them in background tasks
165 - Store state in `state.metadata` not instance variables (thread safety)
167 See Also:
168 LoggingMiddleware: Example implementation
169 ConversationManager.complete: Where middleware is executed
170 """
172 @abstractmethod
173 async def process_request(
174 self,
175 messages: List[LLMMessage],
176 state: ConversationState
177 ) -> List[LLMMessage]:
178 """Process messages before sending to LLM.
180 Args:
181 messages: Messages to send to LLM
182 state: Current conversation state
184 Returns:
185 Processed messages (can modify, add, or remove messages)
187 Example:
188 >>> from datetime import datetime
189 >>> async def process_request(self, messages, state):
190 ... # Add timestamp to metadata
191 ... for msg in messages:
192 ... if not msg.metadata:
193 ... msg.metadata = {}
194 ... msg.metadata["timestamp"] = datetime.now().isoformat()
195 ... return messages
196 """
197 pass
199 @abstractmethod
200 async def process_response(
201 self,
202 response: LLMResponse,
203 state: ConversationState
204 ) -> LLMResponse:
205 """Process response from LLM.
207 Args:
208 response: LLM response
209 state: Current conversation state
211 Returns:
212 Processed response (can modify content, metadata, etc.)
214 Example:
215 >>> from datetime import datetime
216 >>> async def process_response(self, response, state):
217 ... # Add processing metadata
218 ... if not response.metadata:
219 ... response.metadata = {}
220 ... response.metadata["processed_at"] = datetime.now().isoformat()
221 ... return response
222 """
223 pass
226class LoggingMiddleware(ConversationMiddleware):
227 """Middleware that logs all requests and responses.
229 This middleware is useful for debugging and monitoring conversations.
230 It logs message counts, conversation IDs, and response metadata.
232 Example:
233 >>> import logging
234 >>> logger = logging.getLogger(__name__)
235 >>> logging.basicConfig(level=logging.INFO)
236 >>>
237 >>> middleware = LoggingMiddleware(logger)
238 >>> manager = await ConversationManager.create(
239 ... llm=llm,
240 ... prompt_builder=builder,
241 ... storage=storage,
242 ... middleware=[middleware]
243 ... )
244 """
246 def __init__(self, logger: logging.Logger | None = None):
247 """Initialize logging middleware.
249 Args:
250 logger: Logger instance to use (defaults to module logger)
251 """
252 self.logger = logger or logging.getLogger(__name__)
254 async def process_request(
255 self,
256 messages: List[LLMMessage],
257 state: ConversationState
258 ) -> List[LLMMessage]:
259 """Log request details before sending to LLM."""
260 self.logger.info(
261 f"Conversation {state.conversation_id} - "
262 f"Sending {len(messages)} messages to LLM"
263 )
264 self.logger.debug(
265 f"Conversation {state.conversation_id} - "
266 f"Message roles: {[msg.role for msg in messages]}"
267 )
268 return messages
270 async def process_response(
271 self,
272 response: LLMResponse,
273 state: ConversationState
274 ) -> LLMResponse:
275 """Log response details after receiving from LLM."""
276 content_length = len(response.content) if response.content else 0
277 self.logger.info(
278 f"Conversation {state.conversation_id} - "
279 f"Received response: {content_length} chars, "
280 f"model={response.model}, finish_reason={response.finish_reason}"
281 )
282 if response.usage:
283 self.logger.debug(
284 f"Conversation {state.conversation_id} - "
285 f"Token usage: {response.usage}"
286 )
287 return response
290class ContentFilterMiddleware(ConversationMiddleware):
291 """Middleware that filters inappropriate content from responses.
293 This middleware can be used to redact or replace specific words or
294 patterns in LLM responses. Useful for content moderation and compliance.
296 Example:
297 >>> # Filter specific words
298 >>> middleware = ContentFilterMiddleware(
299 ... filter_words=["badword1", "badword2"],
300 ... replacement="[FILTERED]"
301 ... )
302 >>>
303 >>> # Case-insensitive filtering
304 >>> middleware = ContentFilterMiddleware(
305 ... filter_words=["sensitive"],
306 ... case_sensitive=False
307 ... )
308 """
310 def __init__(
311 self,
312 filter_words: List[str],
313 replacement: str = "[FILTERED]",
314 case_sensitive: bool = True
315 ):
316 """Initialize content filter middleware.
318 Args:
319 filter_words: List of words/phrases to filter
320 replacement: String to replace filtered content with
321 case_sensitive: Whether filtering should be case-sensitive
322 """
323 self.filter_words = filter_words
324 self.replacement = replacement
325 self.case_sensitive = case_sensitive
327 async def process_request(
328 self,
329 messages: List[LLMMessage],
330 state: ConversationState
331 ) -> List[LLMMessage]:
332 """Pass through requests without filtering."""
333 return messages
335 async def process_response(
336 self,
337 response: LLMResponse,
338 state: ConversationState
339 ) -> LLMResponse:
340 """Filter inappropriate content from response."""
341 content = response.content
343 for word in self.filter_words:
344 if self.case_sensitive:
345 content = content.replace(word, self.replacement)
346 else:
347 # Case-insensitive replacement
348 import re
349 pattern = re.compile(re.escape(word), re.IGNORECASE)
350 content = pattern.sub(self.replacement, content)
352 # Track if any filtering occurred
353 if content != response.content:
354 if not response.metadata:
355 response.metadata = {}
356 response.metadata["content_filtered"] = True
357 response.content = content
359 return response
362class ValidationMiddleware(ConversationMiddleware):
363 """Middleware that validates LLM responses using another LLM call.
365 This middleware uses a validation prompt and a separate LLM call to check
366 if responses meet certain criteria. Can optionally retry on validation failure.
368 Example:
369 >>> from dataknobs_llm.llm.providers import OpenAIProvider
370 >>> from dataknobs_llm.llm.base import LLMConfig
371 >>>
372 >>> # Create validation middleware
373 >>> config = LLMConfig(provider="openai", model="gpt-4")
374 >>> validation_llm = OpenAIProvider(config)
375 >>> middleware = ValidationMiddleware(
376 ... llm=validation_llm,
377 ... prompt_builder=builder,
378 ... validation_prompt="validate_response",
379 ... auto_retry=False # Raise error instead of retrying
380 ... )
381 >>>
382 >>> # Validation prompt should ask the LLM to respond with
383 >>> # "VALID" or "INVALID" based on the response content
384 """
386 def __init__(
387 self,
388 llm: AsyncLLMProvider,
389 prompt_builder: AsyncPromptBuilder,
390 validation_prompt: str,
391 auto_retry: bool = False,
392 retry_limit: int = 3
393 ):
394 """Initialize validation middleware.
396 Args:
397 llm: LLM provider to use for validation (required)
398 prompt_builder: Prompt builder for rendering validation prompt
399 validation_prompt: Name of validation prompt template
400 auto_retry: Whether to automatically retry on validation failure
401 retry_limit: Maximum number of retries if auto_retry is True
402 """
403 self.llm: AsyncLLMProvider = llm
404 self.builder: AsyncPromptBuilder = prompt_builder
405 self.validation_prompt = validation_prompt
406 self.auto_retry = auto_retry
407 self.retry_limit = retry_limit
409 async def process_request(
410 self,
411 messages: List[LLMMessage],
412 state: ConversationState
413 ) -> List[LLMMessage]:
414 """Pass through requests without validation."""
415 return messages
417 async def process_response(
418 self,
419 response: LLMResponse,
420 state: ConversationState
421 ) -> LLMResponse:
422 """Validate response by calling LLM with validation prompt."""
423 # Render validation prompt with response content
424 validation_prompt_result = await self.builder.render_user_prompt(
425 self.validation_prompt,
426 index=0,
427 params={"response": response.content},
428 include_rag=False # Don't need RAG for validation
429 )
431 # Create message and call LLM to get validation judgment
432 validation_message = LLMMessage(
433 role="user",
434 content=validation_prompt_result.content
435 )
436 validation_response = await self.llm.complete([validation_message])
438 # Check if LLM says response is valid
439 is_valid = self._check_validity(validation_response.content)
441 if not is_valid:
442 # Track validation failure
443 if not response.metadata:
444 response.metadata = {}
445 response.metadata["validation_failed"] = True
446 response.metadata["validation_response"] = validation_response.content
448 if self.auto_retry:
449 # Note: Actual retry logic would need to be implemented
450 # at the ConversationManager level. This just marks the failure.
451 response.metadata["retry_requested"] = True
452 else:
453 raise ValueError(
454 f"Response failed validation: {validation_response.content}"
455 )
457 return response
459 def _check_validity(self, validation_response: str) -> bool:
460 """Check if validation response indicates success.
462 Args:
463 validation_response: Content from validation prompt response
465 Returns:
466 True if valid, False otherwise
467 """
468 # Simple implementation: look for "VALID" in response
469 # This can be customized based on validation prompt design
470 return "VALID" in validation_response.upper()
473class MetadataMiddleware(ConversationMiddleware):
474 """Middleware that adds custom metadata to messages and responses.
476 This middleware can inject metadata into both requests and responses,
477 which is useful for tracking, analytics, and debugging.
479 Example:
480 >>> from datetime import datetime
481 >>>
482 >>> # Add environment info to all messages
483 >>> middleware = MetadataMiddleware(
484 ... request_metadata={"environment": "production"},
485 ... response_metadata={"version": "1.0.0"}
486 ... )
487 >>>
488 >>> # Add dynamic metadata via callback
489 >>> def get_request_meta():
490 ... return {"timestamp": datetime.now().isoformat()}
491 >>>
492 >>> middleware = MetadataMiddleware(
493 ... request_metadata_fn=get_request_meta
494 ... )
495 """
497 def __init__(
498 self,
499 request_metadata: Dict[str, Any] | None = None,
500 response_metadata: Dict[str, Any] | None = None,
501 request_metadata_fn: Callable[..., Dict[str, Any]] | None = None,
502 response_metadata_fn: Callable[..., Dict[str, Any]] | None = None
503 ):
504 """Initialize metadata middleware.
506 Args:
507 request_metadata: Static metadata to add to requests
508 response_metadata: Static metadata to add to responses
509 request_metadata_fn: Callable that returns metadata for requests
510 response_metadata_fn: Callable that returns metadata for responses
511 """
512 self.request_metadata = request_metadata or {}
513 self.response_metadata = response_metadata or {}
514 self.request_metadata_fn = request_metadata_fn
515 self.response_metadata_fn = response_metadata_fn
517 async def process_request(
518 self,
519 messages: List[LLMMessage],
520 state: ConversationState
521 ) -> List[LLMMessage]:
522 """Add metadata to request messages."""
523 # Collect metadata to add
524 metadata_to_add = dict(self.request_metadata)
526 # Add dynamic metadata if function provided
527 if self.request_metadata_fn:
528 dynamic_metadata = self.request_metadata_fn()
529 metadata_to_add.update(dynamic_metadata)
531 # Add metadata to each message
532 if metadata_to_add:
533 for msg in messages:
534 if not msg.metadata:
535 msg.metadata = {}
536 msg.metadata.update(metadata_to_add)
538 return messages
540 async def process_response(
541 self,
542 response: LLMResponse,
543 state: ConversationState
544 ) -> LLMResponse:
545 """Add metadata to response."""
546 # Collect metadata to add
547 metadata_to_add = dict(self.response_metadata)
549 # Add dynamic metadata if function provided
550 if self.response_metadata_fn:
551 dynamic_metadata = self.response_metadata_fn()
552 metadata_to_add.update(dynamic_metadata)
554 # Add metadata to response
555 if metadata_to_add:
556 if not response.metadata:
557 response.metadata = {}
558 response.metadata.update(metadata_to_add)
560 return response
563class RateLimitMiddleware(ConversationMiddleware):
564 """Middleware that enforces rate limiting on LLM requests.
566 This middleware tracks request rates per conversation or per client
567 and raises an exception when the rate limit is exceeded. Rate limits
568 are tracked in-memory using a sliding window algorithm.
570 Example:
571 >>> # Limit to 10 requests per minute
572 >>> middleware = RateLimitMiddleware(
573 ... max_requests=10,
574 ... window_seconds=60
575 ... )
576 >>>
577 >>> # Per-client rate limiting
578 >>> middleware = RateLimitMiddleware(
579 ... max_requests=100,
580 ... window_seconds=3600,
581 ... scope="client_id" # Rate limit per client
582 ... )
583 >>>
584 >>> # With custom key function
585 >>> def get_user_id(state):
586 ... return state.metadata.get("user_id")
587 >>>
588 >>> middleware = RateLimitMiddleware(
589 ... max_requests=50,
590 ... window_seconds=60,
591 ... key_fn=get_user_id
592 ... )
593 """
595 def __init__(
596 self,
597 max_requests: int,
598 window_seconds: int = 60,
599 scope: str = "conversation", # "conversation" or "client_id"
600 key_fn: Callable[[ConversationState], str] | None = None
601 ):
602 """Initialize rate limiting middleware.
604 Args:
605 max_requests: Maximum number of requests allowed in window
606 window_seconds: Time window in seconds for rate limiting
607 scope: Scope for rate limiting ("conversation" or "client_id")
608 key_fn: Optional custom function to extract rate limit key from state
609 """
610 self.max_requests = max_requests
611 self.window_seconds = window_seconds
612 self.scope = scope
613 self.key_fn = key_fn
615 # In-memory storage: key -> list of request timestamps
616 self._request_history: Dict[str, List[float]] = {}
618 def _get_rate_limit_key(self, state: ConversationState) -> str:
619 """Get the key to use for rate limiting.
621 Args:
622 state: Conversation state
624 Returns:
625 Rate limit key
626 """
627 if self.key_fn:
628 return self.key_fn(state)
629 elif self.scope == "client_id":
630 return state.metadata.get("client_id", state.conversation_id)
631 else:
632 return state.conversation_id
634 def _clean_old_requests(self, key: str, current_time: float) -> None:
635 """Remove requests outside the time window.
637 Args:
638 key: Rate limit key
639 current_time: Current timestamp
640 """
641 if key in self._request_history:
642 cutoff_time = current_time - self.window_seconds
643 self._request_history[key] = [
644 ts for ts in self._request_history[key]
645 if ts > cutoff_time
646 ]
648 def _check_rate_limit(self, key: str, current_time: float) -> tuple[bool, int]:
649 """Check if request is within rate limit.
651 Args:
652 key: Rate limit key
653 current_time: Current timestamp
655 Returns:
656 Tuple of (is_allowed, current_count)
657 """
658 # Clean old requests
659 self._clean_old_requests(key, current_time)
661 # Check current count
662 if key not in self._request_history:
663 self._request_history[key] = []
665 current_count = len(self._request_history[key])
666 is_allowed = current_count < self.max_requests
668 return is_allowed, current_count
670 def _record_request(self, key: str, current_time: float) -> None:
671 """Record a new request.
673 Args:
674 key: Rate limit key
675 current_time: Current timestamp
676 """
677 if key not in self._request_history:
678 self._request_history[key] = []
680 self._request_history[key].append(current_time)
682 async def process_request(
683 self,
684 messages: List[LLMMessage],
685 state: ConversationState
686 ) -> List[LLMMessage]:
687 """Check rate limit before allowing request through."""
688 import time
690 current_time = time.time()
691 key = self._get_rate_limit_key(state)
693 # Check rate limit
694 is_allowed, current_count = self._check_rate_limit(key, current_time)
696 if not is_allowed:
697 # Add rate limit info to state metadata for debugging
698 if not state.metadata:
699 state.metadata = {}
700 state.metadata["rate_limit_exceeded"] = True
701 state.metadata["rate_limit_count"] = current_count
702 state.metadata["rate_limit_max"] = self.max_requests
703 state.metadata["rate_limit_window"] = self.window_seconds
705 raise RateLimitError(
706 f"Rate limit exceeded: {current_count}/{self.max_requests} "
707 f"requests in {self.window_seconds}s window"
708 )
710 # Record this request
711 self._record_request(key, current_time)
713 # Add rate limit info to messages metadata
714 for msg in messages:
715 if not msg.metadata:
716 msg.metadata = {}
717 msg.metadata["rate_limit_count"] = current_count + 1
718 msg.metadata["rate_limit_max"] = self.max_requests
720 return messages
722 async def process_response(
723 self,
724 response: LLMResponse,
725 state: ConversationState
726 ) -> LLMResponse:
727 """Add rate limit info to response metadata."""
728 key = self._get_rate_limit_key(state)
730 if key in self._request_history:
731 current_count = len(self._request_history[key])
733 if not response.metadata:
734 response.metadata = {}
736 response.metadata["rate_limit_count"] = current_count
737 response.metadata["rate_limit_max"] = self.max_requests
738 response.metadata["rate_limit_remaining"] = self.max_requests - current_count
740 return response
742 def get_rate_limit_status(self, key: str) -> Dict[str, Any]:
743 """Get current rate limit status for a key.
745 Args:
746 key: Rate limit key
748 Returns:
749 Dictionary with rate limit status
751 Example:
752 >>> status = middleware.get_rate_limit_status("client-abc")
753 >>> print(status)
754 {
755 'current_count': 5,
756 'max_requests': 10,
757 'remaining': 5,
758 'window_seconds': 60,
759 'next_reset': 45.2 # seconds until oldest request expires
760 }
761 """
762 import time
764 current_time = time.time()
765 self._clean_old_requests(key, current_time)
767 if key not in self._request_history or not self._request_history[key]:
768 return {
769 'current_count': 0,
770 'max_requests': self.max_requests,
771 'remaining': self.max_requests,
772 'window_seconds': self.window_seconds,
773 'next_reset': 0
774 }
776 current_count = len(self._request_history[key])
777 oldest_request = min(self._request_history[key])
778 next_reset = max(0, (oldest_request + self.window_seconds) - current_time)
780 return {
781 'current_count': current_count,
782 'max_requests': self.max_requests,
783 'remaining': max(0, self.max_requests - current_count),
784 'window_seconds': self.window_seconds,
785 'next_reset': next_reset
786 }
788 def reset(self, key: str | None = None) -> None:
789 """Reset rate limit for a specific key or all keys.
791 Args:
792 key: Key to reset. If None, resets all keys.
794 Example:
795 >>> # Reset specific client
796 >>> middleware.reset("client-abc")
797 >>>
798 >>> # Reset all
799 >>> middleware.reset()
800 """
801 if key is None:
802 self._request_history.clear()
803 elif key in self._request_history:
804 del self._request_history[key]
807class RateLimitError(Exception):
808 """Exception raised when rate limit is exceeded."""
809 pass