Coverage for src / utils / rate_limiter.py: 95%
43 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 20:29 +0800
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 20:29 +0800
1"""异步令牌桶限流器模块.
3提供基于令牌桶算法的异步速率限制功能,
4支持全局单例模式和异步上下文管理器。
5"""
7import asyncio
8import time
9from dataclasses import dataclass
10from typing import Any, Optional
13@dataclass
14class RateLimitConfig:
15 """限流配置数据类.
17 Attributes:
18 requests: 时间窗口内允许的最大请求数
19 window_seconds: 时间窗口秒数
20 """
22 requests: int = 100
23 window_seconds: int = 60
26class TokenBucketRateLimiter:
27 """异步令牌桶限流器.
29 使用令牌桶算法实现速率限制,支持异步上下文管理器。
30 全局单例模式确保整个应用共享同一限流器实例。
32 Example:
33 >>> limiter = TokenBucketRateLimiter()
34 >>> async with limiter:
35 ... await make_api_call()
37 >>> # 批量获取令牌
38 >>> await limiter.acquire(tokens=5)
39 """
41 _instance: Optional["TokenBucketRateLimiter"] = None
42 _lock: asyncio.Lock = asyncio.Lock()
44 def __new__(cls, config: RateLimitConfig | None = None) -> "TokenBucketRateLimiter":
45 if cls._instance is None:
46 cls._instance = super().__new__(cls)
47 return cls._instance
49 def __init__(self, config: RateLimitConfig | None = None) -> None:
50 """初始化限流器.
52 Args:
53 config: 限流配置,使用默认配置 if None
54 """
55 # 避免重复初始化
56 if hasattr(self, "_initialized"):
57 return
59 self.config = config or RateLimitConfig()
60 self.tokens: float = float(self.config.requests)
61 self.last_update: float = time.monotonic()
62 self._mutex: asyncio.Lock = asyncio.Lock()
63 self._initialized: bool = True
65 async def acquire(self, tokens: int = 1) -> None:
66 """获取指定数量的令牌.
68 如果令牌不足,会异步等待直到令牌补充足够。
70 Args:
71 tokens: 需要获取的令牌数量,默认为 1
73 Raises:
74 ValueError: tokens 小于等于 0
75 """
76 if tokens <= 0:
77 raise ValueError("tokens must be positive")
79 async with self._mutex:
80 now = time.monotonic()
81 elapsed = now - self.last_update
83 # 补充令牌
84 self.tokens = min(
85 self.config.requests,
86 self.tokens + elapsed * (self.config.requests / self.config.window_seconds),
87 )
88 self.last_update = now
90 if self.tokens < tokens:
91 # 计算需要等待的时间
92 wait_time = (tokens - self.tokens) * (
93 self.config.window_seconds / self.config.requests
94 )
95 await asyncio.sleep(wait_time)
96 self.tokens = 0
97 else:
98 self.tokens -= tokens
100 async def __aenter__(self) -> "TokenBucketRateLimiter":
101 """异步上下文管理器入口."""
102 await self.acquire()
103 return self
105 async def __aexit__(
106 self,
107 exc_type: type | None,
108 exc_val: BaseException | None,
109 exc_tb: object | None,
110 ) -> None:
111 """异步上下文管理器出口."""
112 pass
114 def get_status(self) -> dict[str, Any]:
115 """获取限流器当前状态.
117 Returns:
118 包含当前令牌数和配置信息的字典
119 """
120 return {
121 "tokens": self.tokens,
122 "max_tokens": self.config.requests,
123 "window_seconds": self.config.window_seconds,
124 "available": self.tokens / self.config.requests,
125 }