Coverage for llm_dataset_engine/stages/llm_invocation_stage.py: 20%

86 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-15 18:04 +0200

1"""LLM invocation stage with concurrency and retry logic.""" 

2 

3import concurrent.futures 

4import time 

5from decimal import Decimal 

6from typing import Any, List 

7 

8from llm_dataset_engine.adapters.llm_client import LLMClient 

9from llm_dataset_engine.core.error_handler import ErrorAction, ErrorHandler 

10from llm_dataset_engine.core.models import ( 

11 CostEstimate, 

12 LLMResponse, 

13 PromptBatch, 

14 ResponseBatch, 

15 ValidationResult, 

16) 

17from llm_dataset_engine.core.specifications import ErrorPolicy 

18from llm_dataset_engine.stages.pipeline_stage import PipelineStage 

19from llm_dataset_engine.utils import ( 

20 NetworkError, 

21 RateLimitError, 

22 RateLimiter, 

23 RetryHandler, 

24) 

25 

26 

27class LLMInvocationStage( 

28 PipelineStage[List[PromptBatch], List[ResponseBatch]] 

29): 

30 """ 

31 Invoke LLM with prompts using concurrency and retries. 

32  

33 Responsibilities: 

34 - Execute LLM calls with rate limiting 

35 - Handle retries for transient failures 

36 - Track tokens and costs 

37 - Support concurrent processing 

38 """ 

39 

40 def __init__( 

41 self, 

42 llm_client: LLMClient, 

43 concurrency: int = 5, 

44 rate_limiter: RateLimiter | None = None, 

45 retry_handler: RetryHandler | None = None, 

46 error_policy: ErrorPolicy = ErrorPolicy.SKIP, 

47 max_retries: int = 3, 

48 ): 

49 """ 

50 Initialize LLM invocation stage. 

51 

52 Args: 

53 llm_client: LLM client instance 

54 concurrency: Max concurrent requests 

55 rate_limiter: Optional rate limiter 

56 retry_handler: Optional retry handler 

57 error_policy: Policy for handling errors 

58 max_retries: Maximum retry attempts 

59 """ 

60 super().__init__("LLMInvocation") 

61 self.llm_client = llm_client 

62 self.concurrency = concurrency 

63 self.rate_limiter = rate_limiter 

64 self.retry_handler = retry_handler or RetryHandler() 

65 self.error_handler = ErrorHandler( 

66 policy=error_policy, 

67 max_retries=max_retries, 

68 default_value_factory=lambda: LLMResponse( 

69 text="", 

70 tokens_in=0, 

71 tokens_out=0, 

72 model=llm_client.model, 

73 cost=Decimal("0.0"), 

74 latency_ms=0.0, 

75 ), 

76 ) 

77 

78 def process( 

79 self, batches: List[PromptBatch], context: Any 

80 ) -> List[ResponseBatch]: 

81 """Execute LLM calls for all prompt batches.""" 

82 response_batches: List[ResponseBatch] = [] 

83 

84 for batch in batches: 

85 self.logger.info( 

86 f"Processing batch {batch.batch_id} " 

87 f"({len(batch.prompts)} prompts)" 

88 ) 

89 

90 # Process batch with concurrency 

91 responses = self._process_batch_concurrent( 

92 batch.prompts, context 

93 ) 

94 

95 # Calculate batch metrics 

96 total_tokens = sum(r.tokens_in + r.tokens_out for r in responses) 

97 total_cost = sum(r.cost for r in responses) 

98 latencies = [r.latency_ms for r in responses] 

99 

100 # Create response batch 

101 response_batch = ResponseBatch( 

102 responses=[r.text for r in responses], 

103 metadata=batch.metadata, 

104 tokens_used=total_tokens, 

105 cost=total_cost, 

106 batch_id=batch.batch_id, 

107 latencies_ms=latencies, 

108 ) 

109 response_batches.append(response_batch) 

110 

111 # Update context 

112 context.add_cost(total_cost, total_tokens) 

113 context.update_row( 

114 batch.metadata[-1].row_index 

115 if batch.metadata 

116 else 0 

117 ) 

118 

119 return response_batches 

120 

121 def _process_batch_concurrent( 

122 self, prompts: List[str], context: Any 

123 ) -> List[Any]: 

124 """Process prompts concurrently while maintaining order.""" 

125 with concurrent.futures.ThreadPoolExecutor( 

126 max_workers=self.concurrency 

127 ) as executor: 

128 # Submit all tasks and keep them in order 

129 futures = [ 

130 executor.submit(self._invoke_with_retry_and_ratelimit, prompt) 

131 for prompt in prompts 

132 ] 

133 

134 # Collect results in submission order 

135 responses = [] 

136 for idx, future in enumerate(futures): 

137 try: 

138 response = future.result() 

139 responses.append(response) 

140 except Exception as e: 

141 prompt = prompts[idx] 

142 

143 # Apply error policy 

144 decision = self.error_handler.handle_error( 

145 e, 

146 context={ 

147 "row_index": idx, 

148 "stage": self.name, 

149 "prompt": prompt[:100], 

150 }, 

151 ) 

152 

153 if decision.action == ErrorAction.SKIP: 

154 # Create placeholder response for skipped row 

155 from llm_dataset_engine.core.models import LLMResponse 

156 from decimal import Decimal 

157 

158 placeholder = LLMResponse( 

159 text="[SKIPPED]", 

160 tokens_in=0, 

161 tokens_out=0, 

162 model=self.llm_client.spec.model, 

163 cost=Decimal("0.0"), 

164 latency_ms=0.0, 

165 metadata={"error": str(e), "action": "skipped"}, 

166 ) 

167 responses.append(placeholder) 

168 elif decision.action == ErrorAction.USE_DEFAULT: 

169 # Use default response 

170 responses.append(decision.default_value) 

171 elif decision.action == ErrorAction.FAIL: 

172 # Re-raise to fail pipeline 

173 raise 

174 # RETRY is handled by retry_handler already 

175 

176 return responses 

177 

178 def _invoke_with_retry_and_ratelimit(self, prompt: str) -> Any: 

179 """Invoke LLM with rate limiting and retries.""" 

180 def _invoke() -> Any: 

181 # Acquire rate limit token 

182 if self.rate_limiter: 

183 self.rate_limiter.acquire() 

184 

185 # Invoke LLM 

186 try: 

187 return self.llm_client.invoke(prompt) 

188 except Exception as e: 

189 # Classify errors for retry logic 

190 if "rate" in str(e).lower(): 

191 raise RateLimitError(str(e)) 

192 elif "network" in str(e).lower() or "timeout" in str( 

193 e 

194 ).lower(): 

195 raise NetworkError(str(e)) 

196 else: 

197 raise 

198 

199 # Execute with retry handler 

200 return self.retry_handler.execute(_invoke) 

201 

202 def validate_input( 

203 self, batches: List[PromptBatch] 

204 ) -> ValidationResult: 

205 """Validate prompt batches.""" 

206 result = ValidationResult(is_valid=True) 

207 

208 if not batches: 

209 result.add_error("No prompt batches provided") 

210 

211 for batch in batches: 

212 if not batch.prompts: 

213 result.add_error(f"Batch {batch.batch_id} has no prompts") 

214 

215 if len(batch.prompts) != len(batch.metadata): 

216 result.add_error( 

217 f"Batch {batch.batch_id} prompt/metadata mismatch" 

218 ) 

219 

220 return result 

221 

222 def estimate_cost(self, batches: List[PromptBatch]) -> CostEstimate: 

223 """Estimate LLM invocation cost.""" 

224 total_input_tokens = 0 

225 total_output_tokens = 0 

226 

227 # Estimate tokens for all prompts 

228 for batch in batches: 

229 for prompt in batch.prompts: 

230 input_tokens = self.llm_client.estimate_tokens(prompt) 

231 total_input_tokens += input_tokens 

232 

233 # Assume average output length (can be made configurable) 

234 estimated_output = int(input_tokens * 0.5) 

235 total_output_tokens += estimated_output 

236 

237 total_cost = self.llm_client.calculate_cost( 

238 total_input_tokens, total_output_tokens 

239 ) 

240 

241 return CostEstimate( 

242 total_cost=total_cost, 

243 total_tokens=total_input_tokens + total_output_tokens, 

244 input_tokens=total_input_tokens, 

245 output_tokens=total_output_tokens, 

246 rows=sum(len(b.prompts) for b in batches), 

247 confidence="estimate", 

248 ) 

249