Coverage for src / dataknobs_llm / prompts / versioning / ab_testing.py: 18%

115 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-15 10:28 -0700

1"""A/B testing management for prompt experiments. 

2 

3This module provides: 

4- Experiment creation and management 

5- Random variant selection 

6- User-sticky variant assignment 

7- Traffic split management 

8""" 

9 

10import uuid 

11import hashlib 

12import random 

13from typing import Any, Dict, List 

14from datetime import datetime 

15 

16from dataknobs_llm.exceptions import VersioningError 

17 

18from .types import ( 

19 PromptExperiment, 

20 PromptVariant, 

21) 

22 

23 

24class ABTestManager: 

25 """Manages A/B test experiments for prompts. 

26 

27 Supports multiple selection strategies: 

28 - Random: Each request gets a random variant based on traffic split 

29 - User-sticky: Same user always gets same variant (consistent experience) 

30 

31 Example: 

32 ```python 

33 manager = ABTestManager(storage_backend) 

34 

35 # Create experiment 

36 experiment = await manager.create_experiment( 

37 name="greeting", 

38 prompt_type="system", 

39 variants=[ 

40 PromptVariant("1.0.0", 0.5, "Control"), 

41 PromptVariant("1.1.0", 0.5, "Treatment") 

42 ] 

43 ) 

44 

45 # Get variant for user (sticky assignment) 

46 variant_version = await manager.get_variant_for_user( 

47 experiment.experiment_id, 

48 user_id="user123" 

49 ) 

50 

51 # Get random variant 

52 variant_version = await manager.get_random_variant( 

53 experiment.experiment_id 

54 ) 

55 ``` 

56 """ 

57 

58 def __init__(self, storage: Any | None = None): 

59 """Initialize A/B test manager. 

60 

61 Args: 

62 storage: Backend storage (dict for in-memory, database for persistence) 

63 If None, uses in-memory dictionary 

64 """ 

65 self.storage = storage if storage is not None else {} 

66 self._experiments: Dict[str, PromptExperiment] = {} # experiment_id -> PromptExperiment 

67 self._user_assignments: Dict[str, Dict[str, str]] = {} # experiment_id -> {user_id -> version} 

68 

69 async def create_experiment( 

70 self, 

71 name: str, 

72 prompt_type: str, 

73 variants: List[PromptVariant], 

74 traffic_split: Dict[str, float] | None = None, 

75 metadata: Dict[str, Any] | None = None, 

76 ) -> PromptExperiment: 

77 """Create a new A/B test experiment. 

78 

79 Args: 

80 name: Prompt name 

81 prompt_type: Prompt type 

82 variants: List of variants to test 

83 traffic_split: Optional custom traffic split (if None, derives from variant weights) 

84 metadata: Additional metadata 

85 

86 Returns: 

87 Created PromptExperiment 

88 

89 Raises: 

90 VersioningError: If variants are invalid or traffic split doesn't sum to 1.0 

91 """ 

92 if len(variants) < 2: 

93 raise VersioningError("Experiment must have at least 2 variants") 

94 

95 # Generate experiment ID 

96 experiment_id = str(uuid.uuid4()) 

97 

98 # Derive traffic split from variant weights if not provided 

99 if traffic_split is None: 

100 # Normalize weights to ensure they sum to 1.0 

101 total_weight = sum(v.weight for v in variants) 

102 traffic_split = { 

103 v.version: v.weight / total_weight 

104 for v in variants 

105 } 

106 

107 # Create experiment 

108 experiment = PromptExperiment( 

109 experiment_id=experiment_id, 

110 name=name, 

111 prompt_type=prompt_type, 

112 variants=variants, 

113 traffic_split=traffic_split, 

114 start_date=datetime.utcnow(), 

115 status="running", 

116 metadata=metadata or {}, 

117 ) 

118 

119 # Store experiment 

120 self._experiments[experiment_id] = experiment 

121 

122 # Initialize user assignments 

123 self._user_assignments[experiment_id] = {} 

124 

125 # Persist to backend if available 

126 if hasattr(self.storage, "set"): 

127 await self._persist_experiment(experiment) 

128 

129 return experiment 

130 

131 async def get_experiment( 

132 self, 

133 experiment_id: str, 

134 ) -> PromptExperiment | None: 

135 """Retrieve an experiment by ID. 

136 

137 Args: 

138 experiment_id: Experiment ID 

139 

140 Returns: 

141 PromptExperiment if found, None otherwise 

142 """ 

143 return self._experiments.get(experiment_id) 

144 

145 async def list_experiments( 

146 self, 

147 name: str | None = None, 

148 prompt_type: str | None = None, 

149 status: str | None = None, 

150 ) -> List[PromptExperiment]: 

151 """List experiments with optional filters. 

152 

153 Args: 

154 name: Filter by prompt name 

155 prompt_type: Filter by prompt type 

156 status: Filter by status ("running", "paused", "completed") 

157 

158 Returns: 

159 List of matching experiments 

160 """ 

161 experiments = list(self._experiments.values()) 

162 

163 # Apply filters 

164 if name: 

165 experiments = [e for e in experiments if e.name == name] 

166 

167 if prompt_type: 

168 experiments = [e for e in experiments if e.prompt_type == prompt_type] 

169 

170 if status: 

171 experiments = [e for e in experiments if e.status == status] 

172 

173 return experiments 

174 

175 async def get_random_variant( 

176 self, 

177 experiment_id: str, 

178 ) -> str: 

179 """Get a random variant based on traffic split. 

180 

181 Each call returns a potentially different variant. 

182 

183 Args: 

184 experiment_id: Experiment ID 

185 

186 Returns: 

187 Version string of selected variant 

188 

189 Raises: 

190 VersioningError: If experiment not found 

191 """ 

192 experiment = self._experiments.get(experiment_id) 

193 if not experiment: 

194 raise VersioningError(f"Experiment not found: {experiment_id}") 

195 

196 if experiment.status != "running": 

197 raise VersioningError( 

198 f"Experiment {experiment_id} is not running (status: {experiment.status})" 

199 ) 

200 

201 # Weighted random selection 

202 versions = list(experiment.traffic_split.keys()) 

203 weights = list(experiment.traffic_split.values()) 

204 

205 return random.choices(versions, weights=weights)[0] 

206 

207 async def get_variant_for_user( 

208 self, 

209 experiment_id: str, 

210 user_id: str, 

211 ) -> str: 

212 """Get variant for a specific user (sticky assignment). 

213 

214 The same user always gets the same variant for consistent experience. 

215 Uses hash-based assignment to ensure deterministic selection. 

216 

217 Args: 

218 experiment_id: Experiment ID 

219 user_id: User identifier 

220 

221 Returns: 

222 Version string of assigned variant 

223 

224 Raises: 

225 VersioningError: If experiment not found 

226 """ 

227 experiment = self._experiments.get(experiment_id) 

228 if not experiment: 

229 raise VersioningError(f"Experiment not found: {experiment_id}") 

230 

231 if experiment.status != "running": 

232 raise VersioningError( 

233 f"Experiment {experiment_id} is not running (status: {experiment.status})" 

234 ) 

235 

236 # Check if user already has assignment 

237 if experiment_id in self._user_assignments: 

238 existing = self._user_assignments[experiment_id].get(user_id) 

239 if existing: 

240 return existing 

241 

242 # Assign user to variant using hash-based selection 

243 assigned_version = self._hash_based_assignment( 

244 user_id, 

245 experiment.traffic_split 

246 ) 

247 

248 # Store assignment 

249 if experiment_id not in self._user_assignments: 

250 self._user_assignments[experiment_id] = {} 

251 self._user_assignments[experiment_id][user_id] = assigned_version 

252 

253 # Persist assignment if backend available 

254 if hasattr(self.storage, "set"): 

255 key = f"assignment:{experiment_id}:{user_id}" 

256 await self.storage.set(key, assigned_version) 

257 

258 return assigned_version 

259 

260 async def update_experiment_status( 

261 self, 

262 experiment_id: str, 

263 status: str, 

264 end_date: datetime | None = None, 

265 ) -> PromptExperiment: 

266 """Update experiment status. 

267 

268 Args: 

269 experiment_id: Experiment ID 

270 status: New status ("running", "paused", "completed") 

271 end_date: Optional end date (auto-set to now if status is "completed") 

272 

273 Returns: 

274 Updated experiment 

275 

276 Raises: 

277 VersioningError: If experiment not found 

278 """ 

279 experiment = self._experiments.get(experiment_id) 

280 if not experiment: 

281 raise VersioningError(f"Experiment not found: {experiment_id}") 

282 

283 experiment.status = status 

284 

285 if status == "completed" and end_date is None: 

286 experiment.end_date = datetime.utcnow() 

287 elif end_date: 

288 experiment.end_date = end_date 

289 

290 # Persist if backend available 

291 if hasattr(self.storage, "set"): 

292 await self._persist_experiment(experiment) 

293 

294 return experiment 

295 

296 async def get_user_assignment( 

297 self, 

298 experiment_id: str, 

299 user_id: str, 

300 ) -> str | None: 

301 """Get existing user assignment without creating a new one. 

302 

303 Args: 

304 experiment_id: Experiment ID 

305 user_id: User ID 

306 

307 Returns: 

308 Assigned version if exists, None otherwise 

309 """ 

310 if experiment_id not in self._user_assignments: 

311 return None 

312 return self._user_assignments[experiment_id].get(user_id) 

313 

314 async def get_experiment_assignments( 

315 self, 

316 experiment_id: str, 

317 ) -> Dict[str, str]: 

318 """Get all user assignments for an experiment. 

319 

320 Args: 

321 experiment_id: Experiment ID 

322 

323 Returns: 

324 Dictionary mapping user_id to assigned version 

325 """ 

326 return self._user_assignments.get(experiment_id, {}) 

327 

328 async def delete_experiment( 

329 self, 

330 experiment_id: str, 

331 ) -> bool: 

332 """Delete an experiment. 

333 

334 Note: This also removes all user assignments. 

335 

336 Args: 

337 experiment_id: Experiment ID 

338 

339 Returns: 

340 True if deleted, False if not found 

341 """ 

342 if experiment_id not in self._experiments: 

343 return False 

344 

345 # Remove experiment 

346 del self._experiments[experiment_id] 

347 

348 # Remove user assignments 

349 if experiment_id in self._user_assignments: 

350 del self._user_assignments[experiment_id] 

351 

352 # Persist deletion if backend available 

353 if hasattr(self.storage, "delete"): 

354 await self.storage.delete(f"experiment:{experiment_id}") 

355 

356 return True 

357 

358 # ===== Helper Methods ===== 

359 

360 def _hash_based_assignment( 

361 self, 

362 user_id: str, 

363 traffic_split: Dict[str, float], 

364 ) -> str: 

365 """Assign user to variant using consistent hash-based selection. 

366 

367 This ensures the same user always gets the same variant. 

368 

369 Args: 

370 user_id: User identifier 

371 traffic_split: Version to percentage mapping 

372 

373 Returns: 

374 Selected version string 

375 """ 

376 # Hash user_id to get a deterministic number 

377 hash_hex = hashlib.md5(user_id.encode()).hexdigest() 

378 hash_val = int(hash_hex, 16) 

379 

380 # Map to [0, 1) range 

381 normalized = (hash_val % 1000) / 1000.0 

382 

383 # Select variant based on cumulative traffic split 

384 cumulative = 0.0 

385 versions = sorted(traffic_split.keys()) # Sort for consistency 

386 

387 for version in versions: 

388 cumulative += traffic_split[version] 

389 if normalized < cumulative: 

390 return version 

391 

392 # Fallback to last version (handles floating point errors) 

393 return versions[-1] 

394 

395 async def _persist_experiment(self, experiment: PromptExperiment): 

396 """Persist experiment to backend storage.""" 

397 if hasattr(self.storage, "set"): 

398 key = f"experiment:{experiment.experiment_id}" 

399 await self.storage.set(key, experiment.to_dict()) 

400 

401 async def get_variant_distribution( 

402 self, 

403 experiment_id: str, 

404 ) -> Dict[str, int]: 

405 """Get actual distribution of users across variants. 

406 

407 Args: 

408 experiment_id: Experiment ID 

409 

410 Returns: 

411 Dictionary mapping version to user count 

412 

413 Raises: 

414 VersioningError: If experiment not found 

415 """ 

416 experiment = self._experiments.get(experiment_id) 

417 if not experiment: 

418 raise VersioningError(f"Experiment not found: {experiment_id}") 

419 

420 assignments = self._user_assignments.get(experiment_id, {}) 

421 distribution: Dict[str, int] = {v.version: 0 for v in experiment.variants} 

422 

423 for version in assignments.values(): 

424 if version in distribution: 

425 distribution[version] += 1 

426 

427 return distribution