Coverage for src / dataknobs_llm / prompts / versioning / types.py: 65%

135 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-15 10:28 -0700

1"""Core type definitions for prompt versioning and A/B testing. 

2 

3This module defines: 

4- Version data structures 

5- Experiment configurations 

6- Metrics tracking types 

7""" 

8 

9from dataclasses import dataclass, field 

10from datetime import datetime 

11from typing import Any, Dict, List 

12from enum import Enum 

13 

14 

15class VersionStatus(Enum): 

16 """Status of a prompt version. 

17 

18 Attributes: 

19 DRAFT: Version is in development 

20 ACTIVE: Version is active and can be used 

21 PRODUCTION: Version is deployed in production 

22 DEPRECATED: Version is deprecated but still available 

23 ARCHIVED: Version is archived and should not be used 

24 """ 

25 DRAFT = "draft" 

26 ACTIVE = "active" 

27 PRODUCTION = "production" 

28 DEPRECATED = "deprecated" 

29 ARCHIVED = "archived" 

30 

31 

32@dataclass 

33class PromptVersion: 

34 """Represents a versioned prompt. 

35 

36 Attributes: 

37 version_id: Unique identifier for this version (auto-generated) 

38 name: Name of the prompt 

39 prompt_type: Type of prompt ("system", "user", "message") 

40 version: Semantic version string (e.g., "1.2.3") 

41 template: The prompt template content 

42 defaults: Default parameter values 

43 validation: Validation configuration 

44 metadata: Additional metadata (author, description, etc.) 

45 created_at: Timestamp when version was created 

46 created_by: Username/ID of creator 

47 parent_version: Previous version ID (for history tracking) 

48 tags: List of tags (e.g., ["production", "experiment-A"]) 

49 status: Current status of this version 

50 """ 

51 version_id: str 

52 name: str 

53 prompt_type: str 

54 version: str 

55 template: str 

56 defaults: Dict[str, Any] = field(default_factory=dict) 

57 validation: Dict[str, Any] | None = None 

58 metadata: Dict[str, Any] = field(default_factory=dict) 

59 created_at: datetime = field(default_factory=datetime.utcnow) 

60 created_by: str | None = None 

61 parent_version: str | None = None 

62 tags: List[str] = field(default_factory=list) 

63 status: VersionStatus = VersionStatus.ACTIVE 

64 

65 def to_dict(self) -> Dict[str, Any]: 

66 """Convert to dictionary for storage.""" 

67 return { 

68 "version_id": self.version_id, 

69 "name": self.name, 

70 "prompt_type": self.prompt_type, 

71 "version": self.version, 

72 "template": self.template, 

73 "defaults": self.defaults, 

74 "validation": self.validation, 

75 "metadata": self.metadata, 

76 "created_at": self.created_at.isoformat(), 

77 "created_by": self.created_by, 

78 "parent_version": self.parent_version, 

79 "tags": self.tags, 

80 "status": self.status.value, 

81 } 

82 

83 @classmethod 

84 def from_dict(cls, data: Dict[str, Any]) -> "PromptVersion": 

85 """Create from dictionary.""" 

86 data = data.copy() 

87 # Parse datetime 

88 if isinstance(data.get("created_at"), str): 

89 data["created_at"] = datetime.fromisoformat(data["created_at"]) 

90 # Parse status enum 

91 if isinstance(data.get("status"), str): 

92 data["status"] = VersionStatus(data["status"]) 

93 return cls(**data) 

94 

95 

96@dataclass 

97class PromptVariant: 

98 """A variant in an A/B test experiment. 

99 

100 Attributes: 

101 version: Version string of this variant 

102 weight: Traffic allocation weight (relative weight, must be > 0.0) 

103 Weights are normalized to sum to 1.0 when creating experiment 

104 description: Human-readable description 

105 metadata: Additional variant metadata 

106 """ 

107 version: str 

108 weight: float 

109 description: str = "" 

110 metadata: Dict[str, Any] = field(default_factory=dict) 

111 

112 def __post_init__(self): 

113 """Validate weight is positive.""" 

114 if self.weight <= 0.0: 

115 raise ValueError(f"Variant weight must be positive, got {self.weight}") 

116 

117 def to_dict(self) -> Dict[str, Any]: 

118 """Convert to dictionary for storage.""" 

119 return { 

120 "version": self.version, 

121 "weight": self.weight, 

122 "description": self.description, 

123 "metadata": self.metadata, 

124 } 

125 

126 @classmethod 

127 def from_dict(cls, data: Dict[str, Any]) -> "PromptVariant": 

128 """Create from dictionary.""" 

129 return cls(**data) 

130 

131 

132@dataclass 

133class PromptExperiment: 

134 """Configuration for an A/B test experiment. 

135 

136 Attributes: 

137 experiment_id: Unique identifier for this experiment 

138 name: Name of the prompt being tested 

139 prompt_type: Type of prompt ("system", "user", "message") 

140 variants: List of variants in this experiment 

141 traffic_split: Mapping of version to traffic percentage 

142 start_date: When experiment started 

143 end_date: When experiment ended (None if still running) 

144 status: Current status ("running", "paused", "completed") 

145 metrics: Aggregated metrics for the experiment 

146 metadata: Additional experiment metadata 

147 """ 

148 experiment_id: str 

149 name: str 

150 prompt_type: str 

151 variants: List[PromptVariant] 

152 traffic_split: Dict[str, float] 

153 start_date: datetime = field(default_factory=datetime.utcnow) 

154 end_date: datetime | None = None 

155 status: str = "running" 

156 metrics: Dict[str, Any] = field(default_factory=dict) 

157 metadata: Dict[str, Any] = field(default_factory=dict) 

158 

159 def __post_init__(self): 

160 """Validate traffic split sums to 1.0.""" 

161 total = sum(self.traffic_split.values()) 

162 if not (0.99 <= total <= 1.01): # Allow small floating point error 

163 raise ValueError( 

164 f"Traffic split must sum to 1.0, got {total}. " 

165 f"Split: {self.traffic_split}" 

166 ) 

167 

168 def to_dict(self) -> Dict[str, Any]: 

169 """Convert to dictionary for storage.""" 

170 return { 

171 "experiment_id": self.experiment_id, 

172 "name": self.name, 

173 "prompt_type": self.prompt_type, 

174 "variants": [v.to_dict() for v in self.variants], 

175 "traffic_split": self.traffic_split, 

176 "start_date": self.start_date.isoformat(), 

177 "end_date": self.end_date.isoformat() if self.end_date else None, 

178 "status": self.status, 

179 "metrics": self.metrics, 

180 "metadata": self.metadata, 

181 } 

182 

183 @classmethod 

184 def from_dict(cls, data: Dict[str, Any]) -> "PromptExperiment": 

185 """Create from dictionary.""" 

186 data = data.copy() 

187 # Parse datetimes 

188 if isinstance(data.get("start_date"), str): 

189 data["start_date"] = datetime.fromisoformat(data["start_date"]) 

190 if isinstance(data.get("end_date"), str): 

191 data["end_date"] = datetime.fromisoformat(data["end_date"]) 

192 # Parse variants 

193 if data.get("variants"): 

194 data["variants"] = [ 

195 PromptVariant.from_dict(v) if isinstance(v, dict) else v 

196 for v in data["variants"] 

197 ] 

198 return cls(**data) 

199 

200 

201@dataclass 

202class PromptMetrics: 

203 """Performance metrics for a prompt version. 

204 

205 Attributes: 

206 version_id: Version ID these metrics belong to 

207 total_uses: Total number of times this version was used 

208 success_count: Number of successful uses 

209 error_count: Number of errors/failures 

210 total_response_time: Total response time across all uses (seconds) 

211 total_tokens: Total tokens used across all uses 

212 user_ratings: List of user ratings (1-5 scale) 

213 last_used: Timestamp of last use 

214 metadata: Additional custom metrics 

215 """ 

216 version_id: str 

217 total_uses: int = 0 

218 success_count: int = 0 

219 error_count: int = 0 

220 total_response_time: float = 0.0 

221 total_tokens: int = 0 

222 user_ratings: List[float] = field(default_factory=list) 

223 last_used: datetime | None = None 

224 metadata: Dict[str, Any] = field(default_factory=dict) 

225 

226 @property 

227 def success_rate(self) -> float: 

228 """Calculate success rate.""" 

229 if self.total_uses == 0: 

230 return 0.0 

231 return self.success_count / self.total_uses 

232 

233 @property 

234 def avg_response_time(self) -> float: 

235 """Calculate average response time.""" 

236 if self.total_uses == 0: 

237 return 0.0 

238 return self.total_response_time / self.total_uses 

239 

240 @property 

241 def avg_tokens(self) -> float: 

242 """Calculate average tokens per use.""" 

243 if self.total_uses == 0: 

244 return 0.0 

245 return self.total_tokens / self.total_uses 

246 

247 @property 

248 def avg_rating(self) -> float: 

249 """Calculate average user rating.""" 

250 if not self.user_ratings: 

251 return 0.0 

252 return sum(self.user_ratings) / len(self.user_ratings) 

253 

254 def to_dict(self) -> Dict[str, Any]: 

255 """Convert to dictionary for storage.""" 

256 return { 

257 "version_id": self.version_id, 

258 "total_uses": self.total_uses, 

259 "success_count": self.success_count, 

260 "error_count": self.error_count, 

261 "total_response_time": self.total_response_time, 

262 "total_tokens": self.total_tokens, 

263 "user_ratings": self.user_ratings, 

264 "last_used": self.last_used.isoformat() if self.last_used else None, 

265 "metadata": self.metadata, 

266 # Include computed properties 

267 "success_rate": self.success_rate, 

268 "avg_response_time": self.avg_response_time, 

269 "avg_tokens": self.avg_tokens, 

270 "avg_rating": self.avg_rating, 

271 } 

272 

273 @classmethod 

274 def from_dict(cls, data: Dict[str, Any]) -> "PromptMetrics": 

275 """Create from dictionary.""" 

276 data = data.copy() 

277 # Parse datetime 

278 if isinstance(data.get("last_used"), str): 

279 data["last_used"] = datetime.fromisoformat(data["last_used"]) 

280 # Remove computed properties (they're recalculated) 

281 for key in ["success_rate", "avg_response_time", "avg_tokens", "avg_rating"]: 

282 data.pop(key, None) 

283 return cls(**data) 

284 

285 

286@dataclass 

287class MetricEvent: 

288 """Single event for metrics tracking. 

289 

290 Attributes: 

291 version_id: Version ID this event belongs to 

292 timestamp: When the event occurred 

293 success: Whether the use was successful 

294 response_time: Response time in seconds (None if not applicable) 

295 tokens: Number of tokens used (None if not applicable) 

296 user_rating: User rating 1-5 (None if not provided) 

297 metadata: Additional event metadata 

298 """ 

299 version_id: str 

300 timestamp: datetime = field(default_factory=datetime.utcnow) 

301 success: bool = True 

302 response_time: float | None = None 

303 tokens: int | None = None 

304 user_rating: float | None = None 

305 metadata: Dict[str, Any] = field(default_factory=dict) 

306 

307 def to_dict(self) -> Dict[str, Any]: 

308 """Convert to dictionary for storage.""" 

309 return { 

310 "version_id": self.version_id, 

311 "timestamp": self.timestamp.isoformat(), 

312 "success": self.success, 

313 "response_time": self.response_time, 

314 "tokens": self.tokens, 

315 "user_rating": self.user_rating, 

316 "metadata": self.metadata, 

317 } 

318 

319 @classmethod 

320 def from_dict(cls, data: Dict[str, Any]) -> "MetricEvent": 

321 """Create from dictionary.""" 

322 data = data.copy() 

323 if isinstance(data.get("timestamp"), str): 

324 data["timestamp"] = datetime.fromisoformat(data["timestamp"]) 

325 return cls(**data)