Coverage for src / beautyspot / cache.py: 89%
207 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-03-11 19:10 +0900
« prev ^ index » next coverage.py v7.13.2, created at 2026-03-11 19:10 +0900
1# src/beautyspot/cache.py
3import hashlib
4import logging
5import threading
6import time
7import asyncio
8import contextlib
9from datetime import datetime, timezone
10from typing import Any, Callable, Optional, NamedTuple, Generator, AsyncGenerator
12from beautyspot.db import TaskDBMaintenable
13from beautyspot.storage import BlobStorageMaintenable, StoragePolicyProtocol
14from beautyspot.serializer import SerializerProtocol
15from beautyspot.lifecycle import (
16 LifecyclePolicy,
17 RetentionSpec,
18 parse_retention,
19 _ForeverSentinel,
20 _FOREVER,
21)
22from beautyspot.cachekey import KeyGen
23from beautyspot.exceptions import CacheCorruptedError
24from beautyspot.content_types import ContentType
26logger = logging.getLogger(__name__)
27logger.addHandler(logging.NullHandler())
29# --- キャッシュミスを表す番兵オブジェクト ---
30CACHE_MISS = object()
33class HerdWaitResult(NamedTuple):
34 """Thundering Herd 待機フェーズの結果。"""
36 is_executor: bool # True: 自分が実行者になった
37 result: Any # is_executor=False のときの結果 or 例外
38 event: threading.Event | None # is_executor=True のときのイベント
39 result_box: list # is_executor=True のときの共有リスト
40 is_error: bool # result が例外の場合 True
43class CacheManager:
44 """
45 キャッシュの読み書き、キー生成、および並行実行制御(Thundering Herd対策)を
46 担当するコンポーネント。
47 """
49 HERD_POLL: float = 5.0
50 HERD_TIMEOUT: float = 300.0
51 HERD_MAX_RETRIES: int = 3
53 def __init__(
54 self,
55 db: TaskDBMaintenable,
56 storage: BlobStorageMaintenable,
57 serializer: SerializerProtocol,
58 storage_policy: StoragePolicyProtocol,
59 lifecycle_policy: Optional[LifecyclePolicy] = None,
60 ):
61 self.db = db
62 self.storage = storage
63 self.serializer = serializer
64 self.storage_policy = storage_policy
66 if lifecycle_policy is not None:
67 self.lifecycle_policy = lifecycle_policy
68 else:
69 self.lifecycle_policy = LifecyclePolicy.default()
71 # サンダリングハード対策: 同一キーの並行実行を直列化する
72 # tuple: (threading.Event, list[asyncio.Future], list[result])
73 self._inflight: dict[
74 str, tuple[threading.Event, list[asyncio.Future], list]
75 ] = {}
76 self._inflight_lock = threading.Lock()
78 def make_cache_key(
79 self,
80 func_identifier: str,
81 args: tuple,
82 kwargs: dict,
83 resolved_key_fn: Optional[Callable],
84 version: str | None,
85 ) -> tuple[str, str]:
86 """キャッシュキーと入力IDを生成する。"""
87 iid = (
88 resolved_key_fn(*args, **kwargs)
89 if resolved_key_fn
90 else KeyGen._default(args, kwargs)
91 )
93 key_source = f"{func_identifier}:{iid}"
94 if version:
95 key_source += f":{version}"
97 ck = hashlib.sha256(key_source.encode()).hexdigest()
98 return iid, ck
100 def calculate_expires_at(
101 self,
102 func_identifier: str,
103 func_name: str,
104 local_retention: RetentionSpec,
105 ) -> Optional[datetime]:
106 """有効期限を計算する。"""
107 if local_retention is _FOREVER:
108 return None
110 if isinstance(local_retention, _ForeverSentinel): 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true
111 raise RuntimeError(
112 "Internal Error: _ForeverSentinel reached calculate_expires_at."
113 )
115 retention = parse_retention(local_retention)
117 if retention is None:
118 retention = self.lifecycle_policy.resolve_with_fallback(
119 func_identifier, func_name
120 )
122 if retention is None: 122 ↛ 123line 122 didn't jump to line 123 because the condition on line 122 was never true
123 return None
125 return datetime.now(timezone.utc) + retention
127 def get(
128 self, cache_key: str, serializer: Optional[SerializerProtocol] = None
129 ) -> Any:
130 """同期的にキャッシュから値を取得する。"""
131 use_serializer = serializer or self.serializer
132 entry = self.db.get(cache_key)
134 if not entry:
135 return CACHE_MISS
137 r_type = entry["result_type"]
138 r_val = entry["result_value"]
139 r_data = entry.get("result_data")
141 try:
142 if r_type == "DIRECT_BLOB":
143 if r_data is None:
144 return CACHE_MISS
145 return use_serializer.loads(r_data)
147 elif r_type == "FILE":
148 if r_val is None:
149 logger.warning(
150 f"Data corruption: 'FILE' record has no path for key `{cache_key}`"
151 )
152 return CACHE_MISS
153 data_bytes = self.storage.load(r_val)
154 return use_serializer.loads(data_bytes)
156 else:
157 logger.warning(
158 f"Unknown result_type '{r_type}' for cache_key `{cache_key}`"
159 )
160 return CACHE_MISS
162 except CacheCorruptedError as e:
163 logger.debug(f"Cache corrupted for {cache_key}: {e}")
164 return CACHE_MISS
165 except Exception as e:
166 logger.error(
167 f"Failed to deserialize cache for `{cache_key}`: {e}", exc_info=True
168 )
169 return CACHE_MISS
171 def set(
172 self,
173 cache_key: str,
174 func_name: str,
175 func_identifier: str,
176 input_id: str,
177 version: str | None,
178 result: Any,
179 content_type: str | ContentType | None,
180 save_blob: bool | None,
181 expires_at: Optional[datetime] = None,
182 serializer: Optional[SerializerProtocol] = None,
183 ) -> None:
184 """同期的にキャッシュへ値を保存する。"""
185 use_serializer = serializer or self.serializer
186 data_bytes = use_serializer.dumps(result)
188 should_use_blob = save_blob
189 if should_use_blob is None:
190 should_use_blob = self.storage_policy.should_save_as_blob(data_bytes)
192 if should_use_blob:
193 r_val = self.storage.save(cache_key, data_bytes)
194 try:
195 self.db.save(
196 cache_key=cache_key,
197 func_name=func_name,
198 func_identifier=func_identifier,
199 input_id=input_id,
200 version=version,
201 result_type="FILE",
202 content_type=content_type,
203 result_value=r_val,
204 result_data=None,
205 expires_at=expires_at,
206 )
207 except Exception:
208 try:
209 self.storage.delete(r_val)
210 except Exception as rollback_err:
211 logger.warning(f"Failed to rollback blob '{r_val}': {rollback_err}")
212 raise
213 else:
214 self.db.save(
215 cache_key=cache_key,
216 func_name=func_name,
217 func_identifier=func_identifier,
218 input_id=input_id,
219 version=version,
220 result_type="DIRECT_BLOB",
221 content_type=content_type,
222 result_value=None,
223 result_data=data_bytes,
224 expires_at=expires_at,
225 )
227 # --- Thundering Herd Protection ---
229 @contextlib.contextmanager
230 def herd_sync(
231 self, cache_key: str, serializer: Optional[SerializerProtocol] = None
232 ) -> Generator[HerdWaitResult, None, None]:
233 """同期パスでの Thundering Herd 保護コンテキストマネージャ。"""
234 herd = self.wait_herd_sync(cache_key, serializer)
235 try:
236 yield herd
237 finally:
238 if herd.is_executor:
239 self.notify_and_cleanup_inflight(cache_key, herd.event, herd.result_box)
241 @contextlib.asynccontextmanager
242 async def herd_async(
243 self,
244 cache_key: str,
245 serializer: Optional[SerializerProtocol],
246 loop: asyncio.AbstractEventLoop,
247 executor: Any,
248 ) -> AsyncGenerator[HerdWaitResult, None]:
249 """非同期パスでの Thundering Herd 保護コンテキストマネージャ。"""
250 herd = await self.wait_herd_async(cache_key, serializer, loop, executor)
251 try:
252 yield herd
253 finally:
254 if herd.is_executor:
255 self.notify_and_cleanup_inflight(cache_key, herd.event, herd.result_box)
257 def wait_herd_sync(
258 self, cache_key: str, serializer: Optional[SerializerProtocol] = None
259 ) -> HerdWaitResult:
260 """同期パスでの Thundering Herd 待機。"""
261 retries = 0
262 while True:
263 with self._inflight_lock:
264 if cache_key not in self._inflight:
265 event = threading.Event()
266 result_box: list = []
267 self._inflight[cache_key] = (event, [], result_box)
268 return HerdWaitResult(True, None, event, result_box, False)
270 wait_event, _, wait_box = self._inflight[cache_key]
272 deadline = time.monotonic() + self.HERD_TIMEOUT
273 while not wait_event.wait(timeout=self.HERD_POLL):
274 if time.monotonic() >= deadline:
275 retries += 1
276 if retries > self.HERD_MAX_RETRIES: 276 ↛ 277line 276 didn't jump to line 277 because the condition on line 276 was never true
277 raise TimeoutError(f"Herd wait timeout for {cache_key} exceeded max retries ({self.HERD_MAX_RETRIES})")
278 logger.warning(f"Herd wait timeout for {cache_key} (retry {retries}/{self.HERD_MAX_RETRIES})")
279 break
281 if wait_box:
282 success, val = wait_box[0]
283 return HerdWaitResult(False, val, None, [], not success)
285 # 万が一の結果漏れに備えて再チェック
286 cached = self.get(cache_key, serializer)
287 if cached is not CACHE_MISS: 287 ↛ 262line 287 didn't jump to line 262 because the condition on line 287 was always true
288 return HerdWaitResult(False, cached, None, [], False)
290 async def wait_herd_async(
291 self,
292 cache_key: str,
293 serializer: Optional[SerializerProtocol],
294 loop: asyncio.AbstractEventLoop,
295 executor: Any,
296 ) -> HerdWaitResult:
297 """非同期パスでの Thundering Herd 待機。"""
298 retries = 0
299 while True:
300 fut = None
301 with self._inflight_lock:
302 if cache_key not in self._inflight:
303 event = threading.Event()
304 result_box: list = []
305 self._inflight[cache_key] = (event, [], result_box)
306 return HerdWaitResult(True, None, event, result_box, False)
308 wait_event, futs, wait_box = self._inflight[cache_key]
310 # すでに結果がある場合は即座に返す
311 if wait_box: 311 ↛ 312line 311 didn't jump to line 312 because the condition on line 311 was never true
312 success, val = wait_box[0]
313 return HerdWaitResult(False, val, None, [], not success)
315 # 同じループの既存の Future があれば再利用する(肥大化防止)
316 for f in futs:
317 if f.get_loop() is loop and not f.done(): 317 ↛ 316line 317 didn't jump to line 316 because the condition on line 317 was always true
318 fut = f
319 break
320 if fut is None:
321 fut = loop.create_future()
322 futs.append(fut)
324 signal = await self._await_herd_signal_async(
325 fut, wait_event, wait_box, cache_key, loop, executor
326 )
327 if signal is None:
328 retries += 1
329 if retries > self.HERD_MAX_RETRIES: 329 ↛ 330line 329 didn't jump to line 330 because the condition on line 329 was never true
330 raise TimeoutError(f"Herd wait timeout for {cache_key} exceeded max retries ({self.HERD_MAX_RETRIES})")
331 logger.warning(f"Herd wait timeout for {cache_key} (retry {retries}/{self.HERD_MAX_RETRIES})")
332 continue
334 success, val = signal
335 return HerdWaitResult(False, val, None, [], not success)
337 async def _await_herd_signal_async(
338 self,
339 fut: Optional[asyncio.Future],
340 wait_event: threading.Event,
341 wait_box: list,
342 cache_key: str,
343 loop: asyncio.AbstractEventLoop,
344 executor: Any,
345 ) -> Optional[tuple[bool, Any]]:
346 if fut is not None: 346 ↛ 357line 346 didn't jump to line 357 because the condition on line 346 was always true
347 try:
348 val = await asyncio.wait_for(
349 asyncio.shield(fut), timeout=self.HERD_TIMEOUT
350 )
351 return (True, val)
352 except asyncio.TimeoutError:
353 return None
354 except Exception as e:
355 return (False, e)
357 if wait_box:
358 return wait_box[0]
360 deadline = time.monotonic() + self.HERD_TIMEOUT
361 while not await loop.run_in_executor(executor, wait_event.wait, self.HERD_POLL):
362 if time.monotonic() >= deadline:
363 return None
365 return wait_box[0] if wait_box else None
367 def notify_and_cleanup_inflight(
368 self,
369 cache_key: str,
370 event: Optional[threading.Event],
371 result_box: list,
372 ) -> None:
373 """待機中のスレッド/タスクに通知し、管理情報を削除する。"""
374 futs_to_notify: list = []
375 with self._inflight_lock:
376 val = self._inflight.get(cache_key)
377 if val is not None and val[0] is event:
378 _, futs_to_notify, _ = val
379 del self._inflight[cache_key]
381 if event is not None: 381 ↛ exitline 381 didn't return from function 'notify_and_cleanup_inflight' because the condition on line 381 was always true
382 event.set()
383 if result_box and futs_to_notify:
384 success, res_val = result_box[0]
385 for fut in futs_to_notify:
386 if not fut.done(): 386 ↛ 385line 386 didn't jump to line 385 because the condition on line 386 was always true
387 self._notify_future(fut, success, res_val)
389 def _notify_future(self, fut: asyncio.Future, success: bool, val: Any) -> None:
390 def _set():
391 if not fut.done(): 391 ↛ exitline 391 didn't return from function '_set' because the condition on line 391 was always true
392 if success:
393 fut.set_result(val)
394 elif isinstance(val, BaseException):
395 fut.set_exception(val)
396 else:
397 fut.set_exception(RuntimeError(f"Non-Exception error: {repr(val)}"))
399 fut.get_loop().call_soon_threadsafe(_set)