Coverage for llm_dataset_engine/adapters/llm_client.py: 30%

127 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-15 18:04 +0200

1""" 

2LLM client abstractions and implementations. 

3 

4Provides unified interface for multiple LLM providers following the 

5Adapter pattern and Dependency Inversion principle. 

6""" 

7 

8import os 

9import time 

10from abc import ABC, abstractmethod 

11from decimal import Decimal 

12from typing import Any, Dict, List, Optional 

13 

14import tiktoken 

15from llama_index.core.llms import ChatMessage 

16from llama_index.llms.anthropic import Anthropic 

17from llama_index.llms.azure_openai import AzureOpenAI 

18from llama_index.llms.groq import Groq 

19from llama_index.llms.openai import OpenAI 

20 

21from llm_dataset_engine.core.models import LLMResponse 

22from llm_dataset_engine.core.specifications import LLMProvider, LLMSpec 

23 

24 

25class LLMClient(ABC): 

26 """ 

27 Abstract base class for LLM clients. 

28  

29 Defines the contract that all LLM provider implementations must follow, 

30 enabling easy swapping of providers (Strategy pattern). 

31 """ 

32 

33 def __init__(self, spec: LLMSpec): 

34 """ 

35 Initialize LLM client. 

36 

37 Args: 

38 spec: LLM specification 

39 """ 

40 self.spec = spec 

41 self.model = spec.model 

42 self.temperature = spec.temperature 

43 self.max_tokens = spec.max_tokens 

44 

45 @abstractmethod 

46 def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse: 

47 """ 

48 Invoke LLM with a single prompt. 

49 

50 Args: 

51 prompt: Text prompt 

52 **kwargs: Additional model parameters 

53 

54 Returns: 

55 LLMResponse with result and metadata 

56 """ 

57 pass 

58 

59 @abstractmethod 

60 def estimate_tokens(self, text: str) -> int: 

61 """ 

62 Estimate token count for text. 

63 

64 Args: 

65 text: Input text 

66 

67 Returns: 

68 Estimated token count 

69 """ 

70 pass 

71 

72 def batch_invoke( 

73 self, prompts: List[str], **kwargs: Any 

74 ) -> List[LLMResponse]: 

75 """ 

76 Invoke LLM with multiple prompts. 

77 

78 Default implementation: sequential invocation. 

79 Subclasses can override for provider-optimized batch processing. 

80 

81 Args: 

82 prompts: List of text prompts 

83 **kwargs: Additional model parameters 

84 

85 Returns: 

86 List of LLMResponse objects 

87 """ 

88 return [self.invoke(prompt, **kwargs) for prompt in prompts] 

89 

90 def calculate_cost( 

91 self, tokens_in: int, tokens_out: int 

92 ) -> Decimal: 

93 """ 

94 Calculate cost for token usage. 

95 

96 Args: 

97 tokens_in: Input tokens 

98 tokens_out: Output tokens 

99 

100 Returns: 

101 Total cost in USD 

102 """ 

103 input_cost = ( 

104 Decimal(tokens_in) / 1000 

105 ) * (self.spec.input_cost_per_1k_tokens or Decimal("0.0")) 

106 output_cost = ( 

107 Decimal(tokens_out) / 1000 

108 ) * (self.spec.output_cost_per_1k_tokens or Decimal("0.0")) 

109 return input_cost + output_cost 

110 

111 

112class OpenAIClient(LLMClient): 

113 """OpenAI LLM client implementation.""" 

114 

115 def __init__(self, spec: LLMSpec): 

116 """Initialize OpenAI client.""" 

117 super().__init__(spec) 

118 

119 api_key = spec.api_key or os.getenv("OPENAI_API_KEY") 

120 if not api_key: 

121 raise ValueError("OPENAI_API_KEY not found in spec or environment") 

122 

123 self.client = OpenAI( 

124 model=spec.model, 

125 api_key=api_key, 

126 temperature=spec.temperature, 

127 max_tokens=spec.max_tokens, 

128 ) 

129 

130 # Initialize tokenizer 

131 try: 

132 self.tokenizer = tiktoken.encoding_for_model(spec.model) 

133 except KeyError: 

134 self.tokenizer = tiktoken.get_encoding("cl100k_base") 

135 

136 def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse: 

137 """Invoke OpenAI API.""" 

138 start_time = time.time() 

139 

140 message = ChatMessage(role="user", content=prompt) 

141 response = self.client.chat([message]) 

142 

143 latency_ms = (time.time() - start_time) * 1000 

144 

145 # Extract token usage 

146 tokens_in = len(self.tokenizer.encode(prompt)) 

147 tokens_out = len(self.tokenizer.encode(str(response))) 

148 

149 cost = self.calculate_cost(tokens_in, tokens_out) 

150 

151 return LLMResponse( 

152 text=str(response), 

153 tokens_in=tokens_in, 

154 tokens_out=tokens_out, 

155 model=self.model, 

156 cost=cost, 

157 latency_ms=latency_ms, 

158 ) 

159 

160 def estimate_tokens(self, text: str) -> int: 

161 """Estimate tokens using tiktoken.""" 

162 return len(self.tokenizer.encode(text)) 

163 

164 

165class AzureOpenAIClient(LLMClient): 

166 """Azure OpenAI LLM client implementation.""" 

167 

168 def __init__(self, spec: LLMSpec): 

169 """Initialize Azure OpenAI client.""" 

170 super().__init__(spec) 

171 

172 api_key = spec.api_key or os.getenv("AZURE_OPENAI_API_KEY") 

173 if not api_key: 

174 raise ValueError( 

175 "AZURE_OPENAI_API_KEY not found in spec or environment" 

176 ) 

177 

178 if not spec.azure_endpoint: 

179 raise ValueError("azure_endpoint required for Azure OpenAI") 

180 

181 if not spec.azure_deployment: 

182 raise ValueError("azure_deployment required for Azure OpenAI") 

183 

184 self.client = AzureOpenAI( 

185 model=spec.model, 

186 deployment_name=spec.azure_deployment, 

187 api_key=api_key, 

188 azure_endpoint=spec.azure_endpoint, 

189 api_version=spec.api_version or "2024-02-15-preview", 

190 temperature=spec.temperature, 

191 max_tokens=spec.max_tokens, 

192 ) 

193 

194 # Initialize tokenizer 

195 try: 

196 self.tokenizer = tiktoken.encoding_for_model(spec.model) 

197 except KeyError: 

198 self.tokenizer = tiktoken.get_encoding("cl100k_base") 

199 

200 def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse: 

201 """Invoke Azure OpenAI API.""" 

202 start_time = time.time() 

203 

204 message = ChatMessage(role="user", content=prompt) 

205 response = self.client.chat([message]) 

206 

207 latency_ms = (time.time() - start_time) * 1000 

208 

209 # Extract token usage 

210 tokens_in = len(self.tokenizer.encode(prompt)) 

211 tokens_out = len(self.tokenizer.encode(str(response))) 

212 

213 cost = self.calculate_cost(tokens_in, tokens_out) 

214 

215 return LLMResponse( 

216 text=str(response), 

217 tokens_in=tokens_in, 

218 tokens_out=tokens_out, 

219 model=self.model, 

220 cost=cost, 

221 latency_ms=latency_ms, 

222 ) 

223 

224 def estimate_tokens(self, text: str) -> int: 

225 """Estimate tokens using tiktoken.""" 

226 return len(self.tokenizer.encode(text)) 

227 

228 

229class AnthropicClient(LLMClient): 

230 """Anthropic Claude LLM client implementation.""" 

231 

232 def __init__(self, spec: LLMSpec): 

233 """Initialize Anthropic client.""" 

234 super().__init__(spec) 

235 

236 api_key = spec.api_key or os.getenv("ANTHROPIC_API_KEY") 

237 if not api_key: 

238 raise ValueError( 

239 "ANTHROPIC_API_KEY not found in spec or environment" 

240 ) 

241 

242 self.client = Anthropic( 

243 model=spec.model, 

244 api_key=api_key, 

245 temperature=spec.temperature, 

246 max_tokens=spec.max_tokens or 1024, 

247 ) 

248 

249 # Anthropic uses approximate token counting 

250 self.tokenizer = tiktoken.get_encoding("cl100k_base") 

251 

252 def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse: 

253 """Invoke Anthropic API.""" 

254 start_time = time.time() 

255 

256 message = ChatMessage(role="user", content=prompt) 

257 response = self.client.chat([message]) 

258 

259 latency_ms = (time.time() - start_time) * 1000 

260 

261 # Approximate token usage 

262 tokens_in = len(self.tokenizer.encode(prompt)) 

263 tokens_out = len(self.tokenizer.encode(str(response))) 

264 

265 cost = self.calculate_cost(tokens_in, tokens_out) 

266 

267 return LLMResponse( 

268 text=str(response), 

269 tokens_in=tokens_in, 

270 tokens_out=tokens_out, 

271 model=self.model, 

272 cost=cost, 

273 latency_ms=latency_ms, 

274 ) 

275 

276 def estimate_tokens(self, text: str) -> int: 

277 """Estimate tokens (approximate for Anthropic).""" 

278 return len(self.tokenizer.encode(text)) 

279 

280 

281class GroqClient(LLMClient): 

282 """Groq LLM client implementation.""" 

283 

284 def __init__(self, spec: LLMSpec): 

285 """Initialize Groq client.""" 

286 super().__init__(spec) 

287 

288 api_key = spec.api_key or os.getenv("GROQ_API_KEY") 

289 if not api_key: 

290 raise ValueError( 

291 "GROQ_API_KEY not found in spec or environment" 

292 ) 

293 

294 self.client = Groq( 

295 model=spec.model, 

296 api_key=api_key, 

297 temperature=spec.temperature, 

298 max_tokens=spec.max_tokens, 

299 ) 

300 

301 # Use tiktoken for token estimation 

302 self.tokenizer = tiktoken.get_encoding("cl100k_base") 

303 

304 def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse: 

305 """Invoke Groq API.""" 

306 start_time = time.time() 

307 

308 message = ChatMessage(role="user", content=prompt) 

309 response = self.client.chat([message]) 

310 

311 latency_ms = (time.time() - start_time) * 1000 

312 

313 # Extract token usage 

314 tokens_in = len(self.tokenizer.encode(prompt)) 

315 tokens_out = len(self.tokenizer.encode(str(response))) 

316 

317 cost = self.calculate_cost(tokens_in, tokens_out) 

318 

319 return LLMResponse( 

320 text=str(response), 

321 tokens_in=tokens_in, 

322 tokens_out=tokens_out, 

323 model=self.model, 

324 cost=cost, 

325 latency_ms=latency_ms, 

326 ) 

327 

328 def estimate_tokens(self, text: str) -> int: 

329 """Estimate tokens using tiktoken.""" 

330 return len(self.tokenizer.encode(text)) 

331 

332 

333def create_llm_client(spec: LLMSpec) -> LLMClient: 

334 """ 

335 Factory function to create appropriate LLM client. 

336 

337 Args: 

338 spec: LLM specification 

339 

340 Returns: 

341 Configured LLM client 

342 

343 Raises: 

344 ValueError: If provider not supported 

345 """ 

346 if spec.provider == LLMProvider.OPENAI: 

347 return OpenAIClient(spec) 

348 elif spec.provider == LLMProvider.AZURE_OPENAI: 

349 return AzureOpenAIClient(spec) 

350 elif spec.provider == LLMProvider.ANTHROPIC: 

351 return AnthropicClient(spec) 

352 elif spec.provider == LLMProvider.GROQ: 

353 return GroqClient(spec) 

354 else: 

355 raise ValueError(f"Unsupported provider: {spec.provider}") 

356