Coverage for llm_dataset_engine/adapters/checkpoint_storage.py: 37%

95 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-15 18:04 +0200

1""" 

2Checkpoint storage for fault tolerance. 

3 

4Provides persistent storage of execution state to enable resume after 

5failures. 

6""" 

7 

8import json 

9import pickle 

10from abc import ABC, abstractmethod 

11from datetime import datetime 

12from pathlib import Path 

13from typing import Any, Dict, List, Optional 

14from uuid import UUID 

15 

16from llm_dataset_engine.core.models import CheckpointInfo 

17 

18 

19class CheckpointStorage(ABC): 

20 """ 

21 Abstract base for checkpoint storage implementations. 

22  

23 Follows Strategy pattern for pluggable storage backends. 

24 """ 

25 

26 @abstractmethod 

27 def save(self, session_id: UUID, data: Dict[str, Any]) -> bool: 

28 """ 

29 Save checkpoint data. 

30 

31 Args: 

32 session_id: Unique session identifier 

33 data: Checkpoint data to save 

34 

35 Returns: 

36 True if successful 

37 """ 

38 pass 

39 

40 @abstractmethod 

41 def load(self, session_id: UUID) -> Optional[Dict[str, Any]]: 

42 """ 

43 Load latest checkpoint data. 

44 

45 Args: 

46 session_id: Session identifier 

47 

48 Returns: 

49 Checkpoint data or None if not found 

50 """ 

51 pass 

52 

53 @abstractmethod 

54 def list_checkpoints(self) -> List[CheckpointInfo]: 

55 """ 

56 List all available checkpoints. 

57 

58 Returns: 

59 List of checkpoint information 

60 """ 

61 pass 

62 

63 @abstractmethod 

64 def delete(self, session_id: UUID) -> bool: 

65 """ 

66 Delete checkpoint for session. 

67 

68 Args: 

69 session_id: Session identifier 

70 

71 Returns: 

72 True if deleted 

73 """ 

74 pass 

75 

76 @abstractmethod 

77 def exists(self, session_id: UUID) -> bool: 

78 """ 

79 Check if checkpoint exists. 

80 

81 Args: 

82 session_id: Session identifier 

83 

84 Returns: 

85 True if exists 

86 """ 

87 pass 

88 

89 

90class LocalFileCheckpointStorage(CheckpointStorage): 

91 """ 

92 Local filesystem checkpoint storage implementation. 

93  

94 Stores checkpoints as JSON files for human readability and debugging. 

95 """ 

96 

97 def __init__( 

98 self, 

99 checkpoint_dir: Path = Path(".checkpoints"), 

100 use_json: bool = True, 

101 ): 

102 """ 

103 Initialize local file checkpoint storage. 

104 

105 Args: 

106 checkpoint_dir: Directory for checkpoints 

107 use_json: Use JSON format (True) or pickle (False) 

108 """ 

109 self.checkpoint_dir = checkpoint_dir 

110 self.use_json = use_json 

111 

112 # Create directory if doesn't exist 

113 self.checkpoint_dir.mkdir(parents=True, exist_ok=True) 

114 

115 def _get_checkpoint_path(self, session_id: UUID) -> Path: 

116 """Get checkpoint file path for session.""" 

117 ext = ".json" if self.use_json else ".pkl" 

118 return self.checkpoint_dir / f"checkpoint_{session_id}{ext}" 

119 

120 def save(self, session_id: UUID, data: Dict[str, Any]) -> bool: 

121 """Save checkpoint to local file.""" 

122 checkpoint_path = self._get_checkpoint_path(session_id) 

123 

124 # Add metadata 

125 checkpoint_data = { 

126 "version": "1.0", 

127 "session_id": str(session_id), 

128 "timestamp": datetime.now().isoformat(), 

129 "data": data, 

130 } 

131 

132 try: 

133 if self.use_json: 

134 with open(checkpoint_path, "w") as f: 

135 json.dump( 

136 checkpoint_data, 

137 f, 

138 indent=2, 

139 default=str, # Handle non-serializable types 

140 ) 

141 else: 

142 with open(checkpoint_path, "wb") as f: 

143 pickle.dump(checkpoint_data, f) 

144 

145 return True 

146 except Exception: 

147 return False 

148 

149 def load(self, session_id: UUID) -> Optional[Dict[str, Any]]: 

150 """Load checkpoint from local file.""" 

151 checkpoint_path = self._get_checkpoint_path(session_id) 

152 

153 if not checkpoint_path.exists(): 

154 return None 

155 

156 try: 

157 if self.use_json: 

158 with open(checkpoint_path, "r") as f: 

159 checkpoint_data = json.load(f) 

160 else: 

161 with open(checkpoint_path, "rb") as f: 

162 checkpoint_data = pickle.load(f) 

163 

164 return checkpoint_data.get("data") 

165 except Exception: 

166 return None 

167 

168 def list_checkpoints(self) -> List[CheckpointInfo]: 

169 """List all checkpoints in directory.""" 

170 checkpoints = [] 

171 

172 pattern = "*.json" if self.use_json else "*.pkl" 

173 for checkpoint_file in self.checkpoint_dir.glob(pattern): 

174 try: 

175 # Extract session ID from filename 

176 session_id_str = checkpoint_file.stem.replace( 

177 "checkpoint_", "" 

178 ) 

179 session_id = UUID(session_id_str) 

180 

181 # Get file stats 

182 stat = checkpoint_file.stat() 

183 

184 # Try to load checkpoint for additional info 

185 data = self.load(session_id) 

186 row_index = data.get("last_processed_row", 0) if data else 0 

187 stage_index = data.get("current_stage_index", 0) if data else 0 

188 

189 checkpoints.append( 

190 CheckpointInfo( 

191 session_id=session_id, 

192 checkpoint_path=str(checkpoint_file), 

193 row_index=row_index, 

194 stage_index=stage_index, 

195 timestamp=datetime.fromtimestamp(stat.st_mtime), 

196 size_bytes=stat.st_size, 

197 ) 

198 ) 

199 except Exception: 

200 # Skip invalid checkpoint files 

201 continue 

202 

203 return sorted(checkpoints, key=lambda x: x.timestamp, reverse=True) 

204 

205 def delete(self, session_id: UUID) -> bool: 

206 """Delete checkpoint file.""" 

207 checkpoint_path = self._get_checkpoint_path(session_id) 

208 

209 if checkpoint_path.exists(): 

210 try: 

211 checkpoint_path.unlink() 

212 return True 

213 except Exception: 

214 return False 

215 return False 

216 

217 def exists(self, session_id: UUID) -> bool: 

218 """Check if checkpoint exists.""" 

219 return self._get_checkpoint_path(session_id).exists() 

220 

221 def cleanup_old_checkpoints(self, days: int = 7) -> int: 

222 """ 

223 Delete checkpoints older than specified days. 

224 

225 Args: 

226 days: Age threshold in days 

227 

228 Returns: 

229 Number of checkpoints deleted 

230 """ 

231 deleted = 0 

232 cutoff = datetime.now().timestamp() - (days * 86400) 

233 

234 pattern = "*.json" if self.use_json else "*.pkl" 

235 for checkpoint_file in self.checkpoint_dir.glob(pattern): 

236 if checkpoint_file.stat().st_mtime < cutoff: 

237 try: 

238 checkpoint_file.unlink() 

239 deleted += 1 

240 except Exception: 

241 continue 

242 

243 return deleted 

244