Coverage for src / beautyspot / storage.py: 62%

220 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-03-11 19:10 +0900

1# src/beautyspot/storage.py 

2 

3import os 

4import io 

5import logging 

6import tempfile 

7import time 

8from pathlib import Path 

9from abc import ABC, abstractmethod 

10from typing import Any, TypeAlias, Iterator, Protocol, runtime_checkable 

11from dataclasses import dataclass, field 

12from beautyspot.exceptions import CacheCorruptedError, ValidationError 

13 

14try: 

15 import boto3 

16 from botocore.exceptions import ClientError 

17except ImportError: 

18 boto3 = None 

19 ClientError = Exception 

20 

21 

22ReadableBuffer: TypeAlias = bytes | bytearray | memoryview 

23 

24# --- Storage Policies --- 

25 

26@runtime_checkable 

27class StoragePolicyProtocol(Protocol): 

28 """ 

29 Protocol to determine if data should be saved as a blob (file/object storage) 

30 or directly in the database based on the data content (usually size). 

31 """ 

32 

33 def should_save_as_blob(self, data: bytes) -> bool: ... 

34 

35 

36@dataclass 

37class ThresholdStoragePolicy(StoragePolicyProtocol): 

38 """ 

39 Policy that saves data as a blob if its size exceeds a configured threshold. 

40 This is the recommended policy for automatic optimization. 

41 """ 

42 

43 threshold: int 

44 

45 def should_save_as_blob(self, data: bytes) -> bool: 

46 return len(data) > self.threshold 

47 

48 

49@dataclass 

50class WarningOnlyPolicy(StoragePolicyProtocol): 

51 """ 

52 Policy for backward compatibility (v2.0 behavior). 

53 Does not force blob storage, but logs a warning if size exceeds threshold. 

54 """ 

55 

56 warning_threshold: int 

57 # logger は比較・repr 対象外にする。 

58 # dataclass の自動生成 __eq__ に logger インスタンスが混入するのを防ぐ。 

59 logger: logging.Logger = field( 

60 default_factory=lambda: logging.getLogger("beautyspot"), 

61 compare=False, 

62 repr=False, 

63 ) 

64 

65 def should_save_as_blob(self, data: bytes) -> bool: 

66 if len(data) > self.warning_threshold: 

67 self.logger.warning( 

68 f"⚠️ Large data detected ({len(data)} bytes). " 

69 f"Consider using `save_blob=True` or a stricter StoragePolicy." 

70 ) 

71 return False 

72 

73 

74@dataclass 

75class AlwaysBlobPolicy(StoragePolicyProtocol): 

76 """ 

77 Policy that always saves data as a blob. 

78 Equivalent to setting `default_save_blob=True`. 

79 """ 

80 

81 def should_save_as_blob(self, data: bytes) -> bool: 

82 return True 

83 

84 

85# --- Blob Storage Implementations --- 

86 

87@runtime_checkable 

88class BlobStorageCore(Protocol): 

89 """ 

90 Core interface for large object storage required during execution. 

91 """ 

92 

93 def save(self, key: str, data: ReadableBuffer) -> str: ... 

94 def load(self, location: str) -> bytes: ... 

95 def delete(self, location: str) -> None: ... 

96 

97@runtime_checkable 

98class Maintenable(Protocol): 

99 """ 

100 Extended interface for maintenance tasks (GC). 

101 """ 

102 

103 def list_keys(self) -> Iterator[str]: ... 

104 def get_mtime(self, location: str) -> float: ... 

105 

106 

107@runtime_checkable 

108class BlobStorageMaintenable(BlobStorageCore, Maintenable, Protocol): 

109 ... 

110 

111 

112class BlobStorageBase(ABC): 

113 """ 

114 Abstract base class for large object storage (BLOBs). 

115 Implementations should at least fulfill BlobStorageCore. 

116 """ 

117 

118 @abstractmethod 

119 def save(self, key: str, data: ReadableBuffer) -> str: 

120 """ 

121 Persist the data associated with the given key. 

122 Returns a location identifier. 

123 """ 

124 pass 

125 

126 @abstractmethod 

127 def load(self, location: str) -> bytes: 

128 """ 

129 Retrieve data from the specified location. 

130 """ 

131 pass 

132 

133 @abstractmethod 

134 def delete(self, location: str) -> None: 

135 """ 

136 Delete the blob at the specified location. 

137 Should be idempotent (no error if file missing). 

138 """ 

139 pass 

140 

141 @abstractmethod 

142 def list_keys(self) -> Iterator[str]: 

143 """ 

144 Yields location identifiers for all stored blobs. 

145 Used for Garbage Collection. 

146 MUST yield the same format (path/URI) that is accepted by `delete`. 

147 """ 

148 pass 

149 

150 @abstractmethod 

151 def get_mtime(self, location: str) -> float: 

152 """ 

153 Get the last modified time of the blob as a POSIX timestamp. 

154 Used to prevent race conditions during Garbage Collection. 

155 """ 

156 pass 

157 

158 

159class LocalStorage(BlobStorageMaintenable): 

160 def __init__(self, base_dir: str | Path): 

161 # Resolve to absolute path explicitly on init 

162 self.base_dir = Path(base_dir).resolve() 

163 self._ensure_cache_dir(self.base_dir) 

164 

165 @staticmethod 

166 def _ensure_cache_dir(directory: Path) -> None: 

167 """ 

168 ディレクトリを作成し、Gitの管理下に入らないよう .gitignore を配置する。 

169 """ 

170 directory.mkdir(parents=True, exist_ok=True) 

171 gitignore_path = directory / ".gitignore" 

172 if not gitignore_path.exists(): 

173 try: 

174 gitignore_path.write_text("*\n") 

175 except OSError as e: 

176 # 権限問題などで書けない場合は処理を続行(ログのみ) 

177 logging.warning(f"Failed to create .gitignore in {directory}: {e}") 

178 

179 def _validate_key(self, key: str): 

180 """save() に渡されるキャッシュキーを検証する。 

181 

182 Note: 

183 この検証は save() の引数(通常は SHA-256 ハッシュ)にのみ適用される。 

184 list_keys() が返すロケーション文字列(例: 'subdir/hash.bin')は 

185 レガシーデータとの互換性のためにパス区切り文字を含む場合があり、 

186 load() / delete() では別途パストラバーサルチェックを行う。 

187 """ 

188 # Prevent Path Traversal 

189 if ".." in key or "/" in key or "\\" in key: 

190 raise ValidationError( 

191 f"Invalid key: '{key}'. Keys must not contain path separators." 

192 ) 

193 

194 def save(self, key: str, data: ReadableBuffer) -> str: 

195 """ 

196 指定されたキーでデータをローカルディスクに保存し、ファイル名(location)を返す。 

197 

198 単純な `open(..., 'wb')` による上書きは行わず、`tempfile.mkstemp` で一意な 

199 一時ファイルを作成して書き込んだ後、`os.replace` でアトミックにリネームする手法を採用している。 

200 これは以下の2点を防ぐためである。 

201 1. 並行実行時(複数スレッド/プロセス)に同じキャッシュキーに同時に書き込もうとした際のファイルの競合・破損。 

202 2. 書き込み中のプロセス強制終了などによる、不完全で壊れたファイルの残留。 

203 

204 Args: 

205 key (str): 保存するキャッシュキー 

206 data (ReadableBuffer): 保存するバイトデータ 

207 

208 Returns: 

209 str: 保存されたファイル名 

210 """ 

211 self._validate_key(key) 

212 filename = f"{key}.bin" 

213 filepath = self.base_dir / filename 

214 

215 # Atomic write: mkstemp generates a unique temp file to avoid collisions 

216 # when multiple threads/processes write concurrently. 

217 # flush + fsync ensures data reaches disk before rename, 

218 # so a crash between write and rename never leaves a corrupt file. 

219 fd, temp_path_str = tempfile.mkstemp(dir=self.base_dir, suffix=".spot_tmp") 

220 try: 

221 with os.fdopen(fd, "wb", closefd=True) as f: 

222 f.write(data) 

223 f.flush() 

224 os.fsync(f.fileno()) 

225 Path(temp_path_str).replace(filepath) 

226 except BaseException: 

227 try: 

228 os.unlink(temp_path_str) 

229 except OSError: 

230 # PermissionError等で消せなかった場合は残留するが、後でGCが回収する 

231 pass 

232 raise 

233 

234 return filename 

235 

236 def load(self, location: str) -> bytes: 

237 # [CHANGED] Resolve location relative to base_dir. 

238 # Note: If 'location' is an absolute path (legacy data), pathlib behavior 

239 # (base / abs) returns abs, so backward compatibility on the same machine is preserved. 

240 full_path = (self.base_dir / location).resolve() 

241 

242 # Security check: Ensure the path is strictly within the base_dir 

243 if not full_path.is_relative_to(self.base_dir): 

244 raise CacheCorruptedError( 

245 f"Access denied: {location} resolves to {full_path}, which is outside {self.base_dir}" 

246 ) 

247 

248 if not full_path.exists(): 

249 raise CacheCorruptedError(f"Local blob lost: {full_path}") 

250 

251 try: 

252 with open(full_path, "rb") as f: 

253 return f.read() 

254 except OSError as e: 

255 raise CacheCorruptedError(f"Failed to read blob: {e}") 

256 

257 def delete(self, location: str) -> None: 

258 """ 

259 Delete the file at the given location. 

260 

261 Note: 

262 For performance reasons, this method does not synchronously remove 

263 empty parent directories. Directory cleanup is deferred to the 

264 asynchronous maintenance task (`prune_empty_dirs` / `beautyspot gc`). 

265 """ 

266 full_path = (self.base_dir / location).resolve() 

267 

268 if not full_path.is_relative_to(self.base_dir): 268 ↛ 269line 268 didn't jump to line 269 because the condition on line 268 was never true

269 return 

270 

271 try: 

272 os.remove(full_path) 

273 except FileNotFoundError: 

274 pass 

275 except (PermissionError, OSError) as e: 

276 # 他のプロセスによるロックや権限の問題をログに残すが、処理は継続する 

277 logging.warning( 

278 f"Failed to delete blob at {full_path}: {e}. It will be handled by subsequent GC." 

279 ) 

280 

281 def list_keys(self) -> Iterator[str]: 

282 """ 

283 Yields relative paths of all .bin files, including subdirectories. 

284 Example: 'hash.bin' or 'subdir/hash.bin' 

285 """ 

286 if not self.base_dir.exists(): 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true

287 return 

288 

289 # rglob で再帰的に探索 

290 for entry in self.base_dir.rglob("*.bin"): 

291 if entry.is_file(): 291 ↛ 290line 291 didn't jump to line 290 because the condition on line 291 was always true

292 # base_dir からの相対パスを返す 

293 yield str(entry.relative_to(self.base_dir).as_posix()) 

294 

295 def get_mtime(self, location: str) -> float: 

296 full_path = (self.base_dir / location).resolve() 

297 if not full_path.is_relative_to(self.base_dir): 297 ↛ 298line 297 didn't jump to line 298 because the condition on line 297 was never true

298 raise ValueError(f"Access denied: {location}") 

299 try: 

300 return full_path.stat().st_mtime 

301 except OSError as e: 

302 raise CacheCorruptedError(f"Failed to get mtime for blob: {e}") 

303 

304 def clean_temp_files(self, max_age_seconds: int = 86400) -> int: 

305 """ 

306 Remove '.spot_tmp' files that are older than max_age_seconds. 

307 Provides a fail-safe against leaked temporary files due to file locks. 

308 """ 

309 if not self.base_dir.exists(): 309 ↛ 310line 309 didn't jump to line 310 because the condition on line 309 was never true

310 return 0 

311 

312 removed_count = 0 

313 now = time.time() 

314 

315 for entry in self.base_dir.rglob("*.spot_tmp"): 

316 if entry.is_file(): 316 ↛ 315line 316 didn't jump to line 315 because the condition on line 316 was always true

317 try: 

318 # 猶予期間(デフォルト24時間)を経過しているかチェック 

319 if now - entry.stat().st_mtime > max_age_seconds: 

320 entry.unlink() 

321 removed_count += 1 

322 except OSError: 

323 # アンチウイルスソフト等で現在もロックされている場合はスキップ 

324 pass 

325 

326 return removed_count 

327 

328 def prune_empty_dirs(self) -> int: 

329 """ 

330 Recursively remove empty directories under base_dir. 

331 Also removes directories containing only system generated files (.DS_Store, etc). 

332 Returns the count of removed directories. 

333 

334 Note: 

335 base_dir 自体は削除しません。base_dir が削除されると以降の 

336 save() で FileNotFoundError が発生するためです。 

337 """ 

338 if not self.base_dir.exists(): 338 ↛ 339line 338 didn't jump to line 339 because the condition on line 338 was never true

339 return 0 

340 

341 IGNORED_FILES = {".DS_Store", "Thumbs.db", "desktop.ini"} 

342 removed_count = 0 

343 

344 # os.walk(topdown=False) で深い階層から順に処理 

345 for root, dirs, files in os.walk(self.base_dir, topdown=False): 

346 path = Path(root) 

347 

348 # base_dir 自体は絶対に削除しない 

349 if path == self.base_dir: 

350 continue 

351 

352 existing_files = set(files) 

353 

354 # 無視リスト以外のファイルがある場合 -> 削除不可 

355 if existing_files - IGNORED_FILES: 355 ↛ 359line 355 didn't jump to line 359 because the condition on line 355 was always true

356 continue 

357 

358 # 無視リストにあるファイルしか残っていない場合、それらを消して空にする 

359 for f in existing_files: 

360 try: 

361 (path / f).unlink() 

362 except OSError: 

363 pass 

364 

365 # ディレクトリ削除を試みる 

366 try: 

367 path.rmdir() 

368 removed_count += 1 

369 except OSError: 

370 pass 

371 

372 return removed_count 

373 

374 

375class S3Storage(BlobStorageMaintenable): 

376 def __init__( 

377 self, 

378 s3_uri: str, 

379 s3_opts: dict[str, Any] | None = None, 

380 ): 

381 if not boto3: 

382 raise ImportError("Run `pip install beautyspot[s3]` to use S3 storage.") 

383 

384 parts = s3_uri.replace("s3://", "").split("/", 1) 

385 self.bucket_name = parts[0] 

386 raw_prefix = parts[1].rstrip("/") if len(parts) > 1 else "" 

387 self.prefix = raw_prefix if raw_prefix else "blobs" 

388 

389 opts = s3_opts or {} 

390 self.s3 = boto3.client("s3", **opts) 

391 

392 @staticmethod 

393 def _parse_s3_uri(location: str) -> tuple[str, str]: 

394 """Parse an s3:// URI into (bucket, key). Raises ValueError for invalid URIs.""" 

395 if not location.startswith("s3://"): 

396 raise ValidationError(f"Expected an s3:// URI, got: {location!r}") 

397 path = location[len("s3://") :] 

398 parts = path.split("/", 1) 

399 if len(parts) != 2 or not parts[0] or not parts[1]: 

400 raise ValidationError( 

401 f"Invalid S3 URI (expected s3://bucket/key): {location!r}" 

402 ) 

403 return parts[0], parts[1] 

404 

405 def save(self, key: str, data: ReadableBuffer) -> str: 

406 s3_key = f"{self.prefix}/{key}.bin" 

407 buffer = io.BytesIO(data) 

408 # upload_fileobj は大容量データ(>5GB)に対してマルチパートアップロードを 

409 # 自動的に使用し、put_object の 5GB 上限を回避する。 

410 self.s3.upload_fileobj(buffer, self.bucket_name, s3_key) 

411 return f"s3://{self.bucket_name}/{s3_key}" 

412 

413 def load(self, location: str) -> bytes: 

414 bucket, key = self._parse_s3_uri(location) 

415 try: 

416 resp = self.s3.get_object(Bucket=bucket, Key=key) 

417 body = resp["Body"] 

418 try: 

419 return body.read() 

420 finally: 

421 body.close() 

422 except ClientError as e: 

423 raise CacheCorruptedError(f"S3 blob lost: {location}") from e 

424 

425 def delete(self, location: str) -> None: 

426 bucket, key = self._parse_s3_uri(location) 

427 try: 

428 self.s3.delete_object(Bucket=bucket, Key=key) 

429 except ClientError as e: 

430 # S3 の delete_object は存在しないオブジェクトに対してもエラーを返さないため、 

431 # ここに到達するのは権限エラーやネットワーク障害などの深刻なケース。 

432 # 握り潰さずにログへ記録し、GC が後続で回収できるようにする。 

433 logging.warning( 

434 f"Failed to delete S3 object {location}: {e}. " 

435 "It will be handled by subsequent GC." 

436 ) 

437 

438 def get_mtime(self, location: str) -> float: 

439 bucket, key = self._parse_s3_uri(location) 

440 try: 

441 resp = self.s3.head_object(Bucket=bucket, Key=key) 

442 # LastModified is a datetime object, convert to POSIX timestamp 

443 return resp["LastModified"].timestamp() 

444 except ClientError as e: 

445 raise CacheCorruptedError( 

446 f"S3 blob lost or inaccessible: {location}" 

447 ) from e 

448 

449 def list_keys(self) -> Iterator[str]: 

450 """Yields s3:// URIs for all objects in the prefix.""" 

451 paginator = self.s3.get_paginator("list_objects_v2") 

452 for page in paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix): 

453 for obj in page.get("Contents", []): 

454 yield f"s3://{self.bucket_name}/{obj['Key']}" 

455 

456 

457def create_storage(path: str, options: dict | None = None) -> BlobStorageMaintenable: 

458 if path.startswith("s3://"): 458 ↛ 459line 458 didn't jump to line 459 because the condition on line 458 was never true

459 return S3Storage(path, options) 

460 

461 return LocalStorage(path) 

462