Coverage for llm_dataset_engine/adapters/checkpoint_storage.py: 37%
95 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
1"""
2Checkpoint storage for fault tolerance.
4Provides persistent storage of execution state to enable resume after
5failures.
6"""
8import json
9import pickle
10from abc import ABC, abstractmethod
11from datetime import datetime
12from pathlib import Path
13from typing import Any, Dict, List, Optional
14from uuid import UUID
16from llm_dataset_engine.core.models import CheckpointInfo
19class CheckpointStorage(ABC):
20 """
21 Abstract base for checkpoint storage implementations.
23 Follows Strategy pattern for pluggable storage backends.
24 """
26 @abstractmethod
27 def save(self, session_id: UUID, data: Dict[str, Any]) -> bool:
28 """
29 Save checkpoint data.
31 Args:
32 session_id: Unique session identifier
33 data: Checkpoint data to save
35 Returns:
36 True if successful
37 """
38 pass
40 @abstractmethod
41 def load(self, session_id: UUID) -> Optional[Dict[str, Any]]:
42 """
43 Load latest checkpoint data.
45 Args:
46 session_id: Session identifier
48 Returns:
49 Checkpoint data or None if not found
50 """
51 pass
53 @abstractmethod
54 def list_checkpoints(self) -> List[CheckpointInfo]:
55 """
56 List all available checkpoints.
58 Returns:
59 List of checkpoint information
60 """
61 pass
63 @abstractmethod
64 def delete(self, session_id: UUID) -> bool:
65 """
66 Delete checkpoint for session.
68 Args:
69 session_id: Session identifier
71 Returns:
72 True if deleted
73 """
74 pass
76 @abstractmethod
77 def exists(self, session_id: UUID) -> bool:
78 """
79 Check if checkpoint exists.
81 Args:
82 session_id: Session identifier
84 Returns:
85 True if exists
86 """
87 pass
90class LocalFileCheckpointStorage(CheckpointStorage):
91 """
92 Local filesystem checkpoint storage implementation.
94 Stores checkpoints as JSON files for human readability and debugging.
95 """
97 def __init__(
98 self,
99 checkpoint_dir: Path = Path(".checkpoints"),
100 use_json: bool = True,
101 ):
102 """
103 Initialize local file checkpoint storage.
105 Args:
106 checkpoint_dir: Directory for checkpoints
107 use_json: Use JSON format (True) or pickle (False)
108 """
109 self.checkpoint_dir = checkpoint_dir
110 self.use_json = use_json
112 # Create directory if doesn't exist
113 self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
115 def _get_checkpoint_path(self, session_id: UUID) -> Path:
116 """Get checkpoint file path for session."""
117 ext = ".json" if self.use_json else ".pkl"
118 return self.checkpoint_dir / f"checkpoint_{session_id}{ext}"
120 def save(self, session_id: UUID, data: Dict[str, Any]) -> bool:
121 """Save checkpoint to local file."""
122 checkpoint_path = self._get_checkpoint_path(session_id)
124 # Add metadata
125 checkpoint_data = {
126 "version": "1.0",
127 "session_id": str(session_id),
128 "timestamp": datetime.now().isoformat(),
129 "data": data,
130 }
132 try:
133 if self.use_json:
134 with open(checkpoint_path, "w") as f:
135 json.dump(
136 checkpoint_data,
137 f,
138 indent=2,
139 default=str, # Handle non-serializable types
140 )
141 else:
142 with open(checkpoint_path, "wb") as f:
143 pickle.dump(checkpoint_data, f)
145 return True
146 except Exception:
147 return False
149 def load(self, session_id: UUID) -> Optional[Dict[str, Any]]:
150 """Load checkpoint from local file."""
151 checkpoint_path = self._get_checkpoint_path(session_id)
153 if not checkpoint_path.exists():
154 return None
156 try:
157 if self.use_json:
158 with open(checkpoint_path, "r") as f:
159 checkpoint_data = json.load(f)
160 else:
161 with open(checkpoint_path, "rb") as f:
162 checkpoint_data = pickle.load(f)
164 return checkpoint_data.get("data")
165 except Exception:
166 return None
168 def list_checkpoints(self) -> List[CheckpointInfo]:
169 """List all checkpoints in directory."""
170 checkpoints = []
172 pattern = "*.json" if self.use_json else "*.pkl"
173 for checkpoint_file in self.checkpoint_dir.glob(pattern):
174 try:
175 # Extract session ID from filename
176 session_id_str = checkpoint_file.stem.replace(
177 "checkpoint_", ""
178 )
179 session_id = UUID(session_id_str)
181 # Get file stats
182 stat = checkpoint_file.stat()
184 # Try to load checkpoint for additional info
185 data = self.load(session_id)
186 row_index = data.get("last_processed_row", 0) if data else 0
187 stage_index = data.get("current_stage_index", 0) if data else 0
189 checkpoints.append(
190 CheckpointInfo(
191 session_id=session_id,
192 checkpoint_path=str(checkpoint_file),
193 row_index=row_index,
194 stage_index=stage_index,
195 timestamp=datetime.fromtimestamp(stat.st_mtime),
196 size_bytes=stat.st_size,
197 )
198 )
199 except Exception:
200 # Skip invalid checkpoint files
201 continue
203 return sorted(checkpoints, key=lambda x: x.timestamp, reverse=True)
205 def delete(self, session_id: UUID) -> bool:
206 """Delete checkpoint file."""
207 checkpoint_path = self._get_checkpoint_path(session_id)
209 if checkpoint_path.exists():
210 try:
211 checkpoint_path.unlink()
212 return True
213 except Exception:
214 return False
215 return False
217 def exists(self, session_id: UUID) -> bool:
218 """Check if checkpoint exists."""
219 return self._get_checkpoint_path(session_id).exists()
221 def cleanup_old_checkpoints(self, days: int = 7) -> int:
222 """
223 Delete checkpoints older than specified days.
225 Args:
226 days: Age threshold in days
228 Returns:
229 Number of checkpoints deleted
230 """
231 deleted = 0
232 cutoff = datetime.now().timestamp() - (days * 86400)
234 pattern = "*.json" if self.use_json else "*.pkl"
235 for checkpoint_file in self.checkpoint_dir.glob(pattern):
236 if checkpoint_file.stat().st_mtime < cutoff:
237 try:
238 checkpoint_file.unlink()
239 deleted += 1
240 except Exception:
241 continue
243 return deleted