Coverage for llm_dataset_engine/adapters/llm_client.py: 30%
127 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
1"""
2LLM client abstractions and implementations.
4Provides unified interface for multiple LLM providers following the
5Adapter pattern and Dependency Inversion principle.
6"""
8import os
9import time
10from abc import ABC, abstractmethod
11from decimal import Decimal
12from typing import Any, Dict, List, Optional
14import tiktoken
15from llama_index.core.llms import ChatMessage
16from llama_index.llms.anthropic import Anthropic
17from llama_index.llms.azure_openai import AzureOpenAI
18from llama_index.llms.groq import Groq
19from llama_index.llms.openai import OpenAI
21from llm_dataset_engine.core.models import LLMResponse
22from llm_dataset_engine.core.specifications import LLMProvider, LLMSpec
25class LLMClient(ABC):
26 """
27 Abstract base class for LLM clients.
29 Defines the contract that all LLM provider implementations must follow,
30 enabling easy swapping of providers (Strategy pattern).
31 """
33 def __init__(self, spec: LLMSpec):
34 """
35 Initialize LLM client.
37 Args:
38 spec: LLM specification
39 """
40 self.spec = spec
41 self.model = spec.model
42 self.temperature = spec.temperature
43 self.max_tokens = spec.max_tokens
45 @abstractmethod
46 def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse:
47 """
48 Invoke LLM with a single prompt.
50 Args:
51 prompt: Text prompt
52 **kwargs: Additional model parameters
54 Returns:
55 LLMResponse with result and metadata
56 """
57 pass
59 @abstractmethod
60 def estimate_tokens(self, text: str) -> int:
61 """
62 Estimate token count for text.
64 Args:
65 text: Input text
67 Returns:
68 Estimated token count
69 """
70 pass
72 def batch_invoke(
73 self, prompts: List[str], **kwargs: Any
74 ) -> List[LLMResponse]:
75 """
76 Invoke LLM with multiple prompts.
78 Default implementation: sequential invocation.
79 Subclasses can override for provider-optimized batch processing.
81 Args:
82 prompts: List of text prompts
83 **kwargs: Additional model parameters
85 Returns:
86 List of LLMResponse objects
87 """
88 return [self.invoke(prompt, **kwargs) for prompt in prompts]
90 def calculate_cost(
91 self, tokens_in: int, tokens_out: int
92 ) -> Decimal:
93 """
94 Calculate cost for token usage.
96 Args:
97 tokens_in: Input tokens
98 tokens_out: Output tokens
100 Returns:
101 Total cost in USD
102 """
103 input_cost = (
104 Decimal(tokens_in) / 1000
105 ) * (self.spec.input_cost_per_1k_tokens or Decimal("0.0"))
106 output_cost = (
107 Decimal(tokens_out) / 1000
108 ) * (self.spec.output_cost_per_1k_tokens or Decimal("0.0"))
109 return input_cost + output_cost
112class OpenAIClient(LLMClient):
113 """OpenAI LLM client implementation."""
115 def __init__(self, spec: LLMSpec):
116 """Initialize OpenAI client."""
117 super().__init__(spec)
119 api_key = spec.api_key or os.getenv("OPENAI_API_KEY")
120 if not api_key:
121 raise ValueError("OPENAI_API_KEY not found in spec or environment")
123 self.client = OpenAI(
124 model=spec.model,
125 api_key=api_key,
126 temperature=spec.temperature,
127 max_tokens=spec.max_tokens,
128 )
130 # Initialize tokenizer
131 try:
132 self.tokenizer = tiktoken.encoding_for_model(spec.model)
133 except KeyError:
134 self.tokenizer = tiktoken.get_encoding("cl100k_base")
136 def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse:
137 """Invoke OpenAI API."""
138 start_time = time.time()
140 message = ChatMessage(role="user", content=prompt)
141 response = self.client.chat([message])
143 latency_ms = (time.time() - start_time) * 1000
145 # Extract token usage
146 tokens_in = len(self.tokenizer.encode(prompt))
147 tokens_out = len(self.tokenizer.encode(str(response)))
149 cost = self.calculate_cost(tokens_in, tokens_out)
151 return LLMResponse(
152 text=str(response),
153 tokens_in=tokens_in,
154 tokens_out=tokens_out,
155 model=self.model,
156 cost=cost,
157 latency_ms=latency_ms,
158 )
160 def estimate_tokens(self, text: str) -> int:
161 """Estimate tokens using tiktoken."""
162 return len(self.tokenizer.encode(text))
165class AzureOpenAIClient(LLMClient):
166 """Azure OpenAI LLM client implementation."""
168 def __init__(self, spec: LLMSpec):
169 """Initialize Azure OpenAI client."""
170 super().__init__(spec)
172 api_key = spec.api_key or os.getenv("AZURE_OPENAI_API_KEY")
173 if not api_key:
174 raise ValueError(
175 "AZURE_OPENAI_API_KEY not found in spec or environment"
176 )
178 if not spec.azure_endpoint:
179 raise ValueError("azure_endpoint required for Azure OpenAI")
181 if not spec.azure_deployment:
182 raise ValueError("azure_deployment required for Azure OpenAI")
184 self.client = AzureOpenAI(
185 model=spec.model,
186 deployment_name=spec.azure_deployment,
187 api_key=api_key,
188 azure_endpoint=spec.azure_endpoint,
189 api_version=spec.api_version or "2024-02-15-preview",
190 temperature=spec.temperature,
191 max_tokens=spec.max_tokens,
192 )
194 # Initialize tokenizer
195 try:
196 self.tokenizer = tiktoken.encoding_for_model(spec.model)
197 except KeyError:
198 self.tokenizer = tiktoken.get_encoding("cl100k_base")
200 def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse:
201 """Invoke Azure OpenAI API."""
202 start_time = time.time()
204 message = ChatMessage(role="user", content=prompt)
205 response = self.client.chat([message])
207 latency_ms = (time.time() - start_time) * 1000
209 # Extract token usage
210 tokens_in = len(self.tokenizer.encode(prompt))
211 tokens_out = len(self.tokenizer.encode(str(response)))
213 cost = self.calculate_cost(tokens_in, tokens_out)
215 return LLMResponse(
216 text=str(response),
217 tokens_in=tokens_in,
218 tokens_out=tokens_out,
219 model=self.model,
220 cost=cost,
221 latency_ms=latency_ms,
222 )
224 def estimate_tokens(self, text: str) -> int:
225 """Estimate tokens using tiktoken."""
226 return len(self.tokenizer.encode(text))
229class AnthropicClient(LLMClient):
230 """Anthropic Claude LLM client implementation."""
232 def __init__(self, spec: LLMSpec):
233 """Initialize Anthropic client."""
234 super().__init__(spec)
236 api_key = spec.api_key or os.getenv("ANTHROPIC_API_KEY")
237 if not api_key:
238 raise ValueError(
239 "ANTHROPIC_API_KEY not found in spec or environment"
240 )
242 self.client = Anthropic(
243 model=spec.model,
244 api_key=api_key,
245 temperature=spec.temperature,
246 max_tokens=spec.max_tokens or 1024,
247 )
249 # Anthropic uses approximate token counting
250 self.tokenizer = tiktoken.get_encoding("cl100k_base")
252 def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse:
253 """Invoke Anthropic API."""
254 start_time = time.time()
256 message = ChatMessage(role="user", content=prompt)
257 response = self.client.chat([message])
259 latency_ms = (time.time() - start_time) * 1000
261 # Approximate token usage
262 tokens_in = len(self.tokenizer.encode(prompt))
263 tokens_out = len(self.tokenizer.encode(str(response)))
265 cost = self.calculate_cost(tokens_in, tokens_out)
267 return LLMResponse(
268 text=str(response),
269 tokens_in=tokens_in,
270 tokens_out=tokens_out,
271 model=self.model,
272 cost=cost,
273 latency_ms=latency_ms,
274 )
276 def estimate_tokens(self, text: str) -> int:
277 """Estimate tokens (approximate for Anthropic)."""
278 return len(self.tokenizer.encode(text))
281class GroqClient(LLMClient):
282 """Groq LLM client implementation."""
284 def __init__(self, spec: LLMSpec):
285 """Initialize Groq client."""
286 super().__init__(spec)
288 api_key = spec.api_key or os.getenv("GROQ_API_KEY")
289 if not api_key:
290 raise ValueError(
291 "GROQ_API_KEY not found in spec or environment"
292 )
294 self.client = Groq(
295 model=spec.model,
296 api_key=api_key,
297 temperature=spec.temperature,
298 max_tokens=spec.max_tokens,
299 )
301 # Use tiktoken for token estimation
302 self.tokenizer = tiktoken.get_encoding("cl100k_base")
304 def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse:
305 """Invoke Groq API."""
306 start_time = time.time()
308 message = ChatMessage(role="user", content=prompt)
309 response = self.client.chat([message])
311 latency_ms = (time.time() - start_time) * 1000
313 # Extract token usage
314 tokens_in = len(self.tokenizer.encode(prompt))
315 tokens_out = len(self.tokenizer.encode(str(response)))
317 cost = self.calculate_cost(tokens_in, tokens_out)
319 return LLMResponse(
320 text=str(response),
321 tokens_in=tokens_in,
322 tokens_out=tokens_out,
323 model=self.model,
324 cost=cost,
325 latency_ms=latency_ms,
326 )
328 def estimate_tokens(self, text: str) -> int:
329 """Estimate tokens using tiktoken."""
330 return len(self.tokenizer.encode(text))
333def create_llm_client(spec: LLMSpec) -> LLMClient:
334 """
335 Factory function to create appropriate LLM client.
337 Args:
338 spec: LLM specification
340 Returns:
341 Configured LLM client
343 Raises:
344 ValueError: If provider not supported
345 """
346 if spec.provider == LLMProvider.OPENAI:
347 return OpenAIClient(spec)
348 elif spec.provider == LLMProvider.AZURE_OPENAI:
349 return AzureOpenAIClient(spec)
350 elif spec.provider == LLMProvider.ANTHROPIC:
351 return AnthropicClient(spec)
352 elif spec.provider == LLMProvider.GROQ:
353 return GroqClient(spec)
354 else:
355 raise ValueError(f"Unsupported provider: {spec.provider}")