Coverage for src/dataknobs_llm/prompts/versioning/ab_testing.py: 83%

114 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-08 13:51 -0700

1"""A/B testing management for prompt experiments. 

2 

3This module provides: 

4- Experiment creation and management 

5- Random variant selection 

6- User-sticky variant assignment 

7- Traffic split management 

8""" 

9 

10import uuid 

11import hashlib 

12import random 

13from typing import Any, Dict, List 

14from datetime import datetime 

15 

16from .types import ( 

17 PromptExperiment, 

18 PromptVariant, 

19 VersioningError, 

20) 

21 

22 

23class ABTestManager: 

24 """Manages A/B test experiments for prompts. 

25 

26 Supports multiple selection strategies: 

27 - Random: Each request gets a random variant based on traffic split 

28 - User-sticky: Same user always gets same variant (consistent experience) 

29 

30 Example: 

31 ```python 

32 manager = ABTestManager(storage_backend) 

33 

34 # Create experiment 

35 experiment = await manager.create_experiment( 

36 name="greeting", 

37 prompt_type="system", 

38 variants=[ 

39 PromptVariant("1.0.0", 0.5, "Control"), 

40 PromptVariant("1.1.0", 0.5, "Treatment") 

41 ] 

42 ) 

43 

44 # Get variant for user (sticky assignment) 

45 variant_version = await manager.get_variant_for_user( 

46 experiment.experiment_id, 

47 user_id="user123" 

48 ) 

49 

50 # Get random variant 

51 variant_version = await manager.get_random_variant( 

52 experiment.experiment_id 

53 ) 

54 ``` 

55 """ 

56 

57 def __init__(self, storage: Any | None = None): 

58 """Initialize A/B test manager. 

59 

60 Args: 

61 storage: Backend storage (dict for in-memory, database for persistence) 

62 If None, uses in-memory dictionary 

63 """ 

64 self.storage = storage if storage is not None else {} 

65 self._experiments: Dict[str, PromptExperiment] = {} # experiment_id -> PromptExperiment 

66 self._user_assignments: Dict[str, Dict[str, str]] = {} # experiment_id -> {user_id -> version} 

67 

68 async def create_experiment( 

69 self, 

70 name: str, 

71 prompt_type: str, 

72 variants: List[PromptVariant], 

73 traffic_split: Dict[str, float] | None = None, 

74 metadata: Dict[str, Any] | None = None, 

75 ) -> PromptExperiment: 

76 """Create a new A/B test experiment. 

77 

78 Args: 

79 name: Prompt name 

80 prompt_type: Prompt type 

81 variants: List of variants to test 

82 traffic_split: Optional custom traffic split (if None, derives from variant weights) 

83 metadata: Additional metadata 

84 

85 Returns: 

86 Created PromptExperiment 

87 

88 Raises: 

89 VersioningError: If variants are invalid or traffic split doesn't sum to 1.0 

90 """ 

91 if len(variants) < 2: 

92 raise VersioningError("Experiment must have at least 2 variants") 

93 

94 # Generate experiment ID 

95 experiment_id = str(uuid.uuid4()) 

96 

97 # Derive traffic split from variant weights if not provided 

98 if traffic_split is None: 

99 # Normalize weights to ensure they sum to 1.0 

100 total_weight = sum(v.weight for v in variants) 

101 traffic_split = { 

102 v.version: v.weight / total_weight 

103 for v in variants 

104 } 

105 

106 # Create experiment 

107 experiment = PromptExperiment( 

108 experiment_id=experiment_id, 

109 name=name, 

110 prompt_type=prompt_type, 

111 variants=variants, 

112 traffic_split=traffic_split, 

113 start_date=datetime.utcnow(), 

114 status="running", 

115 metadata=metadata or {}, 

116 ) 

117 

118 # Store experiment 

119 self._experiments[experiment_id] = experiment 

120 

121 # Initialize user assignments 

122 self._user_assignments[experiment_id] = {} 

123 

124 # Persist to backend if available 

125 if hasattr(self.storage, "set"): 

126 await self._persist_experiment(experiment) 

127 

128 return experiment 

129 

130 async def get_experiment( 

131 self, 

132 experiment_id: str, 

133 ) -> PromptExperiment | None: 

134 """Retrieve an experiment by ID. 

135 

136 Args: 

137 experiment_id: Experiment ID 

138 

139 Returns: 

140 PromptExperiment if found, None otherwise 

141 """ 

142 return self._experiments.get(experiment_id) 

143 

144 async def list_experiments( 

145 self, 

146 name: str | None = None, 

147 prompt_type: str | None = None, 

148 status: str | None = None, 

149 ) -> List[PromptExperiment]: 

150 """List experiments with optional filters. 

151 

152 Args: 

153 name: Filter by prompt name 

154 prompt_type: Filter by prompt type 

155 status: Filter by status ("running", "paused", "completed") 

156 

157 Returns: 

158 List of matching experiments 

159 """ 

160 experiments = list(self._experiments.values()) 

161 

162 # Apply filters 

163 if name: 

164 experiments = [e for e in experiments if e.name == name] 

165 

166 if prompt_type: 

167 experiments = [e for e in experiments if e.prompt_type == prompt_type] 

168 

169 if status: 

170 experiments = [e for e in experiments if e.status == status] 

171 

172 return experiments 

173 

174 async def get_random_variant( 

175 self, 

176 experiment_id: str, 

177 ) -> str: 

178 """Get a random variant based on traffic split. 

179 

180 Each call returns a potentially different variant. 

181 

182 Args: 

183 experiment_id: Experiment ID 

184 

185 Returns: 

186 Version string of selected variant 

187 

188 Raises: 

189 VersioningError: If experiment not found 

190 """ 

191 experiment = self._experiments.get(experiment_id) 

192 if not experiment: 

193 raise VersioningError(f"Experiment not found: {experiment_id}") 

194 

195 if experiment.status != "running": 

196 raise VersioningError( 

197 f"Experiment {experiment_id} is not running (status: {experiment.status})" 

198 ) 

199 

200 # Weighted random selection 

201 versions = list(experiment.traffic_split.keys()) 

202 weights = list(experiment.traffic_split.values()) 

203 

204 return random.choices(versions, weights=weights)[0] 

205 

206 async def get_variant_for_user( 

207 self, 

208 experiment_id: str, 

209 user_id: str, 

210 ) -> str: 

211 """Get variant for a specific user (sticky assignment). 

212 

213 The same user always gets the same variant for consistent experience. 

214 Uses hash-based assignment to ensure deterministic selection. 

215 

216 Args: 

217 experiment_id: Experiment ID 

218 user_id: User identifier 

219 

220 Returns: 

221 Version string of assigned variant 

222 

223 Raises: 

224 VersioningError: If experiment not found 

225 """ 

226 experiment = self._experiments.get(experiment_id) 

227 if not experiment: 

228 raise VersioningError(f"Experiment not found: {experiment_id}") 

229 

230 if experiment.status != "running": 

231 raise VersioningError( 

232 f"Experiment {experiment_id} is not running (status: {experiment.status})" 

233 ) 

234 

235 # Check if user already has assignment 

236 if experiment_id in self._user_assignments: 

237 existing = self._user_assignments[experiment_id].get(user_id) 

238 if existing: 

239 return existing 

240 

241 # Assign user to variant using hash-based selection 

242 assigned_version = self._hash_based_assignment( 

243 user_id, 

244 experiment.traffic_split 

245 ) 

246 

247 # Store assignment 

248 if experiment_id not in self._user_assignments: 

249 self._user_assignments[experiment_id] = {} 

250 self._user_assignments[experiment_id][user_id] = assigned_version 

251 

252 # Persist assignment if backend available 

253 if hasattr(self.storage, "set"): 

254 key = f"assignment:{experiment_id}:{user_id}" 

255 await self.storage.set(key, assigned_version) 

256 

257 return assigned_version 

258 

259 async def update_experiment_status( 

260 self, 

261 experiment_id: str, 

262 status: str, 

263 end_date: datetime | None = None, 

264 ) -> PromptExperiment: 

265 """Update experiment status. 

266 

267 Args: 

268 experiment_id: Experiment ID 

269 status: New status ("running", "paused", "completed") 

270 end_date: Optional end date (auto-set to now if status is "completed") 

271 

272 Returns: 

273 Updated experiment 

274 

275 Raises: 

276 VersioningError: If experiment not found 

277 """ 

278 experiment = self._experiments.get(experiment_id) 

279 if not experiment: 

280 raise VersioningError(f"Experiment not found: {experiment_id}") 

281 

282 experiment.status = status 

283 

284 if status == "completed" and end_date is None: 

285 experiment.end_date = datetime.utcnow() 

286 elif end_date: 

287 experiment.end_date = end_date 

288 

289 # Persist if backend available 

290 if hasattr(self.storage, "set"): 

291 await self._persist_experiment(experiment) 

292 

293 return experiment 

294 

295 async def get_user_assignment( 

296 self, 

297 experiment_id: str, 

298 user_id: str, 

299 ) -> str | None: 

300 """Get existing user assignment without creating a new one. 

301 

302 Args: 

303 experiment_id: Experiment ID 

304 user_id: User ID 

305 

306 Returns: 

307 Assigned version if exists, None otherwise 

308 """ 

309 if experiment_id not in self._user_assignments: 

310 return None 

311 return self._user_assignments[experiment_id].get(user_id) 

312 

313 async def get_experiment_assignments( 

314 self, 

315 experiment_id: str, 

316 ) -> Dict[str, str]: 

317 """Get all user assignments for an experiment. 

318 

319 Args: 

320 experiment_id: Experiment ID 

321 

322 Returns: 

323 Dictionary mapping user_id to assigned version 

324 """ 

325 return self._user_assignments.get(experiment_id, {}) 

326 

327 async def delete_experiment( 

328 self, 

329 experiment_id: str, 

330 ) -> bool: 

331 """Delete an experiment. 

332 

333 Note: This also removes all user assignments. 

334 

335 Args: 

336 experiment_id: Experiment ID 

337 

338 Returns: 

339 True if deleted, False if not found 

340 """ 

341 if experiment_id not in self._experiments: 

342 return False 

343 

344 # Remove experiment 

345 del self._experiments[experiment_id] 

346 

347 # Remove user assignments 

348 if experiment_id in self._user_assignments: 

349 del self._user_assignments[experiment_id] 

350 

351 # Persist deletion if backend available 

352 if hasattr(self.storage, "delete"): 

353 await self.storage.delete(f"experiment:{experiment_id}") 

354 

355 return True 

356 

357 # ===== Helper Methods ===== 

358 

359 def _hash_based_assignment( 

360 self, 

361 user_id: str, 

362 traffic_split: Dict[str, float], 

363 ) -> str: 

364 """Assign user to variant using consistent hash-based selection. 

365 

366 This ensures the same user always gets the same variant. 

367 

368 Args: 

369 user_id: User identifier 

370 traffic_split: Version to percentage mapping 

371 

372 Returns: 

373 Selected version string 

374 """ 

375 # Hash user_id to get a deterministic number 

376 hash_hex = hashlib.md5(user_id.encode()).hexdigest() 

377 hash_val = int(hash_hex, 16) 

378 

379 # Map to [0, 1) range 

380 normalized = (hash_val % 1000) / 1000.0 

381 

382 # Select variant based on cumulative traffic split 

383 cumulative = 0.0 

384 versions = sorted(traffic_split.keys()) # Sort for consistency 

385 

386 for version in versions: 

387 cumulative += traffic_split[version] 

388 if normalized < cumulative: 

389 return version 

390 

391 # Fallback to last version (handles floating point errors) 

392 return versions[-1] 

393 

394 async def _persist_experiment(self, experiment: PromptExperiment): 

395 """Persist experiment to backend storage.""" 

396 if hasattr(self.storage, "set"): 

397 key = f"experiment:{experiment.experiment_id}" 

398 await self.storage.set(key, experiment.to_dict()) 

399 

400 async def get_variant_distribution( 

401 self, 

402 experiment_id: str, 

403 ) -> Dict[str, int]: 

404 """Get actual distribution of users across variants. 

405 

406 Args: 

407 experiment_id: Experiment ID 

408 

409 Returns: 

410 Dictionary mapping version to user count 

411 

412 Raises: 

413 VersioningError: If experiment not found 

414 """ 

415 experiment = self._experiments.get(experiment_id) 

416 if not experiment: 

417 raise VersioningError(f"Experiment not found: {experiment_id}") 

418 

419 assignments = self._user_assignments.get(experiment_id, {}) 

420 distribution: Dict[str, int] = {v.version: 0 for v in experiment.variants} 

421 

422 for version in assignments.values(): 

423 if version in distribution: 

424 distribution[version] += 1 

425 

426 return distribution