Coverage for src / dataknobs_llm / prompts / versioning / ab_testing.py: 18%
115 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-15 10:28 -0700
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-15 10:28 -0700
1"""A/B testing management for prompt experiments.
3This module provides:
4- Experiment creation and management
5- Random variant selection
6- User-sticky variant assignment
7- Traffic split management
8"""
10import uuid
11import hashlib
12import random
13from typing import Any, Dict, List
14from datetime import datetime
16from dataknobs_llm.exceptions import VersioningError
18from .types import (
19 PromptExperiment,
20 PromptVariant,
21)
24class ABTestManager:
25 """Manages A/B test experiments for prompts.
27 Supports multiple selection strategies:
28 - Random: Each request gets a random variant based on traffic split
29 - User-sticky: Same user always gets same variant (consistent experience)
31 Example:
32 ```python
33 manager = ABTestManager(storage_backend)
35 # Create experiment
36 experiment = await manager.create_experiment(
37 name="greeting",
38 prompt_type="system",
39 variants=[
40 PromptVariant("1.0.0", 0.5, "Control"),
41 PromptVariant("1.1.0", 0.5, "Treatment")
42 ]
43 )
45 # Get variant for user (sticky assignment)
46 variant_version = await manager.get_variant_for_user(
47 experiment.experiment_id,
48 user_id="user123"
49 )
51 # Get random variant
52 variant_version = await manager.get_random_variant(
53 experiment.experiment_id
54 )
55 ```
56 """
58 def __init__(self, storage: Any | None = None):
59 """Initialize A/B test manager.
61 Args:
62 storage: Backend storage (dict for in-memory, database for persistence)
63 If None, uses in-memory dictionary
64 """
65 self.storage = storage if storage is not None else {}
66 self._experiments: Dict[str, PromptExperiment] = {} # experiment_id -> PromptExperiment
67 self._user_assignments: Dict[str, Dict[str, str]] = {} # experiment_id -> {user_id -> version}
69 async def create_experiment(
70 self,
71 name: str,
72 prompt_type: str,
73 variants: List[PromptVariant],
74 traffic_split: Dict[str, float] | None = None,
75 metadata: Dict[str, Any] | None = None,
76 ) -> PromptExperiment:
77 """Create a new A/B test experiment.
79 Args:
80 name: Prompt name
81 prompt_type: Prompt type
82 variants: List of variants to test
83 traffic_split: Optional custom traffic split (if None, derives from variant weights)
84 metadata: Additional metadata
86 Returns:
87 Created PromptExperiment
89 Raises:
90 VersioningError: If variants are invalid or traffic split doesn't sum to 1.0
91 """
92 if len(variants) < 2:
93 raise VersioningError("Experiment must have at least 2 variants")
95 # Generate experiment ID
96 experiment_id = str(uuid.uuid4())
98 # Derive traffic split from variant weights if not provided
99 if traffic_split is None:
100 # Normalize weights to ensure they sum to 1.0
101 total_weight = sum(v.weight for v in variants)
102 traffic_split = {
103 v.version: v.weight / total_weight
104 for v in variants
105 }
107 # Create experiment
108 experiment = PromptExperiment(
109 experiment_id=experiment_id,
110 name=name,
111 prompt_type=prompt_type,
112 variants=variants,
113 traffic_split=traffic_split,
114 start_date=datetime.utcnow(),
115 status="running",
116 metadata=metadata or {},
117 )
119 # Store experiment
120 self._experiments[experiment_id] = experiment
122 # Initialize user assignments
123 self._user_assignments[experiment_id] = {}
125 # Persist to backend if available
126 if hasattr(self.storage, "set"):
127 await self._persist_experiment(experiment)
129 return experiment
131 async def get_experiment(
132 self,
133 experiment_id: str,
134 ) -> PromptExperiment | None:
135 """Retrieve an experiment by ID.
137 Args:
138 experiment_id: Experiment ID
140 Returns:
141 PromptExperiment if found, None otherwise
142 """
143 return self._experiments.get(experiment_id)
145 async def list_experiments(
146 self,
147 name: str | None = None,
148 prompt_type: str | None = None,
149 status: str | None = None,
150 ) -> List[PromptExperiment]:
151 """List experiments with optional filters.
153 Args:
154 name: Filter by prompt name
155 prompt_type: Filter by prompt type
156 status: Filter by status ("running", "paused", "completed")
158 Returns:
159 List of matching experiments
160 """
161 experiments = list(self._experiments.values())
163 # Apply filters
164 if name:
165 experiments = [e for e in experiments if e.name == name]
167 if prompt_type:
168 experiments = [e for e in experiments if e.prompt_type == prompt_type]
170 if status:
171 experiments = [e for e in experiments if e.status == status]
173 return experiments
175 async def get_random_variant(
176 self,
177 experiment_id: str,
178 ) -> str:
179 """Get a random variant based on traffic split.
181 Each call returns a potentially different variant.
183 Args:
184 experiment_id: Experiment ID
186 Returns:
187 Version string of selected variant
189 Raises:
190 VersioningError: If experiment not found
191 """
192 experiment = self._experiments.get(experiment_id)
193 if not experiment:
194 raise VersioningError(f"Experiment not found: {experiment_id}")
196 if experiment.status != "running":
197 raise VersioningError(
198 f"Experiment {experiment_id} is not running (status: {experiment.status})"
199 )
201 # Weighted random selection
202 versions = list(experiment.traffic_split.keys())
203 weights = list(experiment.traffic_split.values())
205 return random.choices(versions, weights=weights)[0]
207 async def get_variant_for_user(
208 self,
209 experiment_id: str,
210 user_id: str,
211 ) -> str:
212 """Get variant for a specific user (sticky assignment).
214 The same user always gets the same variant for consistent experience.
215 Uses hash-based assignment to ensure deterministic selection.
217 Args:
218 experiment_id: Experiment ID
219 user_id: User identifier
221 Returns:
222 Version string of assigned variant
224 Raises:
225 VersioningError: If experiment not found
226 """
227 experiment = self._experiments.get(experiment_id)
228 if not experiment:
229 raise VersioningError(f"Experiment not found: {experiment_id}")
231 if experiment.status != "running":
232 raise VersioningError(
233 f"Experiment {experiment_id} is not running (status: {experiment.status})"
234 )
236 # Check if user already has assignment
237 if experiment_id in self._user_assignments:
238 existing = self._user_assignments[experiment_id].get(user_id)
239 if existing:
240 return existing
242 # Assign user to variant using hash-based selection
243 assigned_version = self._hash_based_assignment(
244 user_id,
245 experiment.traffic_split
246 )
248 # Store assignment
249 if experiment_id not in self._user_assignments:
250 self._user_assignments[experiment_id] = {}
251 self._user_assignments[experiment_id][user_id] = assigned_version
253 # Persist assignment if backend available
254 if hasattr(self.storage, "set"):
255 key = f"assignment:{experiment_id}:{user_id}"
256 await self.storage.set(key, assigned_version)
258 return assigned_version
260 async def update_experiment_status(
261 self,
262 experiment_id: str,
263 status: str,
264 end_date: datetime | None = None,
265 ) -> PromptExperiment:
266 """Update experiment status.
268 Args:
269 experiment_id: Experiment ID
270 status: New status ("running", "paused", "completed")
271 end_date: Optional end date (auto-set to now if status is "completed")
273 Returns:
274 Updated experiment
276 Raises:
277 VersioningError: If experiment not found
278 """
279 experiment = self._experiments.get(experiment_id)
280 if not experiment:
281 raise VersioningError(f"Experiment not found: {experiment_id}")
283 experiment.status = status
285 if status == "completed" and end_date is None:
286 experiment.end_date = datetime.utcnow()
287 elif end_date:
288 experiment.end_date = end_date
290 # Persist if backend available
291 if hasattr(self.storage, "set"):
292 await self._persist_experiment(experiment)
294 return experiment
296 async def get_user_assignment(
297 self,
298 experiment_id: str,
299 user_id: str,
300 ) -> str | None:
301 """Get existing user assignment without creating a new one.
303 Args:
304 experiment_id: Experiment ID
305 user_id: User ID
307 Returns:
308 Assigned version if exists, None otherwise
309 """
310 if experiment_id not in self._user_assignments:
311 return None
312 return self._user_assignments[experiment_id].get(user_id)
314 async def get_experiment_assignments(
315 self,
316 experiment_id: str,
317 ) -> Dict[str, str]:
318 """Get all user assignments for an experiment.
320 Args:
321 experiment_id: Experiment ID
323 Returns:
324 Dictionary mapping user_id to assigned version
325 """
326 return self._user_assignments.get(experiment_id, {})
328 async def delete_experiment(
329 self,
330 experiment_id: str,
331 ) -> bool:
332 """Delete an experiment.
334 Note: This also removes all user assignments.
336 Args:
337 experiment_id: Experiment ID
339 Returns:
340 True if deleted, False if not found
341 """
342 if experiment_id not in self._experiments:
343 return False
345 # Remove experiment
346 del self._experiments[experiment_id]
348 # Remove user assignments
349 if experiment_id in self._user_assignments:
350 del self._user_assignments[experiment_id]
352 # Persist deletion if backend available
353 if hasattr(self.storage, "delete"):
354 await self.storage.delete(f"experiment:{experiment_id}")
356 return True
358 # ===== Helper Methods =====
360 def _hash_based_assignment(
361 self,
362 user_id: str,
363 traffic_split: Dict[str, float],
364 ) -> str:
365 """Assign user to variant using consistent hash-based selection.
367 This ensures the same user always gets the same variant.
369 Args:
370 user_id: User identifier
371 traffic_split: Version to percentage mapping
373 Returns:
374 Selected version string
375 """
376 # Hash user_id to get a deterministic number
377 hash_hex = hashlib.md5(user_id.encode()).hexdigest()
378 hash_val = int(hash_hex, 16)
380 # Map to [0, 1) range
381 normalized = (hash_val % 1000) / 1000.0
383 # Select variant based on cumulative traffic split
384 cumulative = 0.0
385 versions = sorted(traffic_split.keys()) # Sort for consistency
387 for version in versions:
388 cumulative += traffic_split[version]
389 if normalized < cumulative:
390 return version
392 # Fallback to last version (handles floating point errors)
393 return versions[-1]
395 async def _persist_experiment(self, experiment: PromptExperiment):
396 """Persist experiment to backend storage."""
397 if hasattr(self.storage, "set"):
398 key = f"experiment:{experiment.experiment_id}"
399 await self.storage.set(key, experiment.to_dict())
401 async def get_variant_distribution(
402 self,
403 experiment_id: str,
404 ) -> Dict[str, int]:
405 """Get actual distribution of users across variants.
407 Args:
408 experiment_id: Experiment ID
410 Returns:
411 Dictionary mapping version to user count
413 Raises:
414 VersioningError: If experiment not found
415 """
416 experiment = self._experiments.get(experiment_id)
417 if not experiment:
418 raise VersioningError(f"Experiment not found: {experiment_id}")
420 assignments = self._user_assignments.get(experiment_id, {})
421 distribution: Dict[str, int] = {v.version: 0 for v in experiment.variants}
423 for version in assignments.values():
424 if version in distribution:
425 distribution[version] += 1
427 return distribution