Coverage for src / dataknobs_llm / conversations / middleware.py: 21%

168 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-15 10:28 -0700

1"""Middleware system for conversation processing. 

2 

3This module provides middleware capabilities for processing messages before they 

4are sent to the LLM and processing responses after they come back from the LLM. 

5Middleware can be used for logging, validation, content filtering, rate limiting, 

6metadata injection, and more. 

7 

8Execution Model (Onion Pattern): 

9 Middleware wraps around LLM calls in an "onion" pattern: 

10 

11 Request Flow: MW0 → MW1 → MW2 → LLM 

12 Response Flow: LLM → MW2 → MW1 → MW0 

13 

14 Example with 3 middleware [Logging, RateLimit, Validation]: 

15 

16 ``` 

17 1. Logging.process_request() # Log incoming messages 

18 2. RateLimit.process_request() # Check rate limits 

19 3. Validation.process_request() # Validate request 

20 4. → LLM Call → # Actual LLM API call 

21 5. Validation.process_response() # Validate LLM response 

22 6. RateLimit.process_response() # Add rate limit info to response 

23 7. Logging.process_response() # Log response details 

24 ``` 

25 

26 This ensures middleware can: 

27 - Time the full LLM call (start timer in process_request, stop in process_response) 

28 - Wrap operations symmetrically (open resources → LLM → close resources) 

29 - See the final state after inner middleware modifications 

30 

31Performance Considerations: 

32 - **Middleware adds latency**: Each middleware's `process_request()` and 

33 `process_response()` adds to total response time. Keep middleware logic fast. 

34 

35 - **Async is key**: All middleware methods are async. Use `await` for I/O 

36 operations (DB calls, network requests) to avoid blocking. 

37 

38 - **Order matters**: Place expensive middleware (like ValidationMiddleware 

39 that makes additional LLM calls) at the end of the list to minimize 

40 wasted work if earlier middleware rejects the request. 

41 

42 - **Memory usage**: RateLimitMiddleware keeps request history in memory. 

43 For high-traffic applications, consider external rate limiting (Redis, etc.). 

44 

45Available Middleware: 

46 - **LoggingMiddleware**: Log requests and responses for debugging 

47 - **ContentFilterMiddleware**: Filter inappropriate content from responses 

48 - **ValidationMiddleware**: Validate responses with additional LLM call 

49 - **MetadataMiddleware**: Inject custom metadata into messages/responses 

50 - **RateLimitMiddleware**: Enforce rate limits with sliding window 

51 

52Example: 

53 ```python 

54 from dataknobs_llm.conversations import ( 

55 ConversationManager, 

56 LoggingMiddleware, 

57 RateLimitMiddleware, 

58 ContentFilterMiddleware 

59 ) 

60 import logging 

61 

62 # Create middleware instances (order matters!) 

63 logger = logging.getLogger(__name__) 

64 logging_mw = LoggingMiddleware(logger) 

65 rate_limit_mw = RateLimitMiddleware(max_requests=10, window_seconds=60) 

66 filter_mw = ContentFilterMiddleware( 

67 filter_words=["inappropriate"], 

68 replacement="[FILTERED]" 

69 ) 

70 

71 # Create conversation with middleware stack 

72 # Execution: Logging → RateLimit → Filter → LLM → Filter → RateLimit → Logging 

73 manager = await ConversationManager.create( 

74 llm=llm, 

75 prompt_builder=builder, 

76 storage=storage, 

77 middleware=[logging_mw, rate_limit_mw, filter_mw] 

78 ) 

79 

80 # All requests will go through middleware pipeline 

81 await manager.add_message(role="user", content="Hello") 

82 response = await manager.complete() # Middleware applied automatically 

83 ``` 

84 

85See Also: 

86 ConversationManager: Uses middleware for all LLM interactions 

87 ConversationMiddleware: Base class for custom middleware 

88""" 

89 

90from abc import ABC, abstractmethod 

91from typing import List, Any, Dict, Callable 

92import logging 

93 

94from dataknobs_llm.llm import LLMMessage, LLMResponse 

95from dataknobs_llm.llm.providers import AsyncLLMProvider 

96from dataknobs_llm.conversations.storage import ConversationState 

97from dataknobs_llm.prompts import AsyncPromptBuilder 

98from dataknobs_llm.exceptions import RateLimitError 

99 

100 

101class ConversationMiddleware(ABC): 

102 """Base class for conversation middleware. 

103 

104 Middleware can process requests before LLM and responses after LLM. 

105 Middleware is executed in order for requests, and in reverse order 

106 for responses (onion pattern). 

107 

108 Execution Order: 

109 Given middleware list [MW0, MW1, MW2]: 

110 

111 - **Request**: MW0 → MW1 → MW2 → LLM 

112 - **Response**: LLM → MW2 → MW1 → MW0 

113 

114 This allows MW0 to: 

115 1. Start a timer in `process_request()` 

116 2. See the LLM call complete 

117 3. Stop the timer in `process_response()` and log total time 

118 

119 Use Cases: 

120 - **Logging**: Track request/response details 

121 - **Validation**: Verify request/response content 

122 - **Transformation**: Modify messages or responses 

123 - **Rate Limiting**: Enforce API usage limits 

124 - **Caching**: Store/retrieve responses 

125 - **Monitoring**: Collect metrics and analytics 

126 - **Security**: Filter sensitive information 

127 

128 Example: 

129 ```python 

130 from dataknobs_llm.conversations import ConversationMiddleware 

131 import time 

132 

133 class TimingMiddleware(ConversationMiddleware): 

134 '''Measure LLM call duration.''' 

135 

136 async def process_request(self, messages, state): 

137 # Store start time in state metadata 

138 state.metadata["request_start"] = time.time() 

139 return messages 

140 

141 async def process_response(self, response, state): 

142 # Calculate elapsed time 

143 start = state.metadata.get("request_start") 

144 if start: 

145 elapsed = time.time() - start 

146 if not response.metadata: 

147 response.metadata = {} 

148 response.metadata["llm_duration_seconds"] = elapsed 

149 print(f"LLM call took {elapsed:.2f}s") 

150 return response 

151 

152 # Use in conversation 

153 manager = await ConversationManager.create( 

154 llm=llm, 

155 middleware=[TimingMiddleware()] 

156 ) 

157 ``` 

158 

159 Note: 

160 **Performance Tips**: 

161 

162 - Keep `process_request()` and `process_response()` fast 

163 - Use async I/O (await) for external calls (DB, network) 

164 - Don't block the async loop with synchronous operations 

165 - For expensive operations, consider running them in background tasks 

166 - Store state in `state.metadata` not instance variables (thread safety) 

167 

168 See Also: 

169 LoggingMiddleware: Example implementation 

170 ConversationManager.complete: Where middleware is executed 

171 """ 

172 

173 @abstractmethod 

174 async def process_request( 

175 self, 

176 messages: List[LLMMessage], 

177 state: ConversationState 

178 ) -> List[LLMMessage]: 

179 """Process messages before sending to LLM. 

180 

181 Args: 

182 messages: Messages to send to LLM 

183 state: Current conversation state 

184 

185 Returns: 

186 Processed messages (can modify, add, or remove messages) 

187 

188 Example: 

189 >>> from datetime import datetime 

190 >>> async def process_request(self, messages, state): 

191 ... # Add timestamp to metadata 

192 ... for msg in messages: 

193 ... if not msg.metadata: 

194 ... msg.metadata = {} 

195 ... msg.metadata["timestamp"] = datetime.now().isoformat() 

196 ... return messages 

197 """ 

198 pass 

199 

200 @abstractmethod 

201 async def process_response( 

202 self, 

203 response: LLMResponse, 

204 state: ConversationState 

205 ) -> LLMResponse: 

206 """Process response from LLM. 

207 

208 Args: 

209 response: LLM response 

210 state: Current conversation state 

211 

212 Returns: 

213 Processed response (can modify content, metadata, etc.) 

214 

215 Example: 

216 >>> from datetime import datetime 

217 >>> async def process_response(self, response, state): 

218 ... # Add processing metadata 

219 ... if not response.metadata: 

220 ... response.metadata = {} 

221 ... response.metadata["processed_at"] = datetime.now().isoformat() 

222 ... return response 

223 """ 

224 pass 

225 

226 

227class LoggingMiddleware(ConversationMiddleware): 

228 """Middleware that logs all requests and responses. 

229 

230 This middleware is useful for debugging and monitoring conversations. 

231 It logs message counts, conversation IDs, and response metadata. 

232 

233 Example: 

234 >>> import logging 

235 >>> logger = logging.getLogger(__name__) 

236 >>> logging.basicConfig(level=logging.INFO) 

237 >>> 

238 >>> middleware = LoggingMiddleware(logger) 

239 >>> manager = await ConversationManager.create( 

240 ... llm=llm, 

241 ... prompt_builder=builder, 

242 ... storage=storage, 

243 ... middleware=[middleware] 

244 ... ) 

245 """ 

246 

247 def __init__(self, logger: logging.Logger | None = None): 

248 """Initialize logging middleware. 

249 

250 Args: 

251 logger: Logger instance to use (defaults to module logger) 

252 """ 

253 self.logger = logger or logging.getLogger(__name__) 

254 

255 async def process_request( 

256 self, 

257 messages: List[LLMMessage], 

258 state: ConversationState 

259 ) -> List[LLMMessage]: 

260 """Log request details before sending to LLM.""" 

261 self.logger.info( 

262 f"Conversation {state.conversation_id} - " 

263 f"Sending {len(messages)} messages to LLM" 

264 ) 

265 self.logger.debug( 

266 f"Conversation {state.conversation_id} - " 

267 f"Message roles: {[msg.role for msg in messages]}" 

268 ) 

269 return messages 

270 

271 async def process_response( 

272 self, 

273 response: LLMResponse, 

274 state: ConversationState 

275 ) -> LLMResponse: 

276 """Log response details after receiving from LLM.""" 

277 content_length = len(response.content) if response.content else 0 

278 self.logger.info( 

279 f"Conversation {state.conversation_id} - " 

280 f"Received response: {content_length} chars, " 

281 f"model={response.model}, finish_reason={response.finish_reason}" 

282 ) 

283 if response.usage: 

284 self.logger.debug( 

285 f"Conversation {state.conversation_id} - " 

286 f"Token usage: {response.usage}" 

287 ) 

288 return response 

289 

290 

291class ContentFilterMiddleware(ConversationMiddleware): 

292 """Middleware that filters inappropriate content from responses. 

293 

294 This middleware can be used to redact or replace specific words or 

295 patterns in LLM responses. Useful for content moderation and compliance. 

296 

297 Example: 

298 >>> # Filter specific words 

299 >>> middleware = ContentFilterMiddleware( 

300 ... filter_words=["badword1", "badword2"], 

301 ... replacement="[FILTERED]" 

302 ... ) 

303 >>> 

304 >>> # Case-insensitive filtering 

305 >>> middleware = ContentFilterMiddleware( 

306 ... filter_words=["sensitive"], 

307 ... case_sensitive=False 

308 ... ) 

309 """ 

310 

311 def __init__( 

312 self, 

313 filter_words: List[str], 

314 replacement: str = "[FILTERED]", 

315 case_sensitive: bool = True 

316 ): 

317 """Initialize content filter middleware. 

318 

319 Args: 

320 filter_words: List of words/phrases to filter 

321 replacement: String to replace filtered content with 

322 case_sensitive: Whether filtering should be case-sensitive 

323 """ 

324 self.filter_words = filter_words 

325 self.replacement = replacement 

326 self.case_sensitive = case_sensitive 

327 

328 async def process_request( 

329 self, 

330 messages: List[LLMMessage], 

331 state: ConversationState 

332 ) -> List[LLMMessage]: 

333 """Pass through requests without filtering.""" 

334 return messages 

335 

336 async def process_response( 

337 self, 

338 response: LLMResponse, 

339 state: ConversationState 

340 ) -> LLMResponse: 

341 """Filter inappropriate content from response.""" 

342 content = response.content 

343 

344 for word in self.filter_words: 

345 if self.case_sensitive: 

346 content = content.replace(word, self.replacement) 

347 else: 

348 # Case-insensitive replacement 

349 import re 

350 pattern = re.compile(re.escape(word), re.IGNORECASE) 

351 content = pattern.sub(self.replacement, content) 

352 

353 # Track if any filtering occurred 

354 if content != response.content: 

355 if not response.metadata: 

356 response.metadata = {} 

357 response.metadata["content_filtered"] = True 

358 response.content = content 

359 

360 return response 

361 

362 

363class ValidationMiddleware(ConversationMiddleware): 

364 """Middleware that validates LLM responses using another LLM call. 

365 

366 This middleware uses a validation prompt and a separate LLM call to check 

367 if responses meet certain criteria. Can optionally retry on validation failure. 

368 

369 Example: 

370 >>> from dataknobs_llm.llm.providers import OpenAIProvider 

371 >>> from dataknobs_llm.llm.base import LLMConfig 

372 >>> 

373 >>> # Create validation middleware 

374 >>> config = LLMConfig(provider="openai", model="gpt-4") 

375 >>> validation_llm = OpenAIProvider(config) 

376 >>> middleware = ValidationMiddleware( 

377 ... llm=validation_llm, 

378 ... prompt_builder=builder, 

379 ... validation_prompt="validate_response", 

380 ... auto_retry=False # Raise error instead of retrying 

381 ... ) 

382 >>> 

383 >>> # Validation prompt should ask the LLM to respond with 

384 >>> # "VALID" or "INVALID" based on the response content 

385 """ 

386 

387 def __init__( 

388 self, 

389 llm: AsyncLLMProvider, 

390 prompt_builder: AsyncPromptBuilder, 

391 validation_prompt: str, 

392 auto_retry: bool = False, 

393 retry_limit: int = 3 

394 ): 

395 """Initialize validation middleware. 

396 

397 Args: 

398 llm: LLM provider to use for validation (required) 

399 prompt_builder: Prompt builder for rendering validation prompt 

400 validation_prompt: Name of validation prompt template 

401 auto_retry: Whether to automatically retry on validation failure 

402 retry_limit: Maximum number of retries if auto_retry is True 

403 """ 

404 self.llm: AsyncLLMProvider = llm 

405 self.builder: AsyncPromptBuilder = prompt_builder 

406 self.validation_prompt = validation_prompt 

407 self.auto_retry = auto_retry 

408 self.retry_limit = retry_limit 

409 

410 async def process_request( 

411 self, 

412 messages: List[LLMMessage], 

413 state: ConversationState 

414 ) -> List[LLMMessage]: 

415 """Pass through requests without validation.""" 

416 return messages 

417 

418 async def process_response( 

419 self, 

420 response: LLMResponse, 

421 state: ConversationState 

422 ) -> LLMResponse: 

423 """Validate response by calling LLM with validation prompt.""" 

424 # Render validation prompt with response content 

425 validation_prompt_result = await self.builder.render_user_prompt( 

426 self.validation_prompt, 

427 index=0, 

428 params={"response": response.content}, 

429 include_rag=False # Don't need RAG for validation 

430 ) 

431 

432 # Create message and call LLM to get validation judgment 

433 validation_message = LLMMessage( 

434 role="user", 

435 content=validation_prompt_result.content 

436 ) 

437 validation_response = await self.llm.complete([validation_message]) 

438 

439 # Check if LLM says response is valid 

440 is_valid = self._check_validity(validation_response.content) 

441 

442 if not is_valid: 

443 # Track validation failure 

444 if not response.metadata: 

445 response.metadata = {} 

446 response.metadata["validation_failed"] = True 

447 response.metadata["validation_response"] = validation_response.content 

448 

449 if self.auto_retry: 

450 # Note: Actual retry logic would need to be implemented 

451 # at the ConversationManager level. This just marks the failure. 

452 response.metadata["retry_requested"] = True 

453 else: 

454 raise ValueError( 

455 f"Response failed validation: {validation_response.content}" 

456 ) 

457 

458 return response 

459 

460 def _check_validity(self, validation_response: str) -> bool: 

461 """Check if validation response indicates success. 

462 

463 Args: 

464 validation_response: Content from validation prompt response 

465 

466 Returns: 

467 True if valid, False otherwise 

468 """ 

469 # Simple implementation: look for "VALID" in response 

470 # This can be customized based on validation prompt design 

471 return "VALID" in validation_response.upper() 

472 

473 

474class MetadataMiddleware(ConversationMiddleware): 

475 """Middleware that adds custom metadata to messages and responses. 

476 

477 This middleware can inject metadata into both requests and responses, 

478 which is useful for tracking, analytics, and debugging. 

479 

480 Example: 

481 >>> from datetime import datetime 

482 >>> 

483 >>> # Add environment info to all messages 

484 >>> middleware = MetadataMiddleware( 

485 ... request_metadata={"environment": "production"}, 

486 ... response_metadata={"version": "1.0.0"} 

487 ... ) 

488 >>> 

489 >>> # Add dynamic metadata via callback 

490 >>> def get_request_meta(): 

491 ... return {"timestamp": datetime.now().isoformat()} 

492 >>> 

493 >>> middleware = MetadataMiddleware( 

494 ... request_metadata_fn=get_request_meta 

495 ... ) 

496 """ 

497 

498 def __init__( 

499 self, 

500 request_metadata: Dict[str, Any] | None = None, 

501 response_metadata: Dict[str, Any] | None = None, 

502 request_metadata_fn: Callable[..., Dict[str, Any]] | None = None, 

503 response_metadata_fn: Callable[..., Dict[str, Any]] | None = None 

504 ): 

505 """Initialize metadata middleware. 

506 

507 Args: 

508 request_metadata: Static metadata to add to requests 

509 response_metadata: Static metadata to add to responses 

510 request_metadata_fn: Callable that returns metadata for requests 

511 response_metadata_fn: Callable that returns metadata for responses 

512 """ 

513 self.request_metadata = request_metadata or {} 

514 self.response_metadata = response_metadata or {} 

515 self.request_metadata_fn = request_metadata_fn 

516 self.response_metadata_fn = response_metadata_fn 

517 

518 async def process_request( 

519 self, 

520 messages: List[LLMMessage], 

521 state: ConversationState 

522 ) -> List[LLMMessage]: 

523 """Add metadata to request messages.""" 

524 # Collect metadata to add 

525 metadata_to_add = dict(self.request_metadata) 

526 

527 # Add dynamic metadata if function provided 

528 if self.request_metadata_fn: 

529 dynamic_metadata = self.request_metadata_fn() 

530 metadata_to_add.update(dynamic_metadata) 

531 

532 # Add metadata to each message 

533 if metadata_to_add: 

534 for msg in messages: 

535 if not msg.metadata: 

536 msg.metadata = {} 

537 msg.metadata.update(metadata_to_add) 

538 

539 return messages 

540 

541 async def process_response( 

542 self, 

543 response: LLMResponse, 

544 state: ConversationState 

545 ) -> LLMResponse: 

546 """Add metadata to response.""" 

547 # Collect metadata to add 

548 metadata_to_add = dict(self.response_metadata) 

549 

550 # Add dynamic metadata if function provided 

551 if self.response_metadata_fn: 

552 dynamic_metadata = self.response_metadata_fn() 

553 metadata_to_add.update(dynamic_metadata) 

554 

555 # Add metadata to response 

556 if metadata_to_add: 

557 if not response.metadata: 

558 response.metadata = {} 

559 response.metadata.update(metadata_to_add) 

560 

561 return response 

562 

563 

564class RateLimitMiddleware(ConversationMiddleware): 

565 """Middleware that enforces rate limiting on LLM requests. 

566 

567 This middleware tracks request rates per conversation or per client 

568 and raises an exception when the rate limit is exceeded. Rate limits 

569 are tracked in-memory using a sliding window algorithm. 

570 

571 Example: 

572 >>> # Limit to 10 requests per minute 

573 >>> middleware = RateLimitMiddleware( 

574 ... max_requests=10, 

575 ... window_seconds=60 

576 ... ) 

577 >>> 

578 >>> # Per-client rate limiting 

579 >>> middleware = RateLimitMiddleware( 

580 ... max_requests=100, 

581 ... window_seconds=3600, 

582 ... scope="client_id" # Rate limit per client 

583 ... ) 

584 >>> 

585 >>> # With custom key function 

586 >>> def get_user_id(state): 

587 ... return state.metadata.get("user_id") 

588 >>> 

589 >>> middleware = RateLimitMiddleware( 

590 ... max_requests=50, 

591 ... window_seconds=60, 

592 ... key_fn=get_user_id 

593 ... ) 

594 """ 

595 

596 def __init__( 

597 self, 

598 max_requests: int, 

599 window_seconds: int = 60, 

600 scope: str = "conversation", # "conversation" or "client_id" 

601 key_fn: Callable[[ConversationState], str] | None = None 

602 ): 

603 """Initialize rate limiting middleware. 

604 

605 Args: 

606 max_requests: Maximum number of requests allowed in window 

607 window_seconds: Time window in seconds for rate limiting 

608 scope: Scope for rate limiting ("conversation" or "client_id") 

609 key_fn: Optional custom function to extract rate limit key from state 

610 """ 

611 self.max_requests = max_requests 

612 self.window_seconds = window_seconds 

613 self.scope = scope 

614 self.key_fn = key_fn 

615 

616 # In-memory storage: key -> list of request timestamps 

617 self._request_history: Dict[str, List[float]] = {} 

618 

619 def _get_rate_limit_key(self, state: ConversationState) -> str: 

620 """Get the key to use for rate limiting. 

621 

622 Args: 

623 state: Conversation state 

624 

625 Returns: 

626 Rate limit key 

627 """ 

628 if self.key_fn: 

629 return self.key_fn(state) 

630 elif self.scope == "client_id": 

631 return state.metadata.get("client_id", state.conversation_id) 

632 else: 

633 return state.conversation_id 

634 

635 def _clean_old_requests(self, key: str, current_time: float) -> None: 

636 """Remove requests outside the time window. 

637 

638 Args: 

639 key: Rate limit key 

640 current_time: Current timestamp 

641 """ 

642 if key in self._request_history: 

643 cutoff_time = current_time - self.window_seconds 

644 self._request_history[key] = [ 

645 ts for ts in self._request_history[key] 

646 if ts > cutoff_time 

647 ] 

648 

649 def _check_rate_limit(self, key: str, current_time: float) -> tuple[bool, int]: 

650 """Check if request is within rate limit. 

651 

652 Args: 

653 key: Rate limit key 

654 current_time: Current timestamp 

655 

656 Returns: 

657 Tuple of (is_allowed, current_count) 

658 """ 

659 # Clean old requests 

660 self._clean_old_requests(key, current_time) 

661 

662 # Check current count 

663 if key not in self._request_history: 

664 self._request_history[key] = [] 

665 

666 current_count = len(self._request_history[key]) 

667 is_allowed = current_count < self.max_requests 

668 

669 return is_allowed, current_count 

670 

671 def _record_request(self, key: str, current_time: float) -> None: 

672 """Record a new request. 

673 

674 Args: 

675 key: Rate limit key 

676 current_time: Current timestamp 

677 """ 

678 if key not in self._request_history: 

679 self._request_history[key] = [] 

680 

681 self._request_history[key].append(current_time) 

682 

683 async def process_request( 

684 self, 

685 messages: List[LLMMessage], 

686 state: ConversationState 

687 ) -> List[LLMMessage]: 

688 """Check rate limit before allowing request through.""" 

689 import time 

690 

691 current_time = time.time() 

692 key = self._get_rate_limit_key(state) 

693 

694 # Check rate limit 

695 is_allowed, current_count = self._check_rate_limit(key, current_time) 

696 

697 if not is_allowed: 

698 # Add rate limit info to state metadata for debugging 

699 if not state.metadata: 

700 state.metadata = {} 

701 state.metadata["rate_limit_exceeded"] = True 

702 state.metadata["rate_limit_count"] = current_count 

703 state.metadata["rate_limit_max"] = self.max_requests 

704 state.metadata["rate_limit_window"] = self.window_seconds 

705 

706 raise RateLimitError( 

707 f"Rate limit exceeded: {current_count}/{self.max_requests} " 

708 f"requests in {self.window_seconds}s window" 

709 ) 

710 

711 # Record this request 

712 self._record_request(key, current_time) 

713 

714 # Add rate limit info to messages metadata 

715 for msg in messages: 

716 if not msg.metadata: 

717 msg.metadata = {} 

718 msg.metadata["rate_limit_count"] = current_count + 1 

719 msg.metadata["rate_limit_max"] = self.max_requests 

720 

721 return messages 

722 

723 async def process_response( 

724 self, 

725 response: LLMResponse, 

726 state: ConversationState 

727 ) -> LLMResponse: 

728 """Add rate limit info to response metadata.""" 

729 key = self._get_rate_limit_key(state) 

730 

731 if key in self._request_history: 

732 current_count = len(self._request_history[key]) 

733 

734 if not response.metadata: 

735 response.metadata = {} 

736 

737 response.metadata["rate_limit_count"] = current_count 

738 response.metadata["rate_limit_max"] = self.max_requests 

739 response.metadata["rate_limit_remaining"] = self.max_requests - current_count 

740 

741 return response 

742 

743 def get_rate_limit_status(self, key: str) -> Dict[str, Any]: 

744 """Get current rate limit status for a key. 

745 

746 Args: 

747 key: Rate limit key 

748 

749 Returns: 

750 Dictionary with rate limit status 

751 

752 Example: 

753 >>> status = middleware.get_rate_limit_status("client-abc") 

754 >>> print(status) 

755 { 

756 'current_count': 5, 

757 'max_requests': 10, 

758 'remaining': 5, 

759 'window_seconds': 60, 

760 'next_reset': 45.2 # seconds until oldest request expires 

761 } 

762 """ 

763 import time 

764 

765 current_time = time.time() 

766 self._clean_old_requests(key, current_time) 

767 

768 if key not in self._request_history or not self._request_history[key]: 

769 return { 

770 'current_count': 0, 

771 'max_requests': self.max_requests, 

772 'remaining': self.max_requests, 

773 'window_seconds': self.window_seconds, 

774 'next_reset': 0 

775 } 

776 

777 current_count = len(self._request_history[key]) 

778 oldest_request = min(self._request_history[key]) 

779 next_reset = max(0, (oldest_request + self.window_seconds) - current_time) 

780 

781 return { 

782 'current_count': current_count, 

783 'max_requests': self.max_requests, 

784 'remaining': max(0, self.max_requests - current_count), 

785 'window_seconds': self.window_seconds, 

786 'next_reset': next_reset 

787 } 

788 

789 def reset(self, key: str | None = None) -> None: 

790 """Reset rate limit for a specific key or all keys. 

791 

792 Args: 

793 key: Key to reset. If None, resets all keys. 

794 

795 Example: 

796 >>> # Reset specific client 

797 >>> middleware.reset("client-abc") 

798 >>> 

799 >>> # Reset all 

800 >>> middleware.reset() 

801 """ 

802 if key is None: 

803 self._request_history.clear() 

804 elif key in self._request_history: 

805 del self._request_history[key]