Coverage for llm_dataset_engine/stages/multi_run_stage.py: 40%

58 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-15 18:04 +0200

1""" 

2Multi-run stage for executing stages multiple times with aggregation. 

3 

4Implements Decorator pattern to wrap any stage and run it multiple times. 

5""" 

6 

7from abc import ABC, abstractmethod 

8from collections import Counter 

9from decimal import Decimal 

10from typing import Any, Generic, List, TypeVar 

11 

12from llm_dataset_engine.core.models import CostEstimate, ValidationResult 

13from llm_dataset_engine.stages.pipeline_stage import PipelineStage, TInput, TOutput 

14 

15T = TypeVar("T") 

16 

17 

18class AggregationStrategy(ABC, Generic[T]): 

19 """ 

20 Abstract base for aggregation strategies. 

21  

22 Follows Strategy pattern for different ways to aggregate results. 

23 """ 

24 

25 @abstractmethod 

26 def aggregate(self, results: List[T]) -> T: 

27 """ 

28 Aggregate multiple results into one. 

29 

30 Args: 

31 results: List of results from multiple runs 

32 

33 Returns: 

34 Aggregated result 

35 """ 

36 pass 

37 

38 

39class ConsensusStrategy(AggregationStrategy[str]): 

40 """Returns most common result (consensus voting).""" 

41 

42 def aggregate(self, results: List[str]) -> str: 

43 """Return most frequent result.""" 

44 if not results: 

45 return "" 

46 

47 # Count occurrences 

48 counter = Counter(results) 

49 most_common = counter.most_common(1)[0][0] 

50 

51 return most_common 

52 

53 

54class FirstSuccessStrategy(AggregationStrategy[T]): 

55 """Returns first successful (non-None) result.""" 

56 

57 def aggregate(self, results: List[T]) -> T: 

58 """Return first non-None result.""" 

59 for result in results: 

60 if result is not None: 

61 return result 

62 return results[0] if results else None 

63 

64 

65class AllStrategy(AggregationStrategy[T]): 

66 """Returns all results as list (no aggregation).""" 

67 

68 def aggregate(self, results: List[T]) -> List[T]: 

69 """Return all results.""" 

70 return results 

71 

72 

73class AverageStrategy(AggregationStrategy[float]): 

74 """Returns average of numeric results.""" 

75 

76 def aggregate(self, results: List[float]) -> float: 

77 """Return average.""" 

78 if not results: 

79 return 0.0 

80 return sum(results) / len(results) 

81 

82 

83class MultiRunStage(PipelineStage[TInput, TOutput]): 

84 """ 

85 Decorator stage that runs wrapped stage multiple times. 

86  

87 Use cases: 

88 - Run LLM 3 times, take consensus (reduce hallucinations) 

89 - Retry until success 

90 - Collect multiple responses for analysis 

91  

92 Example: 

93 multi_llm = MultiRunStage( 

94 wrapped=LLMInvocationStage(...), 

95 num_runs=3, 

96 aggregation=ConsensusStrategy() 

97 ) 

98 """ 

99 

100 def __init__( 

101 self, 

102 wrapped_stage: PipelineStage[TInput, TOutput], 

103 num_runs: int = 3, 

104 aggregation_strategy: AggregationStrategy | None = None, 

105 ): 

106 """ 

107 Initialize multi-run stage. 

108 

109 Args: 

110 wrapped_stage: Stage to execute multiple times 

111 num_runs: Number of times to run 

112 aggregation_strategy: Strategy for aggregating results 

113 """ 

114 super().__init__(f"MultiRun({wrapped_stage.name})") 

115 self.wrapped_stage = wrapped_stage 

116 self.num_runs = num_runs 

117 self.aggregation_strategy = ( 

118 aggregation_strategy or ConsensusStrategy() 

119 ) 

120 

121 def process(self, input_data: TInput, context: Any) -> TOutput: 

122 """Execute wrapped stage multiple times and aggregate.""" 

123 results: List[TOutput] = [] 

124 

125 self.logger.info( 

126 f"Running {self.wrapped_stage.name} {self.num_runs} times" 

127 ) 

128 

129 for run_num in range(self.num_runs): 

130 try: 

131 result = self.wrapped_stage.process(input_data, context) 

132 results.append(result) 

133 except Exception as e: 

134 self.logger.error( 

135 f"Run {run_num + 1}/{self.num_runs} failed: {e}" 

136 ) 

137 # Continue with other runs 

138 continue 

139 

140 if not results: 

141 raise RuntimeError( 

142 f"All {self.num_runs} runs failed for {self.wrapped_stage.name}" 

143 ) 

144 

145 # Aggregate results 

146 aggregated = self.aggregation_strategy.aggregate(results) 

147 

148 self.logger.info( 

149 f"Aggregated {len(results)} results using " 

150 f"{self.aggregation_strategy.__class__.__name__}" 

151 ) 

152 

153 return aggregated 

154 

155 def validate_input(self, input_data: TInput) -> ValidationResult: 

156 """Delegate validation to wrapped stage.""" 

157 return self.wrapped_stage.validate_input(input_data) 

158 

159 def estimate_cost(self, input_data: TInput) -> CostEstimate: 

160 """Estimate cost as num_runs × wrapped stage cost.""" 

161 single_run_cost = self.wrapped_stage.estimate_cost(input_data) 

162 

163 return CostEstimate( 

164 total_cost=single_run_cost.total_cost * self.num_runs, 

165 total_tokens=single_run_cost.total_tokens * self.num_runs, 

166 input_tokens=single_run_cost.input_tokens * self.num_runs, 

167 output_tokens=single_run_cost.output_tokens * self.num_runs, 

168 rows=single_run_cost.rows, 

169 confidence=f"{single_run_cost.confidence} × {self.num_runs} runs", 

170 ) 

171