Coverage for src/dataknobs_llm/conversations/middleware.py: 61%

169 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-08 13:51 -0700

1"""Middleware system for conversation processing. 

2 

3This module provides middleware capabilities for processing messages before they 

4are sent to the LLM and processing responses after they come back from the LLM. 

5Middleware can be used for logging, validation, content filtering, rate limiting, 

6metadata injection, and more. 

7 

8Execution Model (Onion Pattern): 

9 Middleware wraps around LLM calls in an "onion" pattern: 

10 

11 Request Flow: MW0 → MW1 → MW2 → LLM 

12 Response Flow: LLM → MW2 → MW1 → MW0 

13 

14 Example with 3 middleware [Logging, RateLimit, Validation]: 

15 

16 ``` 

17 1. Logging.process_request() # Log incoming messages 

18 2. RateLimit.process_request() # Check rate limits 

19 3. Validation.process_request() # Validate request 

20 4. → LLM Call → # Actual LLM API call 

21 5. Validation.process_response() # Validate LLM response 

22 6. RateLimit.process_response() # Add rate limit info to response 

23 7. Logging.process_response() # Log response details 

24 ``` 

25 

26 This ensures middleware can: 

27 - Time the full LLM call (start timer in process_request, stop in process_response) 

28 - Wrap operations symmetrically (open resources → LLM → close resources) 

29 - See the final state after inner middleware modifications 

30 

31Performance Considerations: 

32 - **Middleware adds latency**: Each middleware's `process_request()` and 

33 `process_response()` adds to total response time. Keep middleware logic fast. 

34 

35 - **Async is key**: All middleware methods are async. Use `await` for I/O 

36 operations (DB calls, network requests) to avoid blocking. 

37 

38 - **Order matters**: Place expensive middleware (like ValidationMiddleware 

39 that makes additional LLM calls) at the end of the list to minimize 

40 wasted work if earlier middleware rejects the request. 

41 

42 - **Memory usage**: RateLimitMiddleware keeps request history in memory. 

43 For high-traffic applications, consider external rate limiting (Redis, etc.). 

44 

45Available Middleware: 

46 - **LoggingMiddleware**: Log requests and responses for debugging 

47 - **ContentFilterMiddleware**: Filter inappropriate content from responses 

48 - **ValidationMiddleware**: Validate responses with additional LLM call 

49 - **MetadataMiddleware**: Inject custom metadata into messages/responses 

50 - **RateLimitMiddleware**: Enforce rate limits with sliding window 

51 

52Example: 

53 ```python 

54 from dataknobs_llm.conversations import ( 

55 ConversationManager, 

56 LoggingMiddleware, 

57 RateLimitMiddleware, 

58 ContentFilterMiddleware 

59 ) 

60 import logging 

61 

62 # Create middleware instances (order matters!) 

63 logger = logging.getLogger(__name__) 

64 logging_mw = LoggingMiddleware(logger) 

65 rate_limit_mw = RateLimitMiddleware(max_requests=10, window_seconds=60) 

66 filter_mw = ContentFilterMiddleware( 

67 filter_words=["inappropriate"], 

68 replacement="[FILTERED]" 

69 ) 

70 

71 # Create conversation with middleware stack 

72 # Execution: Logging → RateLimit → Filter → LLM → Filter → RateLimit → Logging 

73 manager = await ConversationManager.create( 

74 llm=llm, 

75 prompt_builder=builder, 

76 storage=storage, 

77 middleware=[logging_mw, rate_limit_mw, filter_mw] 

78 ) 

79 

80 # All requests will go through middleware pipeline 

81 await manager.add_message(role="user", content="Hello") 

82 response = await manager.complete() # Middleware applied automatically 

83 ``` 

84 

85See Also: 

86 ConversationManager: Uses middleware for all LLM interactions 

87 ConversationMiddleware: Base class for custom middleware 

88""" 

89 

90from abc import ABC, abstractmethod 

91from typing import List, Any, Dict, Callable 

92import logging 

93 

94from dataknobs_llm.llm import LLMMessage, LLMResponse 

95from dataknobs_llm.llm.providers import AsyncLLMProvider 

96from dataknobs_llm.conversations.storage import ConversationState 

97from dataknobs_llm.prompts import AsyncPromptBuilder 

98 

99 

100class ConversationMiddleware(ABC): 

101 """Base class for conversation middleware. 

102 

103 Middleware can process requests before LLM and responses after LLM. 

104 Middleware is executed in order for requests, and in reverse order 

105 for responses (onion pattern). 

106 

107 Execution Order: 

108 Given middleware list [MW0, MW1, MW2]: 

109 

110 - **Request**: MW0 → MW1 → MW2 → LLM 

111 - **Response**: LLM → MW2 → MW1 → MW0 

112 

113 This allows MW0 to: 

114 1. Start a timer in `process_request()` 

115 2. See the LLM call complete 

116 3. Stop the timer in `process_response()` and log total time 

117 

118 Use Cases: 

119 - **Logging**: Track request/response details 

120 - **Validation**: Verify request/response content 

121 - **Transformation**: Modify messages or responses 

122 - **Rate Limiting**: Enforce API usage limits 

123 - **Caching**: Store/retrieve responses 

124 - **Monitoring**: Collect metrics and analytics 

125 - **Security**: Filter sensitive information 

126 

127 Example: 

128 ```python 

129 from dataknobs_llm.conversations import ConversationMiddleware 

130 import time 

131 

132 class TimingMiddleware(ConversationMiddleware): 

133 '''Measure LLM call duration.''' 

134 

135 async def process_request(self, messages, state): 

136 # Store start time in state metadata 

137 state.metadata["request_start"] = time.time() 

138 return messages 

139 

140 async def process_response(self, response, state): 

141 # Calculate elapsed time 

142 start = state.metadata.get("request_start") 

143 if start: 

144 elapsed = time.time() - start 

145 if not response.metadata: 

146 response.metadata = {} 

147 response.metadata["llm_duration_seconds"] = elapsed 

148 print(f"LLM call took {elapsed:.2f}s") 

149 return response 

150 

151 # Use in conversation 

152 manager = await ConversationManager.create( 

153 llm=llm, 

154 middleware=[TimingMiddleware()] 

155 ) 

156 ``` 

157 

158 Note: 

159 **Performance Tips**: 

160 

161 - Keep `process_request()` and `process_response()` fast 

162 - Use async I/O (await) for external calls (DB, network) 

163 - Don't block the async loop with synchronous operations 

164 - For expensive operations, consider running them in background tasks 

165 - Store state in `state.metadata` not instance variables (thread safety) 

166 

167 See Also: 

168 LoggingMiddleware: Example implementation 

169 ConversationManager.complete: Where middleware is executed 

170 """ 

171 

172 @abstractmethod 

173 async def process_request( 

174 self, 

175 messages: List[LLMMessage], 

176 state: ConversationState 

177 ) -> List[LLMMessage]: 

178 """Process messages before sending to LLM. 

179 

180 Args: 

181 messages: Messages to send to LLM 

182 state: Current conversation state 

183 

184 Returns: 

185 Processed messages (can modify, add, or remove messages) 

186 

187 Example: 

188 >>> from datetime import datetime 

189 >>> async def process_request(self, messages, state): 

190 ... # Add timestamp to metadata 

191 ... for msg in messages: 

192 ... if not msg.metadata: 

193 ... msg.metadata = {} 

194 ... msg.metadata["timestamp"] = datetime.now().isoformat() 

195 ... return messages 

196 """ 

197 pass 

198 

199 @abstractmethod 

200 async def process_response( 

201 self, 

202 response: LLMResponse, 

203 state: ConversationState 

204 ) -> LLMResponse: 

205 """Process response from LLM. 

206 

207 Args: 

208 response: LLM response 

209 state: Current conversation state 

210 

211 Returns: 

212 Processed response (can modify content, metadata, etc.) 

213 

214 Example: 

215 >>> from datetime import datetime 

216 >>> async def process_response(self, response, state): 

217 ... # Add processing metadata 

218 ... if not response.metadata: 

219 ... response.metadata = {} 

220 ... response.metadata["processed_at"] = datetime.now().isoformat() 

221 ... return response 

222 """ 

223 pass 

224 

225 

226class LoggingMiddleware(ConversationMiddleware): 

227 """Middleware that logs all requests and responses. 

228 

229 This middleware is useful for debugging and monitoring conversations. 

230 It logs message counts, conversation IDs, and response metadata. 

231 

232 Example: 

233 >>> import logging 

234 >>> logger = logging.getLogger(__name__) 

235 >>> logging.basicConfig(level=logging.INFO) 

236 >>> 

237 >>> middleware = LoggingMiddleware(logger) 

238 >>> manager = await ConversationManager.create( 

239 ... llm=llm, 

240 ... prompt_builder=builder, 

241 ... storage=storage, 

242 ... middleware=[middleware] 

243 ... ) 

244 """ 

245 

246 def __init__(self, logger: logging.Logger | None = None): 

247 """Initialize logging middleware. 

248 

249 Args: 

250 logger: Logger instance to use (defaults to module logger) 

251 """ 

252 self.logger = logger or logging.getLogger(__name__) 

253 

254 async def process_request( 

255 self, 

256 messages: List[LLMMessage], 

257 state: ConversationState 

258 ) -> List[LLMMessage]: 

259 """Log request details before sending to LLM.""" 

260 self.logger.info( 

261 f"Conversation {state.conversation_id} - " 

262 f"Sending {len(messages)} messages to LLM" 

263 ) 

264 self.logger.debug( 

265 f"Conversation {state.conversation_id} - " 

266 f"Message roles: {[msg.role for msg in messages]}" 

267 ) 

268 return messages 

269 

270 async def process_response( 

271 self, 

272 response: LLMResponse, 

273 state: ConversationState 

274 ) -> LLMResponse: 

275 """Log response details after receiving from LLM.""" 

276 content_length = len(response.content) if response.content else 0 

277 self.logger.info( 

278 f"Conversation {state.conversation_id} - " 

279 f"Received response: {content_length} chars, " 

280 f"model={response.model}, finish_reason={response.finish_reason}" 

281 ) 

282 if response.usage: 

283 self.logger.debug( 

284 f"Conversation {state.conversation_id} - " 

285 f"Token usage: {response.usage}" 

286 ) 

287 return response 

288 

289 

290class ContentFilterMiddleware(ConversationMiddleware): 

291 """Middleware that filters inappropriate content from responses. 

292 

293 This middleware can be used to redact or replace specific words or 

294 patterns in LLM responses. Useful for content moderation and compliance. 

295 

296 Example: 

297 >>> # Filter specific words 

298 >>> middleware = ContentFilterMiddleware( 

299 ... filter_words=["badword1", "badword2"], 

300 ... replacement="[FILTERED]" 

301 ... ) 

302 >>> 

303 >>> # Case-insensitive filtering 

304 >>> middleware = ContentFilterMiddleware( 

305 ... filter_words=["sensitive"], 

306 ... case_sensitive=False 

307 ... ) 

308 """ 

309 

310 def __init__( 

311 self, 

312 filter_words: List[str], 

313 replacement: str = "[FILTERED]", 

314 case_sensitive: bool = True 

315 ): 

316 """Initialize content filter middleware. 

317 

318 Args: 

319 filter_words: List of words/phrases to filter 

320 replacement: String to replace filtered content with 

321 case_sensitive: Whether filtering should be case-sensitive 

322 """ 

323 self.filter_words = filter_words 

324 self.replacement = replacement 

325 self.case_sensitive = case_sensitive 

326 

327 async def process_request( 

328 self, 

329 messages: List[LLMMessage], 

330 state: ConversationState 

331 ) -> List[LLMMessage]: 

332 """Pass through requests without filtering.""" 

333 return messages 

334 

335 async def process_response( 

336 self, 

337 response: LLMResponse, 

338 state: ConversationState 

339 ) -> LLMResponse: 

340 """Filter inappropriate content from response.""" 

341 content = response.content 

342 

343 for word in self.filter_words: 

344 if self.case_sensitive: 

345 content = content.replace(word, self.replacement) 

346 else: 

347 # Case-insensitive replacement 

348 import re 

349 pattern = re.compile(re.escape(word), re.IGNORECASE) 

350 content = pattern.sub(self.replacement, content) 

351 

352 # Track if any filtering occurred 

353 if content != response.content: 

354 if not response.metadata: 

355 response.metadata = {} 

356 response.metadata["content_filtered"] = True 

357 response.content = content 

358 

359 return response 

360 

361 

362class ValidationMiddleware(ConversationMiddleware): 

363 """Middleware that validates LLM responses using another LLM call. 

364 

365 This middleware uses a validation prompt and a separate LLM call to check 

366 if responses meet certain criteria. Can optionally retry on validation failure. 

367 

368 Example: 

369 >>> from dataknobs_llm.llm.providers import OpenAIProvider 

370 >>> from dataknobs_llm.llm.base import LLMConfig 

371 >>> 

372 >>> # Create validation middleware 

373 >>> config = LLMConfig(provider="openai", model="gpt-4") 

374 >>> validation_llm = OpenAIProvider(config) 

375 >>> middleware = ValidationMiddleware( 

376 ... llm=validation_llm, 

377 ... prompt_builder=builder, 

378 ... validation_prompt="validate_response", 

379 ... auto_retry=False # Raise error instead of retrying 

380 ... ) 

381 >>> 

382 >>> # Validation prompt should ask the LLM to respond with 

383 >>> # "VALID" or "INVALID" based on the response content 

384 """ 

385 

386 def __init__( 

387 self, 

388 llm: AsyncLLMProvider, 

389 prompt_builder: AsyncPromptBuilder, 

390 validation_prompt: str, 

391 auto_retry: bool = False, 

392 retry_limit: int = 3 

393 ): 

394 """Initialize validation middleware. 

395 

396 Args: 

397 llm: LLM provider to use for validation (required) 

398 prompt_builder: Prompt builder for rendering validation prompt 

399 validation_prompt: Name of validation prompt template 

400 auto_retry: Whether to automatically retry on validation failure 

401 retry_limit: Maximum number of retries if auto_retry is True 

402 """ 

403 self.llm: AsyncLLMProvider = llm 

404 self.builder: AsyncPromptBuilder = prompt_builder 

405 self.validation_prompt = validation_prompt 

406 self.auto_retry = auto_retry 

407 self.retry_limit = retry_limit 

408 

409 async def process_request( 

410 self, 

411 messages: List[LLMMessage], 

412 state: ConversationState 

413 ) -> List[LLMMessage]: 

414 """Pass through requests without validation.""" 

415 return messages 

416 

417 async def process_response( 

418 self, 

419 response: LLMResponse, 

420 state: ConversationState 

421 ) -> LLMResponse: 

422 """Validate response by calling LLM with validation prompt.""" 

423 # Render validation prompt with response content 

424 validation_prompt_result = await self.builder.render_user_prompt( 

425 self.validation_prompt, 

426 index=0, 

427 params={"response": response.content}, 

428 include_rag=False # Don't need RAG for validation 

429 ) 

430 

431 # Create message and call LLM to get validation judgment 

432 validation_message = LLMMessage( 

433 role="user", 

434 content=validation_prompt_result.content 

435 ) 

436 validation_response = await self.llm.complete([validation_message]) 

437 

438 # Check if LLM says response is valid 

439 is_valid = self._check_validity(validation_response.content) 

440 

441 if not is_valid: 

442 # Track validation failure 

443 if not response.metadata: 

444 response.metadata = {} 

445 response.metadata["validation_failed"] = True 

446 response.metadata["validation_response"] = validation_response.content 

447 

448 if self.auto_retry: 

449 # Note: Actual retry logic would need to be implemented 

450 # at the ConversationManager level. This just marks the failure. 

451 response.metadata["retry_requested"] = True 

452 else: 

453 raise ValueError( 

454 f"Response failed validation: {validation_response.content}" 

455 ) 

456 

457 return response 

458 

459 def _check_validity(self, validation_response: str) -> bool: 

460 """Check if validation response indicates success. 

461 

462 Args: 

463 validation_response: Content from validation prompt response 

464 

465 Returns: 

466 True if valid, False otherwise 

467 """ 

468 # Simple implementation: look for "VALID" in response 

469 # This can be customized based on validation prompt design 

470 return "VALID" in validation_response.upper() 

471 

472 

473class MetadataMiddleware(ConversationMiddleware): 

474 """Middleware that adds custom metadata to messages and responses. 

475 

476 This middleware can inject metadata into both requests and responses, 

477 which is useful for tracking, analytics, and debugging. 

478 

479 Example: 

480 >>> from datetime import datetime 

481 >>> 

482 >>> # Add environment info to all messages 

483 >>> middleware = MetadataMiddleware( 

484 ... request_metadata={"environment": "production"}, 

485 ... response_metadata={"version": "1.0.0"} 

486 ... ) 

487 >>> 

488 >>> # Add dynamic metadata via callback 

489 >>> def get_request_meta(): 

490 ... return {"timestamp": datetime.now().isoformat()} 

491 >>> 

492 >>> middleware = MetadataMiddleware( 

493 ... request_metadata_fn=get_request_meta 

494 ... ) 

495 """ 

496 

497 def __init__( 

498 self, 

499 request_metadata: Dict[str, Any] | None = None, 

500 response_metadata: Dict[str, Any] | None = None, 

501 request_metadata_fn: Callable[..., Dict[str, Any]] | None = None, 

502 response_metadata_fn: Callable[..., Dict[str, Any]] | None = None 

503 ): 

504 """Initialize metadata middleware. 

505 

506 Args: 

507 request_metadata: Static metadata to add to requests 

508 response_metadata: Static metadata to add to responses 

509 request_metadata_fn: Callable that returns metadata for requests 

510 response_metadata_fn: Callable that returns metadata for responses 

511 """ 

512 self.request_metadata = request_metadata or {} 

513 self.response_metadata = response_metadata or {} 

514 self.request_metadata_fn = request_metadata_fn 

515 self.response_metadata_fn = response_metadata_fn 

516 

517 async def process_request( 

518 self, 

519 messages: List[LLMMessage], 

520 state: ConversationState 

521 ) -> List[LLMMessage]: 

522 """Add metadata to request messages.""" 

523 # Collect metadata to add 

524 metadata_to_add = dict(self.request_metadata) 

525 

526 # Add dynamic metadata if function provided 

527 if self.request_metadata_fn: 

528 dynamic_metadata = self.request_metadata_fn() 

529 metadata_to_add.update(dynamic_metadata) 

530 

531 # Add metadata to each message 

532 if metadata_to_add: 

533 for msg in messages: 

534 if not msg.metadata: 

535 msg.metadata = {} 

536 msg.metadata.update(metadata_to_add) 

537 

538 return messages 

539 

540 async def process_response( 

541 self, 

542 response: LLMResponse, 

543 state: ConversationState 

544 ) -> LLMResponse: 

545 """Add metadata to response.""" 

546 # Collect metadata to add 

547 metadata_to_add = dict(self.response_metadata) 

548 

549 # Add dynamic metadata if function provided 

550 if self.response_metadata_fn: 

551 dynamic_metadata = self.response_metadata_fn() 

552 metadata_to_add.update(dynamic_metadata) 

553 

554 # Add metadata to response 

555 if metadata_to_add: 

556 if not response.metadata: 

557 response.metadata = {} 

558 response.metadata.update(metadata_to_add) 

559 

560 return response 

561 

562 

563class RateLimitMiddleware(ConversationMiddleware): 

564 """Middleware that enforces rate limiting on LLM requests. 

565 

566 This middleware tracks request rates per conversation or per client 

567 and raises an exception when the rate limit is exceeded. Rate limits 

568 are tracked in-memory using a sliding window algorithm. 

569 

570 Example: 

571 >>> # Limit to 10 requests per minute 

572 >>> middleware = RateLimitMiddleware( 

573 ... max_requests=10, 

574 ... window_seconds=60 

575 ... ) 

576 >>> 

577 >>> # Per-client rate limiting 

578 >>> middleware = RateLimitMiddleware( 

579 ... max_requests=100, 

580 ... window_seconds=3600, 

581 ... scope="client_id" # Rate limit per client 

582 ... ) 

583 >>> 

584 >>> # With custom key function 

585 >>> def get_user_id(state): 

586 ... return state.metadata.get("user_id") 

587 >>> 

588 >>> middleware = RateLimitMiddleware( 

589 ... max_requests=50, 

590 ... window_seconds=60, 

591 ... key_fn=get_user_id 

592 ... ) 

593 """ 

594 

595 def __init__( 

596 self, 

597 max_requests: int, 

598 window_seconds: int = 60, 

599 scope: str = "conversation", # "conversation" or "client_id" 

600 key_fn: Callable[[ConversationState], str] | None = None 

601 ): 

602 """Initialize rate limiting middleware. 

603 

604 Args: 

605 max_requests: Maximum number of requests allowed in window 

606 window_seconds: Time window in seconds for rate limiting 

607 scope: Scope for rate limiting ("conversation" or "client_id") 

608 key_fn: Optional custom function to extract rate limit key from state 

609 """ 

610 self.max_requests = max_requests 

611 self.window_seconds = window_seconds 

612 self.scope = scope 

613 self.key_fn = key_fn 

614 

615 # In-memory storage: key -> list of request timestamps 

616 self._request_history: Dict[str, List[float]] = {} 

617 

618 def _get_rate_limit_key(self, state: ConversationState) -> str: 

619 """Get the key to use for rate limiting. 

620 

621 Args: 

622 state: Conversation state 

623 

624 Returns: 

625 Rate limit key 

626 """ 

627 if self.key_fn: 

628 return self.key_fn(state) 

629 elif self.scope == "client_id": 

630 return state.metadata.get("client_id", state.conversation_id) 

631 else: 

632 return state.conversation_id 

633 

634 def _clean_old_requests(self, key: str, current_time: float) -> None: 

635 """Remove requests outside the time window. 

636 

637 Args: 

638 key: Rate limit key 

639 current_time: Current timestamp 

640 """ 

641 if key in self._request_history: 

642 cutoff_time = current_time - self.window_seconds 

643 self._request_history[key] = [ 

644 ts for ts in self._request_history[key] 

645 if ts > cutoff_time 

646 ] 

647 

648 def _check_rate_limit(self, key: str, current_time: float) -> tuple[bool, int]: 

649 """Check if request is within rate limit. 

650 

651 Args: 

652 key: Rate limit key 

653 current_time: Current timestamp 

654 

655 Returns: 

656 Tuple of (is_allowed, current_count) 

657 """ 

658 # Clean old requests 

659 self._clean_old_requests(key, current_time) 

660 

661 # Check current count 

662 if key not in self._request_history: 

663 self._request_history[key] = [] 

664 

665 current_count = len(self._request_history[key]) 

666 is_allowed = current_count < self.max_requests 

667 

668 return is_allowed, current_count 

669 

670 def _record_request(self, key: str, current_time: float) -> None: 

671 """Record a new request. 

672 

673 Args: 

674 key: Rate limit key 

675 current_time: Current timestamp 

676 """ 

677 if key not in self._request_history: 

678 self._request_history[key] = [] 

679 

680 self._request_history[key].append(current_time) 

681 

682 async def process_request( 

683 self, 

684 messages: List[LLMMessage], 

685 state: ConversationState 

686 ) -> List[LLMMessage]: 

687 """Check rate limit before allowing request through.""" 

688 import time 

689 

690 current_time = time.time() 

691 key = self._get_rate_limit_key(state) 

692 

693 # Check rate limit 

694 is_allowed, current_count = self._check_rate_limit(key, current_time) 

695 

696 if not is_allowed: 

697 # Add rate limit info to state metadata for debugging 

698 if not state.metadata: 

699 state.metadata = {} 

700 state.metadata["rate_limit_exceeded"] = True 

701 state.metadata["rate_limit_count"] = current_count 

702 state.metadata["rate_limit_max"] = self.max_requests 

703 state.metadata["rate_limit_window"] = self.window_seconds 

704 

705 raise RateLimitError( 

706 f"Rate limit exceeded: {current_count}/{self.max_requests} " 

707 f"requests in {self.window_seconds}s window" 

708 ) 

709 

710 # Record this request 

711 self._record_request(key, current_time) 

712 

713 # Add rate limit info to messages metadata 

714 for msg in messages: 

715 if not msg.metadata: 

716 msg.metadata = {} 

717 msg.metadata["rate_limit_count"] = current_count + 1 

718 msg.metadata["rate_limit_max"] = self.max_requests 

719 

720 return messages 

721 

722 async def process_response( 

723 self, 

724 response: LLMResponse, 

725 state: ConversationState 

726 ) -> LLMResponse: 

727 """Add rate limit info to response metadata.""" 

728 key = self._get_rate_limit_key(state) 

729 

730 if key in self._request_history: 

731 current_count = len(self._request_history[key]) 

732 

733 if not response.metadata: 

734 response.metadata = {} 

735 

736 response.metadata["rate_limit_count"] = current_count 

737 response.metadata["rate_limit_max"] = self.max_requests 

738 response.metadata["rate_limit_remaining"] = self.max_requests - current_count 

739 

740 return response 

741 

742 def get_rate_limit_status(self, key: str) -> Dict[str, Any]: 

743 """Get current rate limit status for a key. 

744 

745 Args: 

746 key: Rate limit key 

747 

748 Returns: 

749 Dictionary with rate limit status 

750 

751 Example: 

752 >>> status = middleware.get_rate_limit_status("client-abc") 

753 >>> print(status) 

754 { 

755 'current_count': 5, 

756 'max_requests': 10, 

757 'remaining': 5, 

758 'window_seconds': 60, 

759 'next_reset': 45.2 # seconds until oldest request expires 

760 } 

761 """ 

762 import time 

763 

764 current_time = time.time() 

765 self._clean_old_requests(key, current_time) 

766 

767 if key not in self._request_history or not self._request_history[key]: 

768 return { 

769 'current_count': 0, 

770 'max_requests': self.max_requests, 

771 'remaining': self.max_requests, 

772 'window_seconds': self.window_seconds, 

773 'next_reset': 0 

774 } 

775 

776 current_count = len(self._request_history[key]) 

777 oldest_request = min(self._request_history[key]) 

778 next_reset = max(0, (oldest_request + self.window_seconds) - current_time) 

779 

780 return { 

781 'current_count': current_count, 

782 'max_requests': self.max_requests, 

783 'remaining': max(0, self.max_requests - current_count), 

784 'window_seconds': self.window_seconds, 

785 'next_reset': next_reset 

786 } 

787 

788 def reset(self, key: str | None = None) -> None: 

789 """Reset rate limit for a specific key or all keys. 

790 

791 Args: 

792 key: Key to reset. If None, resets all keys. 

793 

794 Example: 

795 >>> # Reset specific client 

796 >>> middleware.reset("client-abc") 

797 >>> 

798 >>> # Reset all 

799 >>> middleware.reset() 

800 """ 

801 if key is None: 

802 self._request_history.clear() 

803 elif key in self._request_history: 

804 del self._request_history[key] 

805 

806 

807class RateLimitError(Exception): 

808 """Exception raised when rate limit is exceeded.""" 

809 pass