Coverage for src / utils / rate_limiter.py: 95%

43 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-13 20:29 +0800

1"""异步令牌桶限流器模块. 

2 

3提供基于令牌桶算法的异步速率限制功能, 

4支持全局单例模式和异步上下文管理器。 

5""" 

6 

7import asyncio 

8import time 

9from dataclasses import dataclass 

10from typing import Any, Optional 

11 

12 

13@dataclass 

14class RateLimitConfig: 

15 """限流配置数据类. 

16 

17 Attributes: 

18 requests: 时间窗口内允许的最大请求数 

19 window_seconds: 时间窗口秒数 

20 """ 

21 

22 requests: int = 100 

23 window_seconds: int = 60 

24 

25 

26class TokenBucketRateLimiter: 

27 """异步令牌桶限流器. 

28 

29 使用令牌桶算法实现速率限制,支持异步上下文管理器。 

30 全局单例模式确保整个应用共享同一限流器实例。 

31 

32 Example: 

33 >>> limiter = TokenBucketRateLimiter() 

34 >>> async with limiter: 

35 ... await make_api_call() 

36 

37 >>> # 批量获取令牌 

38 >>> await limiter.acquire(tokens=5) 

39 """ 

40 

41 _instance: Optional["TokenBucketRateLimiter"] = None 

42 _lock: asyncio.Lock = asyncio.Lock() 

43 

44 def __new__(cls, config: RateLimitConfig | None = None) -> "TokenBucketRateLimiter": 

45 if cls._instance is None: 

46 cls._instance = super().__new__(cls) 

47 return cls._instance 

48 

49 def __init__(self, config: RateLimitConfig | None = None) -> None: 

50 """初始化限流器. 

51 

52 Args: 

53 config: 限流配置,使用默认配置 if None 

54 """ 

55 # 避免重复初始化 

56 if hasattr(self, "_initialized"): 

57 return 

58 

59 self.config = config or RateLimitConfig() 

60 self.tokens: float = float(self.config.requests) 

61 self.last_update: float = time.monotonic() 

62 self._mutex: asyncio.Lock = asyncio.Lock() 

63 self._initialized: bool = True 

64 

65 async def acquire(self, tokens: int = 1) -> None: 

66 """获取指定数量的令牌. 

67 

68 如果令牌不足,会异步等待直到令牌补充足够。 

69 

70 Args: 

71 tokens: 需要获取的令牌数量,默认为 1 

72 

73 Raises: 

74 ValueError: tokens 小于等于 0 

75 """ 

76 if tokens <= 0: 

77 raise ValueError("tokens must be positive") 

78 

79 async with self._mutex: 

80 now = time.monotonic() 

81 elapsed = now - self.last_update 

82 

83 # 补充令牌 

84 self.tokens = min( 

85 self.config.requests, 

86 self.tokens + elapsed * (self.config.requests / self.config.window_seconds), 

87 ) 

88 self.last_update = now 

89 

90 if self.tokens < tokens: 

91 # 计算需要等待的时间 

92 wait_time = (tokens - self.tokens) * ( 

93 self.config.window_seconds / self.config.requests 

94 ) 

95 await asyncio.sleep(wait_time) 

96 self.tokens = 0 

97 else: 

98 self.tokens -= tokens 

99 

100 async def __aenter__(self) -> "TokenBucketRateLimiter": 

101 """异步上下文管理器入口.""" 

102 await self.acquire() 

103 return self 

104 

105 async def __aexit__( 

106 self, 

107 exc_type: type | None, 

108 exc_val: BaseException | None, 

109 exc_tb: object | None, 

110 ) -> None: 

111 """异步上下文管理器出口.""" 

112 pass 

113 

114 def get_status(self) -> dict[str, Any]: 

115 """获取限流器当前状态. 

116 

117 Returns: 

118 包含当前令牌数和配置信息的字典 

119 """ 

120 return { 

121 "tokens": self.tokens, 

122 "max_tokens": self.config.requests, 

123 "window_seconds": self.config.window_seconds, 

124 "available": self.tokens / self.config.requests, 

125 }