Coverage for src/dataknobs_llm/prompts/versioning/ab_testing.py: 83%
114 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
1"""A/B testing management for prompt experiments.
3This module provides:
4- Experiment creation and management
5- Random variant selection
6- User-sticky variant assignment
7- Traffic split management
8"""
10import uuid
11import hashlib
12import random
13from typing import Any, Dict, List
14from datetime import datetime
16from .types import (
17 PromptExperiment,
18 PromptVariant,
19 VersioningError,
20)
23class ABTestManager:
24 """Manages A/B test experiments for prompts.
26 Supports multiple selection strategies:
27 - Random: Each request gets a random variant based on traffic split
28 - User-sticky: Same user always gets same variant (consistent experience)
30 Example:
31 ```python
32 manager = ABTestManager(storage_backend)
34 # Create experiment
35 experiment = await manager.create_experiment(
36 name="greeting",
37 prompt_type="system",
38 variants=[
39 PromptVariant("1.0.0", 0.5, "Control"),
40 PromptVariant("1.1.0", 0.5, "Treatment")
41 ]
42 )
44 # Get variant for user (sticky assignment)
45 variant_version = await manager.get_variant_for_user(
46 experiment.experiment_id,
47 user_id="user123"
48 )
50 # Get random variant
51 variant_version = await manager.get_random_variant(
52 experiment.experiment_id
53 )
54 ```
55 """
57 def __init__(self, storage: Any | None = None):
58 """Initialize A/B test manager.
60 Args:
61 storage: Backend storage (dict for in-memory, database for persistence)
62 If None, uses in-memory dictionary
63 """
64 self.storage = storage if storage is not None else {}
65 self._experiments: Dict[str, PromptExperiment] = {} # experiment_id -> PromptExperiment
66 self._user_assignments: Dict[str, Dict[str, str]] = {} # experiment_id -> {user_id -> version}
68 async def create_experiment(
69 self,
70 name: str,
71 prompt_type: str,
72 variants: List[PromptVariant],
73 traffic_split: Dict[str, float] | None = None,
74 metadata: Dict[str, Any] | None = None,
75 ) -> PromptExperiment:
76 """Create a new A/B test experiment.
78 Args:
79 name: Prompt name
80 prompt_type: Prompt type
81 variants: List of variants to test
82 traffic_split: Optional custom traffic split (if None, derives from variant weights)
83 metadata: Additional metadata
85 Returns:
86 Created PromptExperiment
88 Raises:
89 VersioningError: If variants are invalid or traffic split doesn't sum to 1.0
90 """
91 if len(variants) < 2:
92 raise VersioningError("Experiment must have at least 2 variants")
94 # Generate experiment ID
95 experiment_id = str(uuid.uuid4())
97 # Derive traffic split from variant weights if not provided
98 if traffic_split is None:
99 # Normalize weights to ensure they sum to 1.0
100 total_weight = sum(v.weight for v in variants)
101 traffic_split = {
102 v.version: v.weight / total_weight
103 for v in variants
104 }
106 # Create experiment
107 experiment = PromptExperiment(
108 experiment_id=experiment_id,
109 name=name,
110 prompt_type=prompt_type,
111 variants=variants,
112 traffic_split=traffic_split,
113 start_date=datetime.utcnow(),
114 status="running",
115 metadata=metadata or {},
116 )
118 # Store experiment
119 self._experiments[experiment_id] = experiment
121 # Initialize user assignments
122 self._user_assignments[experiment_id] = {}
124 # Persist to backend if available
125 if hasattr(self.storage, "set"):
126 await self._persist_experiment(experiment)
128 return experiment
130 async def get_experiment(
131 self,
132 experiment_id: str,
133 ) -> PromptExperiment | None:
134 """Retrieve an experiment by ID.
136 Args:
137 experiment_id: Experiment ID
139 Returns:
140 PromptExperiment if found, None otherwise
141 """
142 return self._experiments.get(experiment_id)
144 async def list_experiments(
145 self,
146 name: str | None = None,
147 prompt_type: str | None = None,
148 status: str | None = None,
149 ) -> List[PromptExperiment]:
150 """List experiments with optional filters.
152 Args:
153 name: Filter by prompt name
154 prompt_type: Filter by prompt type
155 status: Filter by status ("running", "paused", "completed")
157 Returns:
158 List of matching experiments
159 """
160 experiments = list(self._experiments.values())
162 # Apply filters
163 if name:
164 experiments = [e for e in experiments if e.name == name]
166 if prompt_type:
167 experiments = [e for e in experiments if e.prompt_type == prompt_type]
169 if status:
170 experiments = [e for e in experiments if e.status == status]
172 return experiments
174 async def get_random_variant(
175 self,
176 experiment_id: str,
177 ) -> str:
178 """Get a random variant based on traffic split.
180 Each call returns a potentially different variant.
182 Args:
183 experiment_id: Experiment ID
185 Returns:
186 Version string of selected variant
188 Raises:
189 VersioningError: If experiment not found
190 """
191 experiment = self._experiments.get(experiment_id)
192 if not experiment:
193 raise VersioningError(f"Experiment not found: {experiment_id}")
195 if experiment.status != "running":
196 raise VersioningError(
197 f"Experiment {experiment_id} is not running (status: {experiment.status})"
198 )
200 # Weighted random selection
201 versions = list(experiment.traffic_split.keys())
202 weights = list(experiment.traffic_split.values())
204 return random.choices(versions, weights=weights)[0]
206 async def get_variant_for_user(
207 self,
208 experiment_id: str,
209 user_id: str,
210 ) -> str:
211 """Get variant for a specific user (sticky assignment).
213 The same user always gets the same variant for consistent experience.
214 Uses hash-based assignment to ensure deterministic selection.
216 Args:
217 experiment_id: Experiment ID
218 user_id: User identifier
220 Returns:
221 Version string of assigned variant
223 Raises:
224 VersioningError: If experiment not found
225 """
226 experiment = self._experiments.get(experiment_id)
227 if not experiment:
228 raise VersioningError(f"Experiment not found: {experiment_id}")
230 if experiment.status != "running":
231 raise VersioningError(
232 f"Experiment {experiment_id} is not running (status: {experiment.status})"
233 )
235 # Check if user already has assignment
236 if experiment_id in self._user_assignments:
237 existing = self._user_assignments[experiment_id].get(user_id)
238 if existing:
239 return existing
241 # Assign user to variant using hash-based selection
242 assigned_version = self._hash_based_assignment(
243 user_id,
244 experiment.traffic_split
245 )
247 # Store assignment
248 if experiment_id not in self._user_assignments:
249 self._user_assignments[experiment_id] = {}
250 self._user_assignments[experiment_id][user_id] = assigned_version
252 # Persist assignment if backend available
253 if hasattr(self.storage, "set"):
254 key = f"assignment:{experiment_id}:{user_id}"
255 await self.storage.set(key, assigned_version)
257 return assigned_version
259 async def update_experiment_status(
260 self,
261 experiment_id: str,
262 status: str,
263 end_date: datetime | None = None,
264 ) -> PromptExperiment:
265 """Update experiment status.
267 Args:
268 experiment_id: Experiment ID
269 status: New status ("running", "paused", "completed")
270 end_date: Optional end date (auto-set to now if status is "completed")
272 Returns:
273 Updated experiment
275 Raises:
276 VersioningError: If experiment not found
277 """
278 experiment = self._experiments.get(experiment_id)
279 if not experiment:
280 raise VersioningError(f"Experiment not found: {experiment_id}")
282 experiment.status = status
284 if status == "completed" and end_date is None:
285 experiment.end_date = datetime.utcnow()
286 elif end_date:
287 experiment.end_date = end_date
289 # Persist if backend available
290 if hasattr(self.storage, "set"):
291 await self._persist_experiment(experiment)
293 return experiment
295 async def get_user_assignment(
296 self,
297 experiment_id: str,
298 user_id: str,
299 ) -> str | None:
300 """Get existing user assignment without creating a new one.
302 Args:
303 experiment_id: Experiment ID
304 user_id: User ID
306 Returns:
307 Assigned version if exists, None otherwise
308 """
309 if experiment_id not in self._user_assignments:
310 return None
311 return self._user_assignments[experiment_id].get(user_id)
313 async def get_experiment_assignments(
314 self,
315 experiment_id: str,
316 ) -> Dict[str, str]:
317 """Get all user assignments for an experiment.
319 Args:
320 experiment_id: Experiment ID
322 Returns:
323 Dictionary mapping user_id to assigned version
324 """
325 return self._user_assignments.get(experiment_id, {})
327 async def delete_experiment(
328 self,
329 experiment_id: str,
330 ) -> bool:
331 """Delete an experiment.
333 Note: This also removes all user assignments.
335 Args:
336 experiment_id: Experiment ID
338 Returns:
339 True if deleted, False if not found
340 """
341 if experiment_id not in self._experiments:
342 return False
344 # Remove experiment
345 del self._experiments[experiment_id]
347 # Remove user assignments
348 if experiment_id in self._user_assignments:
349 del self._user_assignments[experiment_id]
351 # Persist deletion if backend available
352 if hasattr(self.storage, "delete"):
353 await self.storage.delete(f"experiment:{experiment_id}")
355 return True
357 # ===== Helper Methods =====
359 def _hash_based_assignment(
360 self,
361 user_id: str,
362 traffic_split: Dict[str, float],
363 ) -> str:
364 """Assign user to variant using consistent hash-based selection.
366 This ensures the same user always gets the same variant.
368 Args:
369 user_id: User identifier
370 traffic_split: Version to percentage mapping
372 Returns:
373 Selected version string
374 """
375 # Hash user_id to get a deterministic number
376 hash_hex = hashlib.md5(user_id.encode()).hexdigest()
377 hash_val = int(hash_hex, 16)
379 # Map to [0, 1) range
380 normalized = (hash_val % 1000) / 1000.0
382 # Select variant based on cumulative traffic split
383 cumulative = 0.0
384 versions = sorted(traffic_split.keys()) # Sort for consistency
386 for version in versions:
387 cumulative += traffic_split[version]
388 if normalized < cumulative:
389 return version
391 # Fallback to last version (handles floating point errors)
392 return versions[-1]
394 async def _persist_experiment(self, experiment: PromptExperiment):
395 """Persist experiment to backend storage."""
396 if hasattr(self.storage, "set"):
397 key = f"experiment:{experiment.experiment_id}"
398 await self.storage.set(key, experiment.to_dict())
400 async def get_variant_distribution(
401 self,
402 experiment_id: str,
403 ) -> Dict[str, int]:
404 """Get actual distribution of users across variants.
406 Args:
407 experiment_id: Experiment ID
409 Returns:
410 Dictionary mapping version to user count
412 Raises:
413 VersioningError: If experiment not found
414 """
415 experiment = self._experiments.get(experiment_id)
416 if not experiment:
417 raise VersioningError(f"Experiment not found: {experiment_id}")
419 assignments = self._user_assignments.get(experiment_id, {})
420 distribution: Dict[str, int] = {v.version: 0 for v in experiment.variants}
422 for version in assignments.values():
423 if version in distribution:
424 distribution[version] += 1
426 return distribution