Coverage for llm_dataset_engine/stages/multi_run_stage.py: 40%
58 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
1"""
2Multi-run stage for executing stages multiple times with aggregation.
4Implements Decorator pattern to wrap any stage and run it multiple times.
5"""
7from abc import ABC, abstractmethod
8from collections import Counter
9from decimal import Decimal
10from typing import Any, Generic, List, TypeVar
12from llm_dataset_engine.core.models import CostEstimate, ValidationResult
13from llm_dataset_engine.stages.pipeline_stage import PipelineStage, TInput, TOutput
15T = TypeVar("T")
18class AggregationStrategy(ABC, Generic[T]):
19 """
20 Abstract base for aggregation strategies.
22 Follows Strategy pattern for different ways to aggregate results.
23 """
25 @abstractmethod
26 def aggregate(self, results: List[T]) -> T:
27 """
28 Aggregate multiple results into one.
30 Args:
31 results: List of results from multiple runs
33 Returns:
34 Aggregated result
35 """
36 pass
39class ConsensusStrategy(AggregationStrategy[str]):
40 """Returns most common result (consensus voting)."""
42 def aggregate(self, results: List[str]) -> str:
43 """Return most frequent result."""
44 if not results:
45 return ""
47 # Count occurrences
48 counter = Counter(results)
49 most_common = counter.most_common(1)[0][0]
51 return most_common
54class FirstSuccessStrategy(AggregationStrategy[T]):
55 """Returns first successful (non-None) result."""
57 def aggregate(self, results: List[T]) -> T:
58 """Return first non-None result."""
59 for result in results:
60 if result is not None:
61 return result
62 return results[0] if results else None
65class AllStrategy(AggregationStrategy[T]):
66 """Returns all results as list (no aggregation)."""
68 def aggregate(self, results: List[T]) -> List[T]:
69 """Return all results."""
70 return results
73class AverageStrategy(AggregationStrategy[float]):
74 """Returns average of numeric results."""
76 def aggregate(self, results: List[float]) -> float:
77 """Return average."""
78 if not results:
79 return 0.0
80 return sum(results) / len(results)
83class MultiRunStage(PipelineStage[TInput, TOutput]):
84 """
85 Decorator stage that runs wrapped stage multiple times.
87 Use cases:
88 - Run LLM 3 times, take consensus (reduce hallucinations)
89 - Retry until success
90 - Collect multiple responses for analysis
92 Example:
93 multi_llm = MultiRunStage(
94 wrapped=LLMInvocationStage(...),
95 num_runs=3,
96 aggregation=ConsensusStrategy()
97 )
98 """
100 def __init__(
101 self,
102 wrapped_stage: PipelineStage[TInput, TOutput],
103 num_runs: int = 3,
104 aggregation_strategy: AggregationStrategy | None = None,
105 ):
106 """
107 Initialize multi-run stage.
109 Args:
110 wrapped_stage: Stage to execute multiple times
111 num_runs: Number of times to run
112 aggregation_strategy: Strategy for aggregating results
113 """
114 super().__init__(f"MultiRun({wrapped_stage.name})")
115 self.wrapped_stage = wrapped_stage
116 self.num_runs = num_runs
117 self.aggregation_strategy = (
118 aggregation_strategy or ConsensusStrategy()
119 )
121 def process(self, input_data: TInput, context: Any) -> TOutput:
122 """Execute wrapped stage multiple times and aggregate."""
123 results: List[TOutput] = []
125 self.logger.info(
126 f"Running {self.wrapped_stage.name} {self.num_runs} times"
127 )
129 for run_num in range(self.num_runs):
130 try:
131 result = self.wrapped_stage.process(input_data, context)
132 results.append(result)
133 except Exception as e:
134 self.logger.error(
135 f"Run {run_num + 1}/{self.num_runs} failed: {e}"
136 )
137 # Continue with other runs
138 continue
140 if not results:
141 raise RuntimeError(
142 f"All {self.num_runs} runs failed for {self.wrapped_stage.name}"
143 )
145 # Aggregate results
146 aggregated = self.aggregation_strategy.aggregate(results)
148 self.logger.info(
149 f"Aggregated {len(results)} results using "
150 f"{self.aggregation_strategy.__class__.__name__}"
151 )
153 return aggregated
155 def validate_input(self, input_data: TInput) -> ValidationResult:
156 """Delegate validation to wrapped stage."""
157 return self.wrapped_stage.validate_input(input_data)
159 def estimate_cost(self, input_data: TInput) -> CostEstimate:
160 """Estimate cost as num_runs × wrapped stage cost."""
161 single_run_cost = self.wrapped_stage.estimate_cost(input_data)
163 return CostEstimate(
164 total_cost=single_run_cost.total_cost * self.num_runs,
165 total_tokens=single_run_cost.total_tokens * self.num_runs,
166 input_tokens=single_run_cost.input_tokens * self.num_runs,
167 output_tokens=single_run_cost.output_tokens * self.num_runs,
168 rows=single_run_cost.rows,
169 confidence=f"{single_run_cost.confidence} × {self.num_runs} runs",
170 )