Coverage for llm_dataset_engine/stages/llm_invocation_stage.py: 20%
86 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
1"""LLM invocation stage with concurrency and retry logic."""
3import concurrent.futures
4import time
5from decimal import Decimal
6from typing import Any, List
8from llm_dataset_engine.adapters.llm_client import LLMClient
9from llm_dataset_engine.core.error_handler import ErrorAction, ErrorHandler
10from llm_dataset_engine.core.models import (
11 CostEstimate,
12 LLMResponse,
13 PromptBatch,
14 ResponseBatch,
15 ValidationResult,
16)
17from llm_dataset_engine.core.specifications import ErrorPolicy
18from llm_dataset_engine.stages.pipeline_stage import PipelineStage
19from llm_dataset_engine.utils import (
20 NetworkError,
21 RateLimitError,
22 RateLimiter,
23 RetryHandler,
24)
27class LLMInvocationStage(
28 PipelineStage[List[PromptBatch], List[ResponseBatch]]
29):
30 """
31 Invoke LLM with prompts using concurrency and retries.
33 Responsibilities:
34 - Execute LLM calls with rate limiting
35 - Handle retries for transient failures
36 - Track tokens and costs
37 - Support concurrent processing
38 """
40 def __init__(
41 self,
42 llm_client: LLMClient,
43 concurrency: int = 5,
44 rate_limiter: RateLimiter | None = None,
45 retry_handler: RetryHandler | None = None,
46 error_policy: ErrorPolicy = ErrorPolicy.SKIP,
47 max_retries: int = 3,
48 ):
49 """
50 Initialize LLM invocation stage.
52 Args:
53 llm_client: LLM client instance
54 concurrency: Max concurrent requests
55 rate_limiter: Optional rate limiter
56 retry_handler: Optional retry handler
57 error_policy: Policy for handling errors
58 max_retries: Maximum retry attempts
59 """
60 super().__init__("LLMInvocation")
61 self.llm_client = llm_client
62 self.concurrency = concurrency
63 self.rate_limiter = rate_limiter
64 self.retry_handler = retry_handler or RetryHandler()
65 self.error_handler = ErrorHandler(
66 policy=error_policy,
67 max_retries=max_retries,
68 default_value_factory=lambda: LLMResponse(
69 text="",
70 tokens_in=0,
71 tokens_out=0,
72 model=llm_client.model,
73 cost=Decimal("0.0"),
74 latency_ms=0.0,
75 ),
76 )
78 def process(
79 self, batches: List[PromptBatch], context: Any
80 ) -> List[ResponseBatch]:
81 """Execute LLM calls for all prompt batches."""
82 response_batches: List[ResponseBatch] = []
84 for batch in batches:
85 self.logger.info(
86 f"Processing batch {batch.batch_id} "
87 f"({len(batch.prompts)} prompts)"
88 )
90 # Process batch with concurrency
91 responses = self._process_batch_concurrent(
92 batch.prompts, context
93 )
95 # Calculate batch metrics
96 total_tokens = sum(r.tokens_in + r.tokens_out for r in responses)
97 total_cost = sum(r.cost for r in responses)
98 latencies = [r.latency_ms for r in responses]
100 # Create response batch
101 response_batch = ResponseBatch(
102 responses=[r.text for r in responses],
103 metadata=batch.metadata,
104 tokens_used=total_tokens,
105 cost=total_cost,
106 batch_id=batch.batch_id,
107 latencies_ms=latencies,
108 )
109 response_batches.append(response_batch)
111 # Update context
112 context.add_cost(total_cost, total_tokens)
113 context.update_row(
114 batch.metadata[-1].row_index
115 if batch.metadata
116 else 0
117 )
119 return response_batches
121 def _process_batch_concurrent(
122 self, prompts: List[str], context: Any
123 ) -> List[Any]:
124 """Process prompts concurrently while maintaining order."""
125 with concurrent.futures.ThreadPoolExecutor(
126 max_workers=self.concurrency
127 ) as executor:
128 # Submit all tasks and keep them in order
129 futures = [
130 executor.submit(self._invoke_with_retry_and_ratelimit, prompt)
131 for prompt in prompts
132 ]
134 # Collect results in submission order
135 responses = []
136 for idx, future in enumerate(futures):
137 try:
138 response = future.result()
139 responses.append(response)
140 except Exception as e:
141 prompt = prompts[idx]
143 # Apply error policy
144 decision = self.error_handler.handle_error(
145 e,
146 context={
147 "row_index": idx,
148 "stage": self.name,
149 "prompt": prompt[:100],
150 },
151 )
153 if decision.action == ErrorAction.SKIP:
154 # Create placeholder response for skipped row
155 from llm_dataset_engine.core.models import LLMResponse
156 from decimal import Decimal
158 placeholder = LLMResponse(
159 text="[SKIPPED]",
160 tokens_in=0,
161 tokens_out=0,
162 model=self.llm_client.spec.model,
163 cost=Decimal("0.0"),
164 latency_ms=0.0,
165 metadata={"error": str(e), "action": "skipped"},
166 )
167 responses.append(placeholder)
168 elif decision.action == ErrorAction.USE_DEFAULT:
169 # Use default response
170 responses.append(decision.default_value)
171 elif decision.action == ErrorAction.FAIL:
172 # Re-raise to fail pipeline
173 raise
174 # RETRY is handled by retry_handler already
176 return responses
178 def _invoke_with_retry_and_ratelimit(self, prompt: str) -> Any:
179 """Invoke LLM with rate limiting and retries."""
180 def _invoke() -> Any:
181 # Acquire rate limit token
182 if self.rate_limiter:
183 self.rate_limiter.acquire()
185 # Invoke LLM
186 try:
187 return self.llm_client.invoke(prompt)
188 except Exception as e:
189 # Classify errors for retry logic
190 if "rate" in str(e).lower():
191 raise RateLimitError(str(e))
192 elif "network" in str(e).lower() or "timeout" in str(
193 e
194 ).lower():
195 raise NetworkError(str(e))
196 else:
197 raise
199 # Execute with retry handler
200 return self.retry_handler.execute(_invoke)
202 def validate_input(
203 self, batches: List[PromptBatch]
204 ) -> ValidationResult:
205 """Validate prompt batches."""
206 result = ValidationResult(is_valid=True)
208 if not batches:
209 result.add_error("No prompt batches provided")
211 for batch in batches:
212 if not batch.prompts:
213 result.add_error(f"Batch {batch.batch_id} has no prompts")
215 if len(batch.prompts) != len(batch.metadata):
216 result.add_error(
217 f"Batch {batch.batch_id} prompt/metadata mismatch"
218 )
220 return result
222 def estimate_cost(self, batches: List[PromptBatch]) -> CostEstimate:
223 """Estimate LLM invocation cost."""
224 total_input_tokens = 0
225 total_output_tokens = 0
227 # Estimate tokens for all prompts
228 for batch in batches:
229 for prompt in batch.prompts:
230 input_tokens = self.llm_client.estimate_tokens(prompt)
231 total_input_tokens += input_tokens
233 # Assume average output length (can be made configurable)
234 estimated_output = int(input_tokens * 0.5)
235 total_output_tokens += estimated_output
237 total_cost = self.llm_client.calculate_cost(
238 total_input_tokens, total_output_tokens
239 )
241 return CostEstimate(
242 total_cost=total_cost,
243 total_tokens=total_input_tokens + total_output_tokens,
244 input_tokens=total_input_tokens,
245 output_tokens=total_output_tokens,
246 rows=sum(len(b.prompts) for b in batches),
247 confidence="estimate",
248 )