Coverage for src / beautyspot / storage.py: 62%
220 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-03-11 19:10 +0900
« prev ^ index » next coverage.py v7.13.2, created at 2026-03-11 19:10 +0900
1# src/beautyspot/storage.py
3import os
4import io
5import logging
6import tempfile
7import time
8from pathlib import Path
9from abc import ABC, abstractmethod
10from typing import Any, TypeAlias, Iterator, Protocol, runtime_checkable
11from dataclasses import dataclass, field
12from beautyspot.exceptions import CacheCorruptedError, ValidationError
14try:
15 import boto3
16 from botocore.exceptions import ClientError
17except ImportError:
18 boto3 = None
19 ClientError = Exception
22ReadableBuffer: TypeAlias = bytes | bytearray | memoryview
24# --- Storage Policies ---
26@runtime_checkable
27class StoragePolicyProtocol(Protocol):
28 """
29 Protocol to determine if data should be saved as a blob (file/object storage)
30 or directly in the database based on the data content (usually size).
31 """
33 def should_save_as_blob(self, data: bytes) -> bool: ...
36@dataclass
37class ThresholdStoragePolicy(StoragePolicyProtocol):
38 """
39 Policy that saves data as a blob if its size exceeds a configured threshold.
40 This is the recommended policy for automatic optimization.
41 """
43 threshold: int
45 def should_save_as_blob(self, data: bytes) -> bool:
46 return len(data) > self.threshold
49@dataclass
50class WarningOnlyPolicy(StoragePolicyProtocol):
51 """
52 Policy for backward compatibility (v2.0 behavior).
53 Does not force blob storage, but logs a warning if size exceeds threshold.
54 """
56 warning_threshold: int
57 # logger は比較・repr 対象外にする。
58 # dataclass の自動生成 __eq__ に logger インスタンスが混入するのを防ぐ。
59 logger: logging.Logger = field(
60 default_factory=lambda: logging.getLogger("beautyspot"),
61 compare=False,
62 repr=False,
63 )
65 def should_save_as_blob(self, data: bytes) -> bool:
66 if len(data) > self.warning_threshold:
67 self.logger.warning(
68 f"⚠️ Large data detected ({len(data)} bytes). "
69 f"Consider using `save_blob=True` or a stricter StoragePolicy."
70 )
71 return False
74@dataclass
75class AlwaysBlobPolicy(StoragePolicyProtocol):
76 """
77 Policy that always saves data as a blob.
78 Equivalent to setting `default_save_blob=True`.
79 """
81 def should_save_as_blob(self, data: bytes) -> bool:
82 return True
85# --- Blob Storage Implementations ---
87@runtime_checkable
88class BlobStorageCore(Protocol):
89 """
90 Core interface for large object storage required during execution.
91 """
93 def save(self, key: str, data: ReadableBuffer) -> str: ...
94 def load(self, location: str) -> bytes: ...
95 def delete(self, location: str) -> None: ...
97@runtime_checkable
98class Maintenable(Protocol):
99 """
100 Extended interface for maintenance tasks (GC).
101 """
103 def list_keys(self) -> Iterator[str]: ...
104 def get_mtime(self, location: str) -> float: ...
107@runtime_checkable
108class BlobStorageMaintenable(BlobStorageCore, Maintenable, Protocol):
109 ...
112class BlobStorageBase(ABC):
113 """
114 Abstract base class for large object storage (BLOBs).
115 Implementations should at least fulfill BlobStorageCore.
116 """
118 @abstractmethod
119 def save(self, key: str, data: ReadableBuffer) -> str:
120 """
121 Persist the data associated with the given key.
122 Returns a location identifier.
123 """
124 pass
126 @abstractmethod
127 def load(self, location: str) -> bytes:
128 """
129 Retrieve data from the specified location.
130 """
131 pass
133 @abstractmethod
134 def delete(self, location: str) -> None:
135 """
136 Delete the blob at the specified location.
137 Should be idempotent (no error if file missing).
138 """
139 pass
141 @abstractmethod
142 def list_keys(self) -> Iterator[str]:
143 """
144 Yields location identifiers for all stored blobs.
145 Used for Garbage Collection.
146 MUST yield the same format (path/URI) that is accepted by `delete`.
147 """
148 pass
150 @abstractmethod
151 def get_mtime(self, location: str) -> float:
152 """
153 Get the last modified time of the blob as a POSIX timestamp.
154 Used to prevent race conditions during Garbage Collection.
155 """
156 pass
159class LocalStorage(BlobStorageMaintenable):
160 def __init__(self, base_dir: str | Path):
161 # Resolve to absolute path explicitly on init
162 self.base_dir = Path(base_dir).resolve()
163 self._ensure_cache_dir(self.base_dir)
165 @staticmethod
166 def _ensure_cache_dir(directory: Path) -> None:
167 """
168 ディレクトリを作成し、Gitの管理下に入らないよう .gitignore を配置する。
169 """
170 directory.mkdir(parents=True, exist_ok=True)
171 gitignore_path = directory / ".gitignore"
172 if not gitignore_path.exists():
173 try:
174 gitignore_path.write_text("*\n")
175 except OSError as e:
176 # 権限問題などで書けない場合は処理を続行(ログのみ)
177 logging.warning(f"Failed to create .gitignore in {directory}: {e}")
179 def _validate_key(self, key: str):
180 """save() に渡されるキャッシュキーを検証する。
182 Note:
183 この検証は save() の引数(通常は SHA-256 ハッシュ)にのみ適用される。
184 list_keys() が返すロケーション文字列(例: 'subdir/hash.bin')は
185 レガシーデータとの互換性のためにパス区切り文字を含む場合があり、
186 load() / delete() では別途パストラバーサルチェックを行う。
187 """
188 # Prevent Path Traversal
189 if ".." in key or "/" in key or "\\" in key:
190 raise ValidationError(
191 f"Invalid key: '{key}'. Keys must not contain path separators."
192 )
194 def save(self, key: str, data: ReadableBuffer) -> str:
195 """
196 指定されたキーでデータをローカルディスクに保存し、ファイル名(location)を返す。
198 単純な `open(..., 'wb')` による上書きは行わず、`tempfile.mkstemp` で一意な
199 一時ファイルを作成して書き込んだ後、`os.replace` でアトミックにリネームする手法を採用している。
200 これは以下の2点を防ぐためである。
201 1. 並行実行時(複数スレッド/プロセス)に同じキャッシュキーに同時に書き込もうとした際のファイルの競合・破損。
202 2. 書き込み中のプロセス強制終了などによる、不完全で壊れたファイルの残留。
204 Args:
205 key (str): 保存するキャッシュキー
206 data (ReadableBuffer): 保存するバイトデータ
208 Returns:
209 str: 保存されたファイル名
210 """
211 self._validate_key(key)
212 filename = f"{key}.bin"
213 filepath = self.base_dir / filename
215 # Atomic write: mkstemp generates a unique temp file to avoid collisions
216 # when multiple threads/processes write concurrently.
217 # flush + fsync ensures data reaches disk before rename,
218 # so a crash between write and rename never leaves a corrupt file.
219 fd, temp_path_str = tempfile.mkstemp(dir=self.base_dir, suffix=".spot_tmp")
220 try:
221 with os.fdopen(fd, "wb", closefd=True) as f:
222 f.write(data)
223 f.flush()
224 os.fsync(f.fileno())
225 Path(temp_path_str).replace(filepath)
226 except BaseException:
227 try:
228 os.unlink(temp_path_str)
229 except OSError:
230 # PermissionError等で消せなかった場合は残留するが、後でGCが回収する
231 pass
232 raise
234 return filename
236 def load(self, location: str) -> bytes:
237 # [CHANGED] Resolve location relative to base_dir.
238 # Note: If 'location' is an absolute path (legacy data), pathlib behavior
239 # (base / abs) returns abs, so backward compatibility on the same machine is preserved.
240 full_path = (self.base_dir / location).resolve()
242 # Security check: Ensure the path is strictly within the base_dir
243 if not full_path.is_relative_to(self.base_dir):
244 raise CacheCorruptedError(
245 f"Access denied: {location} resolves to {full_path}, which is outside {self.base_dir}"
246 )
248 if not full_path.exists():
249 raise CacheCorruptedError(f"Local blob lost: {full_path}")
251 try:
252 with open(full_path, "rb") as f:
253 return f.read()
254 except OSError as e:
255 raise CacheCorruptedError(f"Failed to read blob: {e}")
257 def delete(self, location: str) -> None:
258 """
259 Delete the file at the given location.
261 Note:
262 For performance reasons, this method does not synchronously remove
263 empty parent directories. Directory cleanup is deferred to the
264 asynchronous maintenance task (`prune_empty_dirs` / `beautyspot gc`).
265 """
266 full_path = (self.base_dir / location).resolve()
268 if not full_path.is_relative_to(self.base_dir): 268 ↛ 269line 268 didn't jump to line 269 because the condition on line 268 was never true
269 return
271 try:
272 os.remove(full_path)
273 except FileNotFoundError:
274 pass
275 except (PermissionError, OSError) as e:
276 # 他のプロセスによるロックや権限の問題をログに残すが、処理は継続する
277 logging.warning(
278 f"Failed to delete blob at {full_path}: {e}. It will be handled by subsequent GC."
279 )
281 def list_keys(self) -> Iterator[str]:
282 """
283 Yields relative paths of all .bin files, including subdirectories.
284 Example: 'hash.bin' or 'subdir/hash.bin'
285 """
286 if not self.base_dir.exists(): 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true
287 return
289 # rglob で再帰的に探索
290 for entry in self.base_dir.rglob("*.bin"):
291 if entry.is_file(): 291 ↛ 290line 291 didn't jump to line 290 because the condition on line 291 was always true
292 # base_dir からの相対パスを返す
293 yield str(entry.relative_to(self.base_dir).as_posix())
295 def get_mtime(self, location: str) -> float:
296 full_path = (self.base_dir / location).resolve()
297 if not full_path.is_relative_to(self.base_dir): 297 ↛ 298line 297 didn't jump to line 298 because the condition on line 297 was never true
298 raise ValueError(f"Access denied: {location}")
299 try:
300 return full_path.stat().st_mtime
301 except OSError as e:
302 raise CacheCorruptedError(f"Failed to get mtime for blob: {e}")
304 def clean_temp_files(self, max_age_seconds: int = 86400) -> int:
305 """
306 Remove '.spot_tmp' files that are older than max_age_seconds.
307 Provides a fail-safe against leaked temporary files due to file locks.
308 """
309 if not self.base_dir.exists(): 309 ↛ 310line 309 didn't jump to line 310 because the condition on line 309 was never true
310 return 0
312 removed_count = 0
313 now = time.time()
315 for entry in self.base_dir.rglob("*.spot_tmp"):
316 if entry.is_file(): 316 ↛ 315line 316 didn't jump to line 315 because the condition on line 316 was always true
317 try:
318 # 猶予期間(デフォルト24時間)を経過しているかチェック
319 if now - entry.stat().st_mtime > max_age_seconds:
320 entry.unlink()
321 removed_count += 1
322 except OSError:
323 # アンチウイルスソフト等で現在もロックされている場合はスキップ
324 pass
326 return removed_count
328 def prune_empty_dirs(self) -> int:
329 """
330 Recursively remove empty directories under base_dir.
331 Also removes directories containing only system generated files (.DS_Store, etc).
332 Returns the count of removed directories.
334 Note:
335 base_dir 自体は削除しません。base_dir が削除されると以降の
336 save() で FileNotFoundError が発生するためです。
337 """
338 if not self.base_dir.exists(): 338 ↛ 339line 338 didn't jump to line 339 because the condition on line 338 was never true
339 return 0
341 IGNORED_FILES = {".DS_Store", "Thumbs.db", "desktop.ini"}
342 removed_count = 0
344 # os.walk(topdown=False) で深い階層から順に処理
345 for root, dirs, files in os.walk(self.base_dir, topdown=False):
346 path = Path(root)
348 # base_dir 自体は絶対に削除しない
349 if path == self.base_dir:
350 continue
352 existing_files = set(files)
354 # 無視リスト以外のファイルがある場合 -> 削除不可
355 if existing_files - IGNORED_FILES: 355 ↛ 359line 355 didn't jump to line 359 because the condition on line 355 was always true
356 continue
358 # 無視リストにあるファイルしか残っていない場合、それらを消して空にする
359 for f in existing_files:
360 try:
361 (path / f).unlink()
362 except OSError:
363 pass
365 # ディレクトリ削除を試みる
366 try:
367 path.rmdir()
368 removed_count += 1
369 except OSError:
370 pass
372 return removed_count
375class S3Storage(BlobStorageMaintenable):
376 def __init__(
377 self,
378 s3_uri: str,
379 s3_opts: dict[str, Any] | None = None,
380 ):
381 if not boto3:
382 raise ImportError("Run `pip install beautyspot[s3]` to use S3 storage.")
384 parts = s3_uri.replace("s3://", "").split("/", 1)
385 self.bucket_name = parts[0]
386 raw_prefix = parts[1].rstrip("/") if len(parts) > 1 else ""
387 self.prefix = raw_prefix if raw_prefix else "blobs"
389 opts = s3_opts or {}
390 self.s3 = boto3.client("s3", **opts)
392 @staticmethod
393 def _parse_s3_uri(location: str) -> tuple[str, str]:
394 """Parse an s3:// URI into (bucket, key). Raises ValueError for invalid URIs."""
395 if not location.startswith("s3://"):
396 raise ValidationError(f"Expected an s3:// URI, got: {location!r}")
397 path = location[len("s3://") :]
398 parts = path.split("/", 1)
399 if len(parts) != 2 or not parts[0] or not parts[1]:
400 raise ValidationError(
401 f"Invalid S3 URI (expected s3://bucket/key): {location!r}"
402 )
403 return parts[0], parts[1]
405 def save(self, key: str, data: ReadableBuffer) -> str:
406 s3_key = f"{self.prefix}/{key}.bin"
407 buffer = io.BytesIO(data)
408 # upload_fileobj は大容量データ(>5GB)に対してマルチパートアップロードを
409 # 自動的に使用し、put_object の 5GB 上限を回避する。
410 self.s3.upload_fileobj(buffer, self.bucket_name, s3_key)
411 return f"s3://{self.bucket_name}/{s3_key}"
413 def load(self, location: str) -> bytes:
414 bucket, key = self._parse_s3_uri(location)
415 try:
416 resp = self.s3.get_object(Bucket=bucket, Key=key)
417 body = resp["Body"]
418 try:
419 return body.read()
420 finally:
421 body.close()
422 except ClientError as e:
423 raise CacheCorruptedError(f"S3 blob lost: {location}") from e
425 def delete(self, location: str) -> None:
426 bucket, key = self._parse_s3_uri(location)
427 try:
428 self.s3.delete_object(Bucket=bucket, Key=key)
429 except ClientError as e:
430 # S3 の delete_object は存在しないオブジェクトに対してもエラーを返さないため、
431 # ここに到達するのは権限エラーやネットワーク障害などの深刻なケース。
432 # 握り潰さずにログへ記録し、GC が後続で回収できるようにする。
433 logging.warning(
434 f"Failed to delete S3 object {location}: {e}. "
435 "It will be handled by subsequent GC."
436 )
438 def get_mtime(self, location: str) -> float:
439 bucket, key = self._parse_s3_uri(location)
440 try:
441 resp = self.s3.head_object(Bucket=bucket, Key=key)
442 # LastModified is a datetime object, convert to POSIX timestamp
443 return resp["LastModified"].timestamp()
444 except ClientError as e:
445 raise CacheCorruptedError(
446 f"S3 blob lost or inaccessible: {location}"
447 ) from e
449 def list_keys(self) -> Iterator[str]:
450 """Yields s3:// URIs for all objects in the prefix."""
451 paginator = self.s3.get_paginator("list_objects_v2")
452 for page in paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix):
453 for obj in page.get("Contents", []):
454 yield f"s3://{self.bucket_name}/{obj['Key']}"
457def create_storage(path: str, options: dict | None = None) -> BlobStorageMaintenable:
458 if path.startswith("s3://"): 458 ↛ 459line 458 didn't jump to line 459 because the condition on line 458 was never true
459 return S3Storage(path, options)
461 return LocalStorage(path)