Coverage for src / dataknobs_llm / llm / providers / openai.py: 15%

124 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-15 10:29 -0700

1"""OpenAI LLM provider implementation. 

2 

3This module provides OpenAI API integration for dataknobs-llm, supporting: 

4- GPT-4, GPT-3.5-turbo, and other OpenAI chat models 

5- Text embeddings (ada-002, etc.) 

6- Function calling / tool use 

7- Streaming responses 

8- JSON mode for structured outputs 

9- Vision models (GPT-4V) 

10 

11The OpenAIProvider uses the official OpenAI Python SDK and supports all 

12standard OpenAI API parameters. 

13 

14Example: 

15 ```python 

16 from dataknobs_llm.llm.providers import OpenAIProvider 

17 from dataknobs_llm.llm.base import LLMConfig 

18 

19 # Create provider 

20 config = LLMConfig( 

21 provider="openai", 

22 model="gpt-4", 

23 api_key="sk-...", # or set OPENAI_API_KEY env var 

24 temperature=0.7, 

25 max_tokens=500 

26 ) 

27 

28 async with OpenAIProvider(config) as llm: 

29 # Simple completion 

30 response = await llm.complete("What is Python?") 

31 print(response.content) 

32 

33 # Streaming 

34 async for chunk in llm.stream_complete("Tell a story"): 

35 print(chunk.delta, end="", flush=True) 

36 

37 # Embeddings 

38 embedding = await llm.embed("sample text") 

39 print(f"Dimensions: {len(embedding)}") 

40 ``` 

41 

42See Also: 

43 - OpenAI API Documentation: https://platform.openai.com/docs 

44 - openai Python package: https://github.com/openai/openai-python 

45""" 

46 

47import os 

48from typing import TYPE_CHECKING, Any, Dict, List, Union, AsyncIterator 

49 

50from ..base import ( 

51 LLMConfig, LLMMessage, LLMResponse, LLMStreamResponse, 

52 AsyncLLMProvider, ModelCapability, 

53 LLMAdapter, normalize_llm_config 

54) 

55from dataknobs_llm.prompts import AsyncPromptBuilder 

56 

57if TYPE_CHECKING: 

58 from dataknobs_config.config import Config 

59 

60 

61class OpenAIAdapter(LLMAdapter): 

62 """Adapter for OpenAI API format.""" 

63 

64 def adapt_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]: 

65 """Convert messages to OpenAI format.""" 

66 adapted = [] 

67 for msg in messages: 

68 message = { 

69 'role': msg.role, 

70 'content': msg.content 

71 } 

72 if msg.name: 

73 message['name'] = msg.name 

74 if msg.function_call: 

75 message['function_call'] = msg.function_call 

76 adapted.append(message) 

77 return adapted 

78 

79 def adapt_response(self, response: Any) -> LLMResponse: 

80 """Convert OpenAI response to standard format.""" 

81 choice = response.choices[0] 

82 message = choice.message 

83 

84 return LLMResponse( 

85 content=message.content or '', 

86 model=response.model, 

87 finish_reason=choice.finish_reason, 

88 usage={ 

89 'prompt_tokens': response.usage.prompt_tokens, 

90 'completion_tokens': response.usage.completion_tokens, 

91 'total_tokens': response.usage.total_tokens 

92 } if response.usage else None, 

93 function_call=message.function_call if hasattr(message, 'function_call') else None 

94 ) 

95 

96 def adapt_config(self, config: LLMConfig) -> Dict[str, Any]: 

97 """Convert config to OpenAI parameters.""" 

98 params = { 

99 'model': config.model, 

100 'temperature': config.temperature, 

101 'top_p': config.top_p, 

102 'frequency_penalty': config.frequency_penalty, 

103 'presence_penalty': config.presence_penalty, 

104 } 

105 

106 if config.max_tokens: 

107 params['max_tokens'] = config.max_tokens 

108 if config.stop_sequences: 

109 params['stop'] = config.stop_sequences 

110 if config.seed: 

111 params['seed'] = config.seed 

112 if config.logit_bias: 

113 params['logit_bias'] = config.logit_bias 

114 if config.user_id: 

115 params['user'] = config.user_id 

116 if config.response_format == 'json': 

117 params['response_format'] = {'type': 'json_object'} 

118 if config.functions: 

119 params['functions'] = config.functions 

120 if config.function_call: 

121 params['function_call'] = config.function_call 

122 

123 return params 

124 

125 

126class OpenAIProvider(AsyncLLMProvider): 

127 """OpenAI LLM provider with full API support. 

128 

129 Provides async access to OpenAI's chat, completion, embedding, and 

130 function calling APIs. Supports all GPT models including GPT-4, GPT-3.5, 

131 and specialized models (vision, embeddings). 

132 

133 Features: 

134 - Full GPT-4 and GPT-3.5-turbo support 

135 - Streaming responses for real-time output 

136 - Function calling for tool use 

137 - JSON mode for structured outputs 

138 - Embeddings for semantic search 

139 - Custom API endpoints (e.g., Azure OpenAI) 

140 - Automatic retry with rate limiting 

141 - Cost tracking 

142 

143 Example: 

144 ```python 

145 from dataknobs_llm.llm.providers import OpenAIProvider 

146 from dataknobs_llm.llm.base import LLMConfig, LLMMessage 

147 

148 # Basic usage 

149 config = LLMConfig( 

150 provider="openai", 

151 model="gpt-4", 

152 api_key="sk-...", 

153 temperature=0.7 

154 ) 

155 

156 async with OpenAIProvider(config) as llm: 

157 # Simple question 

158 response = await llm.complete("Explain async/await") 

159 print(response.content) 

160 

161 # Multi-turn conversation 

162 messages = [ 

163 LLMMessage(role="system", content="You are a coding tutor"), 

164 LLMMessage(role="user", content="How do I use asyncio?") 

165 ] 

166 response = await llm.complete(messages) 

167 

168 # JSON mode for structured output 

169 json_config = LLMConfig( 

170 provider="openai", 

171 model="gpt-4", 

172 response_format="json", 

173 system_prompt="Return JSON only" 

174 ) 

175 

176 llm = OpenAIProvider(json_config) 

177 await llm.initialize() 

178 response = await llm.complete( 

179 "List 3 Python libraries as JSON: {name, description}" 

180 ) 

181 import json 

182 data = json.loads(response.content) 

183 

184 # With Azure OpenAI 

185 azure_config = LLMConfig( 

186 provider="openai", 

187 model="gpt-4", 

188 api_base="https://your-resource.openai.azure.com/", 

189 api_key="azure-key" 

190 ) 

191 

192 # Function calling 

193 functions = [{ 

194 "name": "search", 

195 "description": "Search for information", 

196 "parameters": { 

197 "type": "object", 

198 "properties": { 

199 "query": {"type": "string"} 

200 } 

201 } 

202 }] 

203 

204 response = await llm.function_call(messages, functions) 

205 if response.function_call: 

206 print(f"Call: {response.function_call['name']}") 

207 ``` 

208 

209 Args: 

210 config: LLMConfig, dataknobs Config, or dict with provider settings 

211 prompt_builder: Optional AsyncPromptBuilder for prompt rendering 

212 

213 Attributes: 

214 adapter (OpenAIAdapter): Format adapter for OpenAI API 

215 _client: OpenAI AsyncOpenAI client instance 

216 

217 See Also: 

218 LLMConfig: Configuration options 

219 AsyncLLMProvider: Base provider interface 

220 OpenAIAdapter: Format conversion 

221 """ 

222 

223 def __init__( 

224 self, 

225 config: Union[LLMConfig, "Config", Dict[str, Any]], 

226 prompt_builder: AsyncPromptBuilder | None = None 

227 ): 

228 # Normalize config first 

229 llm_config = normalize_llm_config(config) 

230 super().__init__(llm_config, prompt_builder=prompt_builder) 

231 self.adapter = OpenAIAdapter() 

232 

233 async def initialize(self) -> None: 

234 """Initialize OpenAI client.""" 

235 try: 

236 import openai 

237 

238 api_key = self.config.api_key or os.environ.get('OPENAI_API_KEY') 

239 if not api_key: 

240 raise ValueError("OpenAI API key not provided") 

241 

242 self._client = openai.AsyncOpenAI( 

243 api_key=api_key, 

244 base_url=self.config.api_base, 

245 timeout=self.config.timeout 

246 ) 

247 self._is_initialized = True 

248 except ImportError as e: 

249 raise ImportError("openai package not installed. Install with: pip install openai") from e 

250 

251 async def close(self) -> None: 

252 """Close OpenAI client.""" 

253 if self._client: 

254 await self._client.close() # type: ignore[unreachable] 

255 self._is_initialized = False 

256 

257 async def validate_model(self) -> bool: 

258 """Validate model availability.""" 

259 try: 

260 # List available models 

261 models = await self._client.models.list() 

262 model_ids = [m.id for m in models.data] 

263 return self.config.model in model_ids 

264 except Exception: 

265 return False 

266 

267 def get_capabilities(self) -> List[ModelCapability]: 

268 """Get OpenAI model capabilities.""" 

269 capabilities = [ 

270 ModelCapability.TEXT_GENERATION, 

271 ModelCapability.CHAT, 

272 ModelCapability.STREAMING 

273 ] 

274 

275 if 'gpt-4' in self.config.model or 'gpt-3.5' in self.config.model: 

276 capabilities.extend([ 

277 ModelCapability.FUNCTION_CALLING, 

278 ModelCapability.JSON_MODE 

279 ]) 

280 

281 if 'vision' in self.config.model: 

282 capabilities.append(ModelCapability.VISION) 

283 

284 if 'embedding' in self.config.model: 

285 capabilities.append(ModelCapability.EMBEDDINGS) 

286 

287 return capabilities 

288 

289 async def complete( 

290 self, 

291 messages: Union[str, List[LLMMessage]], 

292 config_overrides: Dict[str, Any] | None = None, 

293 **kwargs 

294 ) -> LLMResponse: 

295 """Generate completion. 

296 

297 Args: 

298 messages: Input messages or prompt 

299 config_overrides: Optional dict to override config fields (model, 

300 temperature, max_tokens, top_p, stop_sequences, seed) 

301 **kwargs: Additional provider-specific parameters 

302 """ 

303 if not self._is_initialized: 

304 await self.initialize() 

305 

306 # Get runtime config (with overrides applied if provided) 

307 runtime_config = self._get_runtime_config(config_overrides) 

308 

309 # Convert string to message list 

310 if isinstance(messages, str): 

311 messages = [LLMMessage(role='user', content=messages)] 

312 

313 # Add system prompt if configured 

314 if runtime_config.system_prompt and messages[0].role != 'system': 

315 messages.insert(0, LLMMessage(role='system', content=runtime_config.system_prompt)) 

316 

317 # Adapt messages and config 

318 adapted_messages = self.adapter.adapt_messages(messages) 

319 params = self.adapter.adapt_config(runtime_config) 

320 params.update(kwargs) 

321 

322 # Make API call 

323 response = await self._client.chat.completions.create( 

324 messages=adapted_messages, 

325 **params 

326 ) 

327 

328 return self.adapter.adapt_response(response) 

329 

330 async def stream_complete( 

331 self, 

332 messages: Union[str, List[LLMMessage]], 

333 config_overrides: Dict[str, Any] | None = None, 

334 **kwargs 

335 ) -> AsyncIterator[LLMStreamResponse]: 

336 """Generate streaming completion. 

337 

338 Args: 

339 messages: Input messages or prompt 

340 config_overrides: Optional dict to override config fields (model, 

341 temperature, max_tokens, top_p, stop_sequences, seed) 

342 **kwargs: Additional provider-specific parameters 

343 """ 

344 if not self._is_initialized: 

345 await self.initialize() 

346 

347 # Get runtime config (with overrides applied if provided) 

348 runtime_config = self._get_runtime_config(config_overrides) 

349 

350 # Convert string to message list 

351 if isinstance(messages, str): 

352 messages = [LLMMessage(role='user', content=messages)] 

353 

354 # Add system prompt if configured 

355 if runtime_config.system_prompt and messages[0].role != 'system': 

356 messages.insert(0, LLMMessage(role='system', content=runtime_config.system_prompt)) 

357 

358 # Adapt messages and config 

359 adapted_messages = self.adapter.adapt_messages(messages) 

360 params = self.adapter.adapt_config(runtime_config) 

361 params['stream'] = True 

362 params.update(kwargs) 

363 

364 # Stream API call 

365 stream = await self._client.chat.completions.create( 

366 messages=adapted_messages, 

367 **params 

368 ) 

369 

370 async for chunk in stream: 

371 if chunk.choices[0].delta.content: 

372 yield LLMStreamResponse( 

373 delta=chunk.choices[0].delta.content, 

374 is_final=chunk.choices[0].finish_reason is not None, 

375 finish_reason=chunk.choices[0].finish_reason 

376 ) 

377 

378 async def embed( 

379 self, 

380 texts: Union[str, List[str]], 

381 **kwargs 

382 ) -> Union[List[float], List[List[float]]]: 

383 """Generate embeddings.""" 

384 if not self._is_initialized: 

385 await self.initialize() 

386 

387 if isinstance(texts, str): 

388 texts = [texts] 

389 single = True 

390 else: 

391 single = False 

392 

393 response = await self._client.embeddings.create( 

394 input=texts, 

395 model=self.config.model or 'text-embedding-ada-002' 

396 ) 

397 

398 embeddings = [e.embedding for e in response.data] 

399 return embeddings[0] if single else embeddings 

400 

401 async def function_call( 

402 self, 

403 messages: List[LLMMessage], 

404 functions: List[Dict[str, Any]], 

405 **kwargs 

406 ) -> LLMResponse: 

407 """Execute function calling.""" 

408 if not self._is_initialized: 

409 await self.initialize() 

410 

411 # Add system prompt if configured 

412 if self.config.system_prompt and messages[0].role != 'system': 

413 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt)) 

414 

415 # Adapt messages and config 

416 adapted_messages = self.adapter.adapt_messages(messages) 

417 params = self.adapter.adapt_config(self.config) 

418 params['functions'] = functions 

419 params['function_call'] = kwargs.get('function_call', 'auto') 

420 params.update(kwargs) 

421 

422 # Make API call 

423 response = await self._client.chat.completions.create( 

424 messages=adapted_messages, 

425 **params 

426 ) 

427 

428 return self.adapter.adapt_response(response)