Coverage for src / dataknobs_llm / conversations / middleware.py: 21%
168 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-15 10:28 -0700
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-15 10:28 -0700
1"""Middleware system for conversation processing.
3This module provides middleware capabilities for processing messages before they
4are sent to the LLM and processing responses after they come back from the LLM.
5Middleware can be used for logging, validation, content filtering, rate limiting,
6metadata injection, and more.
8Execution Model (Onion Pattern):
9 Middleware wraps around LLM calls in an "onion" pattern:
11 Request Flow: MW0 → MW1 → MW2 → LLM
12 Response Flow: LLM → MW2 → MW1 → MW0
14 Example with 3 middleware [Logging, RateLimit, Validation]:
16 ```
17 1. Logging.process_request() # Log incoming messages
18 2. RateLimit.process_request() # Check rate limits
19 3. Validation.process_request() # Validate request
20 4. → LLM Call → # Actual LLM API call
21 5. Validation.process_response() # Validate LLM response
22 6. RateLimit.process_response() # Add rate limit info to response
23 7. Logging.process_response() # Log response details
24 ```
26 This ensures middleware can:
27 - Time the full LLM call (start timer in process_request, stop in process_response)
28 - Wrap operations symmetrically (open resources → LLM → close resources)
29 - See the final state after inner middleware modifications
31Performance Considerations:
32 - **Middleware adds latency**: Each middleware's `process_request()` and
33 `process_response()` adds to total response time. Keep middleware logic fast.
35 - **Async is key**: All middleware methods are async. Use `await` for I/O
36 operations (DB calls, network requests) to avoid blocking.
38 - **Order matters**: Place expensive middleware (like ValidationMiddleware
39 that makes additional LLM calls) at the end of the list to minimize
40 wasted work if earlier middleware rejects the request.
42 - **Memory usage**: RateLimitMiddleware keeps request history in memory.
43 For high-traffic applications, consider external rate limiting (Redis, etc.).
45Available Middleware:
46 - **LoggingMiddleware**: Log requests and responses for debugging
47 - **ContentFilterMiddleware**: Filter inappropriate content from responses
48 - **ValidationMiddleware**: Validate responses with additional LLM call
49 - **MetadataMiddleware**: Inject custom metadata into messages/responses
50 - **RateLimitMiddleware**: Enforce rate limits with sliding window
52Example:
53 ```python
54 from dataknobs_llm.conversations import (
55 ConversationManager,
56 LoggingMiddleware,
57 RateLimitMiddleware,
58 ContentFilterMiddleware
59 )
60 import logging
62 # Create middleware instances (order matters!)
63 logger = logging.getLogger(__name__)
64 logging_mw = LoggingMiddleware(logger)
65 rate_limit_mw = RateLimitMiddleware(max_requests=10, window_seconds=60)
66 filter_mw = ContentFilterMiddleware(
67 filter_words=["inappropriate"],
68 replacement="[FILTERED]"
69 )
71 # Create conversation with middleware stack
72 # Execution: Logging → RateLimit → Filter → LLM → Filter → RateLimit → Logging
73 manager = await ConversationManager.create(
74 llm=llm,
75 prompt_builder=builder,
76 storage=storage,
77 middleware=[logging_mw, rate_limit_mw, filter_mw]
78 )
80 # All requests will go through middleware pipeline
81 await manager.add_message(role="user", content="Hello")
82 response = await manager.complete() # Middleware applied automatically
83 ```
85See Also:
86 ConversationManager: Uses middleware for all LLM interactions
87 ConversationMiddleware: Base class for custom middleware
88"""
90from abc import ABC, abstractmethod
91from typing import List, Any, Dict, Callable
92import logging
94from dataknobs_llm.llm import LLMMessage, LLMResponse
95from dataknobs_llm.llm.providers import AsyncLLMProvider
96from dataknobs_llm.conversations.storage import ConversationState
97from dataknobs_llm.prompts import AsyncPromptBuilder
98from dataknobs_llm.exceptions import RateLimitError
101class ConversationMiddleware(ABC):
102 """Base class for conversation middleware.
104 Middleware can process requests before LLM and responses after LLM.
105 Middleware is executed in order for requests, and in reverse order
106 for responses (onion pattern).
108 Execution Order:
109 Given middleware list [MW0, MW1, MW2]:
111 - **Request**: MW0 → MW1 → MW2 → LLM
112 - **Response**: LLM → MW2 → MW1 → MW0
114 This allows MW0 to:
115 1. Start a timer in `process_request()`
116 2. See the LLM call complete
117 3. Stop the timer in `process_response()` and log total time
119 Use Cases:
120 - **Logging**: Track request/response details
121 - **Validation**: Verify request/response content
122 - **Transformation**: Modify messages or responses
123 - **Rate Limiting**: Enforce API usage limits
124 - **Caching**: Store/retrieve responses
125 - **Monitoring**: Collect metrics and analytics
126 - **Security**: Filter sensitive information
128 Example:
129 ```python
130 from dataknobs_llm.conversations import ConversationMiddleware
131 import time
133 class TimingMiddleware(ConversationMiddleware):
134 '''Measure LLM call duration.'''
136 async def process_request(self, messages, state):
137 # Store start time in state metadata
138 state.metadata["request_start"] = time.time()
139 return messages
141 async def process_response(self, response, state):
142 # Calculate elapsed time
143 start = state.metadata.get("request_start")
144 if start:
145 elapsed = time.time() - start
146 if not response.metadata:
147 response.metadata = {}
148 response.metadata["llm_duration_seconds"] = elapsed
149 print(f"LLM call took {elapsed:.2f}s")
150 return response
152 # Use in conversation
153 manager = await ConversationManager.create(
154 llm=llm,
155 middleware=[TimingMiddleware()]
156 )
157 ```
159 Note:
160 **Performance Tips**:
162 - Keep `process_request()` and `process_response()` fast
163 - Use async I/O (await) for external calls (DB, network)
164 - Don't block the async loop with synchronous operations
165 - For expensive operations, consider running them in background tasks
166 - Store state in `state.metadata` not instance variables (thread safety)
168 See Also:
169 LoggingMiddleware: Example implementation
170 ConversationManager.complete: Where middleware is executed
171 """
173 @abstractmethod
174 async def process_request(
175 self,
176 messages: List[LLMMessage],
177 state: ConversationState
178 ) -> List[LLMMessage]:
179 """Process messages before sending to LLM.
181 Args:
182 messages: Messages to send to LLM
183 state: Current conversation state
185 Returns:
186 Processed messages (can modify, add, or remove messages)
188 Example:
189 >>> from datetime import datetime
190 >>> async def process_request(self, messages, state):
191 ... # Add timestamp to metadata
192 ... for msg in messages:
193 ... if not msg.metadata:
194 ... msg.metadata = {}
195 ... msg.metadata["timestamp"] = datetime.now().isoformat()
196 ... return messages
197 """
198 pass
200 @abstractmethod
201 async def process_response(
202 self,
203 response: LLMResponse,
204 state: ConversationState
205 ) -> LLMResponse:
206 """Process response from LLM.
208 Args:
209 response: LLM response
210 state: Current conversation state
212 Returns:
213 Processed response (can modify content, metadata, etc.)
215 Example:
216 >>> from datetime import datetime
217 >>> async def process_response(self, response, state):
218 ... # Add processing metadata
219 ... if not response.metadata:
220 ... response.metadata = {}
221 ... response.metadata["processed_at"] = datetime.now().isoformat()
222 ... return response
223 """
224 pass
227class LoggingMiddleware(ConversationMiddleware):
228 """Middleware that logs all requests and responses.
230 This middleware is useful for debugging and monitoring conversations.
231 It logs message counts, conversation IDs, and response metadata.
233 Example:
234 >>> import logging
235 >>> logger = logging.getLogger(__name__)
236 >>> logging.basicConfig(level=logging.INFO)
237 >>>
238 >>> middleware = LoggingMiddleware(logger)
239 >>> manager = await ConversationManager.create(
240 ... llm=llm,
241 ... prompt_builder=builder,
242 ... storage=storage,
243 ... middleware=[middleware]
244 ... )
245 """
247 def __init__(self, logger: logging.Logger | None = None):
248 """Initialize logging middleware.
250 Args:
251 logger: Logger instance to use (defaults to module logger)
252 """
253 self.logger = logger or logging.getLogger(__name__)
255 async def process_request(
256 self,
257 messages: List[LLMMessage],
258 state: ConversationState
259 ) -> List[LLMMessage]:
260 """Log request details before sending to LLM."""
261 self.logger.info(
262 f"Conversation {state.conversation_id} - "
263 f"Sending {len(messages)} messages to LLM"
264 )
265 self.logger.debug(
266 f"Conversation {state.conversation_id} - "
267 f"Message roles: {[msg.role for msg in messages]}"
268 )
269 return messages
271 async def process_response(
272 self,
273 response: LLMResponse,
274 state: ConversationState
275 ) -> LLMResponse:
276 """Log response details after receiving from LLM."""
277 content_length = len(response.content) if response.content else 0
278 self.logger.info(
279 f"Conversation {state.conversation_id} - "
280 f"Received response: {content_length} chars, "
281 f"model={response.model}, finish_reason={response.finish_reason}"
282 )
283 if response.usage:
284 self.logger.debug(
285 f"Conversation {state.conversation_id} - "
286 f"Token usage: {response.usage}"
287 )
288 return response
291class ContentFilterMiddleware(ConversationMiddleware):
292 """Middleware that filters inappropriate content from responses.
294 This middleware can be used to redact or replace specific words or
295 patterns in LLM responses. Useful for content moderation and compliance.
297 Example:
298 >>> # Filter specific words
299 >>> middleware = ContentFilterMiddleware(
300 ... filter_words=["badword1", "badword2"],
301 ... replacement="[FILTERED]"
302 ... )
303 >>>
304 >>> # Case-insensitive filtering
305 >>> middleware = ContentFilterMiddleware(
306 ... filter_words=["sensitive"],
307 ... case_sensitive=False
308 ... )
309 """
311 def __init__(
312 self,
313 filter_words: List[str],
314 replacement: str = "[FILTERED]",
315 case_sensitive: bool = True
316 ):
317 """Initialize content filter middleware.
319 Args:
320 filter_words: List of words/phrases to filter
321 replacement: String to replace filtered content with
322 case_sensitive: Whether filtering should be case-sensitive
323 """
324 self.filter_words = filter_words
325 self.replacement = replacement
326 self.case_sensitive = case_sensitive
328 async def process_request(
329 self,
330 messages: List[LLMMessage],
331 state: ConversationState
332 ) -> List[LLMMessage]:
333 """Pass through requests without filtering."""
334 return messages
336 async def process_response(
337 self,
338 response: LLMResponse,
339 state: ConversationState
340 ) -> LLMResponse:
341 """Filter inappropriate content from response."""
342 content = response.content
344 for word in self.filter_words:
345 if self.case_sensitive:
346 content = content.replace(word, self.replacement)
347 else:
348 # Case-insensitive replacement
349 import re
350 pattern = re.compile(re.escape(word), re.IGNORECASE)
351 content = pattern.sub(self.replacement, content)
353 # Track if any filtering occurred
354 if content != response.content:
355 if not response.metadata:
356 response.metadata = {}
357 response.metadata["content_filtered"] = True
358 response.content = content
360 return response
363class ValidationMiddleware(ConversationMiddleware):
364 """Middleware that validates LLM responses using another LLM call.
366 This middleware uses a validation prompt and a separate LLM call to check
367 if responses meet certain criteria. Can optionally retry on validation failure.
369 Example:
370 >>> from dataknobs_llm.llm.providers import OpenAIProvider
371 >>> from dataknobs_llm.llm.base import LLMConfig
372 >>>
373 >>> # Create validation middleware
374 >>> config = LLMConfig(provider="openai", model="gpt-4")
375 >>> validation_llm = OpenAIProvider(config)
376 >>> middleware = ValidationMiddleware(
377 ... llm=validation_llm,
378 ... prompt_builder=builder,
379 ... validation_prompt="validate_response",
380 ... auto_retry=False # Raise error instead of retrying
381 ... )
382 >>>
383 >>> # Validation prompt should ask the LLM to respond with
384 >>> # "VALID" or "INVALID" based on the response content
385 """
387 def __init__(
388 self,
389 llm: AsyncLLMProvider,
390 prompt_builder: AsyncPromptBuilder,
391 validation_prompt: str,
392 auto_retry: bool = False,
393 retry_limit: int = 3
394 ):
395 """Initialize validation middleware.
397 Args:
398 llm: LLM provider to use for validation (required)
399 prompt_builder: Prompt builder for rendering validation prompt
400 validation_prompt: Name of validation prompt template
401 auto_retry: Whether to automatically retry on validation failure
402 retry_limit: Maximum number of retries if auto_retry is True
403 """
404 self.llm: AsyncLLMProvider = llm
405 self.builder: AsyncPromptBuilder = prompt_builder
406 self.validation_prompt = validation_prompt
407 self.auto_retry = auto_retry
408 self.retry_limit = retry_limit
410 async def process_request(
411 self,
412 messages: List[LLMMessage],
413 state: ConversationState
414 ) -> List[LLMMessage]:
415 """Pass through requests without validation."""
416 return messages
418 async def process_response(
419 self,
420 response: LLMResponse,
421 state: ConversationState
422 ) -> LLMResponse:
423 """Validate response by calling LLM with validation prompt."""
424 # Render validation prompt with response content
425 validation_prompt_result = await self.builder.render_user_prompt(
426 self.validation_prompt,
427 index=0,
428 params={"response": response.content},
429 include_rag=False # Don't need RAG for validation
430 )
432 # Create message and call LLM to get validation judgment
433 validation_message = LLMMessage(
434 role="user",
435 content=validation_prompt_result.content
436 )
437 validation_response = await self.llm.complete([validation_message])
439 # Check if LLM says response is valid
440 is_valid = self._check_validity(validation_response.content)
442 if not is_valid:
443 # Track validation failure
444 if not response.metadata:
445 response.metadata = {}
446 response.metadata["validation_failed"] = True
447 response.metadata["validation_response"] = validation_response.content
449 if self.auto_retry:
450 # Note: Actual retry logic would need to be implemented
451 # at the ConversationManager level. This just marks the failure.
452 response.metadata["retry_requested"] = True
453 else:
454 raise ValueError(
455 f"Response failed validation: {validation_response.content}"
456 )
458 return response
460 def _check_validity(self, validation_response: str) -> bool:
461 """Check if validation response indicates success.
463 Args:
464 validation_response: Content from validation prompt response
466 Returns:
467 True if valid, False otherwise
468 """
469 # Simple implementation: look for "VALID" in response
470 # This can be customized based on validation prompt design
471 return "VALID" in validation_response.upper()
474class MetadataMiddleware(ConversationMiddleware):
475 """Middleware that adds custom metadata to messages and responses.
477 This middleware can inject metadata into both requests and responses,
478 which is useful for tracking, analytics, and debugging.
480 Example:
481 >>> from datetime import datetime
482 >>>
483 >>> # Add environment info to all messages
484 >>> middleware = MetadataMiddleware(
485 ... request_metadata={"environment": "production"},
486 ... response_metadata={"version": "1.0.0"}
487 ... )
488 >>>
489 >>> # Add dynamic metadata via callback
490 >>> def get_request_meta():
491 ... return {"timestamp": datetime.now().isoformat()}
492 >>>
493 >>> middleware = MetadataMiddleware(
494 ... request_metadata_fn=get_request_meta
495 ... )
496 """
498 def __init__(
499 self,
500 request_metadata: Dict[str, Any] | None = None,
501 response_metadata: Dict[str, Any] | None = None,
502 request_metadata_fn: Callable[..., Dict[str, Any]] | None = None,
503 response_metadata_fn: Callable[..., Dict[str, Any]] | None = None
504 ):
505 """Initialize metadata middleware.
507 Args:
508 request_metadata: Static metadata to add to requests
509 response_metadata: Static metadata to add to responses
510 request_metadata_fn: Callable that returns metadata for requests
511 response_metadata_fn: Callable that returns metadata for responses
512 """
513 self.request_metadata = request_metadata or {}
514 self.response_metadata = response_metadata or {}
515 self.request_metadata_fn = request_metadata_fn
516 self.response_metadata_fn = response_metadata_fn
518 async def process_request(
519 self,
520 messages: List[LLMMessage],
521 state: ConversationState
522 ) -> List[LLMMessage]:
523 """Add metadata to request messages."""
524 # Collect metadata to add
525 metadata_to_add = dict(self.request_metadata)
527 # Add dynamic metadata if function provided
528 if self.request_metadata_fn:
529 dynamic_metadata = self.request_metadata_fn()
530 metadata_to_add.update(dynamic_metadata)
532 # Add metadata to each message
533 if metadata_to_add:
534 for msg in messages:
535 if not msg.metadata:
536 msg.metadata = {}
537 msg.metadata.update(metadata_to_add)
539 return messages
541 async def process_response(
542 self,
543 response: LLMResponse,
544 state: ConversationState
545 ) -> LLMResponse:
546 """Add metadata to response."""
547 # Collect metadata to add
548 metadata_to_add = dict(self.response_metadata)
550 # Add dynamic metadata if function provided
551 if self.response_metadata_fn:
552 dynamic_metadata = self.response_metadata_fn()
553 metadata_to_add.update(dynamic_metadata)
555 # Add metadata to response
556 if metadata_to_add:
557 if not response.metadata:
558 response.metadata = {}
559 response.metadata.update(metadata_to_add)
561 return response
564class RateLimitMiddleware(ConversationMiddleware):
565 """Middleware that enforces rate limiting on LLM requests.
567 This middleware tracks request rates per conversation or per client
568 and raises an exception when the rate limit is exceeded. Rate limits
569 are tracked in-memory using a sliding window algorithm.
571 Example:
572 >>> # Limit to 10 requests per minute
573 >>> middleware = RateLimitMiddleware(
574 ... max_requests=10,
575 ... window_seconds=60
576 ... )
577 >>>
578 >>> # Per-client rate limiting
579 >>> middleware = RateLimitMiddleware(
580 ... max_requests=100,
581 ... window_seconds=3600,
582 ... scope="client_id" # Rate limit per client
583 ... )
584 >>>
585 >>> # With custom key function
586 >>> def get_user_id(state):
587 ... return state.metadata.get("user_id")
588 >>>
589 >>> middleware = RateLimitMiddleware(
590 ... max_requests=50,
591 ... window_seconds=60,
592 ... key_fn=get_user_id
593 ... )
594 """
596 def __init__(
597 self,
598 max_requests: int,
599 window_seconds: int = 60,
600 scope: str = "conversation", # "conversation" or "client_id"
601 key_fn: Callable[[ConversationState], str] | None = None
602 ):
603 """Initialize rate limiting middleware.
605 Args:
606 max_requests: Maximum number of requests allowed in window
607 window_seconds: Time window in seconds for rate limiting
608 scope: Scope for rate limiting ("conversation" or "client_id")
609 key_fn: Optional custom function to extract rate limit key from state
610 """
611 self.max_requests = max_requests
612 self.window_seconds = window_seconds
613 self.scope = scope
614 self.key_fn = key_fn
616 # In-memory storage: key -> list of request timestamps
617 self._request_history: Dict[str, List[float]] = {}
619 def _get_rate_limit_key(self, state: ConversationState) -> str:
620 """Get the key to use for rate limiting.
622 Args:
623 state: Conversation state
625 Returns:
626 Rate limit key
627 """
628 if self.key_fn:
629 return self.key_fn(state)
630 elif self.scope == "client_id":
631 return state.metadata.get("client_id", state.conversation_id)
632 else:
633 return state.conversation_id
635 def _clean_old_requests(self, key: str, current_time: float) -> None:
636 """Remove requests outside the time window.
638 Args:
639 key: Rate limit key
640 current_time: Current timestamp
641 """
642 if key in self._request_history:
643 cutoff_time = current_time - self.window_seconds
644 self._request_history[key] = [
645 ts for ts in self._request_history[key]
646 if ts > cutoff_time
647 ]
649 def _check_rate_limit(self, key: str, current_time: float) -> tuple[bool, int]:
650 """Check if request is within rate limit.
652 Args:
653 key: Rate limit key
654 current_time: Current timestamp
656 Returns:
657 Tuple of (is_allowed, current_count)
658 """
659 # Clean old requests
660 self._clean_old_requests(key, current_time)
662 # Check current count
663 if key not in self._request_history:
664 self._request_history[key] = []
666 current_count = len(self._request_history[key])
667 is_allowed = current_count < self.max_requests
669 return is_allowed, current_count
671 def _record_request(self, key: str, current_time: float) -> None:
672 """Record a new request.
674 Args:
675 key: Rate limit key
676 current_time: Current timestamp
677 """
678 if key not in self._request_history:
679 self._request_history[key] = []
681 self._request_history[key].append(current_time)
683 async def process_request(
684 self,
685 messages: List[LLMMessage],
686 state: ConversationState
687 ) -> List[LLMMessage]:
688 """Check rate limit before allowing request through."""
689 import time
691 current_time = time.time()
692 key = self._get_rate_limit_key(state)
694 # Check rate limit
695 is_allowed, current_count = self._check_rate_limit(key, current_time)
697 if not is_allowed:
698 # Add rate limit info to state metadata for debugging
699 if not state.metadata:
700 state.metadata = {}
701 state.metadata["rate_limit_exceeded"] = True
702 state.metadata["rate_limit_count"] = current_count
703 state.metadata["rate_limit_max"] = self.max_requests
704 state.metadata["rate_limit_window"] = self.window_seconds
706 raise RateLimitError(
707 f"Rate limit exceeded: {current_count}/{self.max_requests} "
708 f"requests in {self.window_seconds}s window"
709 )
711 # Record this request
712 self._record_request(key, current_time)
714 # Add rate limit info to messages metadata
715 for msg in messages:
716 if not msg.metadata:
717 msg.metadata = {}
718 msg.metadata["rate_limit_count"] = current_count + 1
719 msg.metadata["rate_limit_max"] = self.max_requests
721 return messages
723 async def process_response(
724 self,
725 response: LLMResponse,
726 state: ConversationState
727 ) -> LLMResponse:
728 """Add rate limit info to response metadata."""
729 key = self._get_rate_limit_key(state)
731 if key in self._request_history:
732 current_count = len(self._request_history[key])
734 if not response.metadata:
735 response.metadata = {}
737 response.metadata["rate_limit_count"] = current_count
738 response.metadata["rate_limit_max"] = self.max_requests
739 response.metadata["rate_limit_remaining"] = self.max_requests - current_count
741 return response
743 def get_rate_limit_status(self, key: str) -> Dict[str, Any]:
744 """Get current rate limit status for a key.
746 Args:
747 key: Rate limit key
749 Returns:
750 Dictionary with rate limit status
752 Example:
753 >>> status = middleware.get_rate_limit_status("client-abc")
754 >>> print(status)
755 {
756 'current_count': 5,
757 'max_requests': 10,
758 'remaining': 5,
759 'window_seconds': 60,
760 'next_reset': 45.2 # seconds until oldest request expires
761 }
762 """
763 import time
765 current_time = time.time()
766 self._clean_old_requests(key, current_time)
768 if key not in self._request_history or not self._request_history[key]:
769 return {
770 'current_count': 0,
771 'max_requests': self.max_requests,
772 'remaining': self.max_requests,
773 'window_seconds': self.window_seconds,
774 'next_reset': 0
775 }
777 current_count = len(self._request_history[key])
778 oldest_request = min(self._request_history[key])
779 next_reset = max(0, (oldest_request + self.window_seconds) - current_time)
781 return {
782 'current_count': current_count,
783 'max_requests': self.max_requests,
784 'remaining': max(0, self.max_requests - current_count),
785 'window_seconds': self.window_seconds,
786 'next_reset': next_reset
787 }
789 def reset(self, key: str | None = None) -> None:
790 """Reset rate limit for a specific key or all keys.
792 Args:
793 key: Key to reset. If None, resets all keys.
795 Example:
796 >>> # Reset specific client
797 >>> middleware.reset("client-abc")
798 >>>
799 >>> # Reset all
800 >>> middleware.reset()
801 """
802 if key is None:
803 self._request_history.clear()
804 elif key in self._request_history:
805 del self._request_history[key]