Coverage for src/dataknobs_llm/llm/providers/openai.py: 17%

122 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-08 13:51 -0700

1"""OpenAI LLM provider implementation. 

2 

3This module provides OpenAI API integration for dataknobs-llm, supporting: 

4- GPT-4, GPT-3.5-turbo, and other OpenAI chat models 

5- Text embeddings (ada-002, etc.) 

6- Function calling / tool use 

7- Streaming responses 

8- JSON mode for structured outputs 

9- Vision models (GPT-4V) 

10 

11The OpenAIProvider uses the official OpenAI Python SDK and supports all 

12standard OpenAI API parameters. 

13 

14Example: 

15 ```python 

16 from dataknobs_llm.llm.providers import OpenAIProvider 

17 from dataknobs_llm.llm.base import LLMConfig 

18 

19 # Create provider 

20 config = LLMConfig( 

21 provider="openai", 

22 model="gpt-4", 

23 api_key="sk-...", # or set OPENAI_API_KEY env var 

24 temperature=0.7, 

25 max_tokens=500 

26 ) 

27 

28 async with OpenAIProvider(config) as llm: 

29 # Simple completion 

30 response = await llm.complete("What is Python?") 

31 print(response.content) 

32 

33 # Streaming 

34 async for chunk in llm.stream_complete("Tell a story"): 

35 print(chunk.delta, end="", flush=True) 

36 

37 # Embeddings 

38 embedding = await llm.embed("sample text") 

39 print(f"Dimensions: {len(embedding)}") 

40 ``` 

41 

42See Also: 

43 - OpenAI API Documentation: https://platform.openai.com/docs 

44 - openai Python package: https://github.com/openai/openai-python 

45""" 

46 

47import os 

48from typing import TYPE_CHECKING, Any, Dict, List, Union, AsyncIterator 

49 

50from ..base import ( 

51 LLMConfig, LLMMessage, LLMResponse, LLMStreamResponse, 

52 AsyncLLMProvider, ModelCapability, 

53 LLMAdapter, normalize_llm_config 

54) 

55from dataknobs_llm.prompts import AsyncPromptBuilder 

56 

57if TYPE_CHECKING: 

58 from dataknobs_config.config import Config 

59 

60 

61class OpenAIAdapter(LLMAdapter): 

62 """Adapter for OpenAI API format.""" 

63 

64 def adapt_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]: 

65 """Convert messages to OpenAI format.""" 

66 adapted = [] 

67 for msg in messages: 

68 message = { 

69 'role': msg.role, 

70 'content': msg.content 

71 } 

72 if msg.name: 

73 message['name'] = msg.name 

74 if msg.function_call: 

75 message['function_call'] = msg.function_call 

76 adapted.append(message) 

77 return adapted 

78 

79 def adapt_response(self, response: Any) -> LLMResponse: 

80 """Convert OpenAI response to standard format.""" 

81 choice = response.choices[0] 

82 message = choice.message 

83 

84 return LLMResponse( 

85 content=message.content or '', 

86 model=response.model, 

87 finish_reason=choice.finish_reason, 

88 usage={ 

89 'prompt_tokens': response.usage.prompt_tokens, 

90 'completion_tokens': response.usage.completion_tokens, 

91 'total_tokens': response.usage.total_tokens 

92 } if response.usage else None, 

93 function_call=message.function_call if hasattr(message, 'function_call') else None 

94 ) 

95 

96 def adapt_config(self, config: LLMConfig) -> Dict[str, Any]: 

97 """Convert config to OpenAI parameters.""" 

98 params = { 

99 'model': config.model, 

100 'temperature': config.temperature, 

101 'top_p': config.top_p, 

102 'frequency_penalty': config.frequency_penalty, 

103 'presence_penalty': config.presence_penalty, 

104 } 

105 

106 if config.max_tokens: 

107 params['max_tokens'] = config.max_tokens 

108 if config.stop_sequences: 

109 params['stop'] = config.stop_sequences 

110 if config.seed: 

111 params['seed'] = config.seed 

112 if config.logit_bias: 

113 params['logit_bias'] = config.logit_bias 

114 if config.user_id: 

115 params['user'] = config.user_id 

116 if config.response_format == 'json': 

117 params['response_format'] = {'type': 'json_object'} 

118 if config.functions: 

119 params['functions'] = config.functions 

120 if config.function_call: 

121 params['function_call'] = config.function_call 

122 

123 return params 

124 

125 

126class OpenAIProvider(AsyncLLMProvider): 

127 """OpenAI LLM provider with full API support. 

128 

129 Provides async access to OpenAI's chat, completion, embedding, and 

130 function calling APIs. Supports all GPT models including GPT-4, GPT-3.5, 

131 and specialized models (vision, embeddings). 

132 

133 Features: 

134 - Full GPT-4 and GPT-3.5-turbo support 

135 - Streaming responses for real-time output 

136 - Function calling for tool use 

137 - JSON mode for structured outputs 

138 - Embeddings for semantic search 

139 - Custom API endpoints (e.g., Azure OpenAI) 

140 - Automatic retry with rate limiting 

141 - Cost tracking 

142 

143 Example: 

144 ```python 

145 from dataknobs_llm.llm.providers import OpenAIProvider 

146 from dataknobs_llm.llm.base import LLMConfig, LLMMessage 

147 

148 # Basic usage 

149 config = LLMConfig( 

150 provider="openai", 

151 model="gpt-4", 

152 api_key="sk-...", 

153 temperature=0.7 

154 ) 

155 

156 async with OpenAIProvider(config) as llm: 

157 # Simple question 

158 response = await llm.complete("Explain async/await") 

159 print(response.content) 

160 

161 # Multi-turn conversation 

162 messages = [ 

163 LLMMessage(role="system", content="You are a coding tutor"), 

164 LLMMessage(role="user", content="How do I use asyncio?") 

165 ] 

166 response = await llm.complete(messages) 

167 

168 # JSON mode for structured output 

169 json_config = LLMConfig( 

170 provider="openai", 

171 model="gpt-4", 

172 response_format="json", 

173 system_prompt="Return JSON only" 

174 ) 

175 

176 llm = OpenAIProvider(json_config) 

177 await llm.initialize() 

178 response = await llm.complete( 

179 "List 3 Python libraries as JSON: {name, description}" 

180 ) 

181 import json 

182 data = json.loads(response.content) 

183 

184 # With Azure OpenAI 

185 azure_config = LLMConfig( 

186 provider="openai", 

187 model="gpt-4", 

188 api_base="https://your-resource.openai.azure.com/", 

189 api_key="azure-key" 

190 ) 

191 

192 # Function calling 

193 functions = [{ 

194 "name": "search", 

195 "description": "Search for information", 

196 "parameters": { 

197 "type": "object", 

198 "properties": { 

199 "query": {"type": "string"} 

200 } 

201 } 

202 }] 

203 

204 response = await llm.function_call(messages, functions) 

205 if response.function_call: 

206 print(f"Call: {response.function_call['name']}") 

207 ``` 

208 

209 Args: 

210 config: LLMConfig, dataknobs Config, or dict with provider settings 

211 prompt_builder: Optional AsyncPromptBuilder for prompt rendering 

212 

213 Attributes: 

214 adapter (OpenAIAdapter): Format adapter for OpenAI API 

215 _client: OpenAI AsyncOpenAI client instance 

216 

217 See Also: 

218 LLMConfig: Configuration options 

219 AsyncLLMProvider: Base provider interface 

220 OpenAIAdapter: Format conversion 

221 """ 

222 

223 def __init__( 

224 self, 

225 config: Union[LLMConfig, "Config", Dict[str, Any]], 

226 prompt_builder: AsyncPromptBuilder | None = None 

227 ): 

228 # Normalize config first 

229 llm_config = normalize_llm_config(config) 

230 super().__init__(llm_config, prompt_builder=prompt_builder) 

231 self.adapter = OpenAIAdapter() 

232 

233 async def initialize(self) -> None: 

234 """Initialize OpenAI client.""" 

235 try: 

236 import openai 

237 

238 api_key = self.config.api_key or os.environ.get('OPENAI_API_KEY') 

239 if not api_key: 

240 raise ValueError("OpenAI API key not provided") 

241 

242 self._client = openai.AsyncOpenAI( 

243 api_key=api_key, 

244 base_url=self.config.api_base, 

245 timeout=self.config.timeout 

246 ) 

247 self._is_initialized = True 

248 except ImportError as e: 

249 raise ImportError("openai package not installed. Install with: pip install openai") from e 

250 

251 async def close(self) -> None: 

252 """Close OpenAI client.""" 

253 if self._client: 

254 await self._client.close() # type: ignore[unreachable] 

255 self._is_initialized = False 

256 

257 async def validate_model(self) -> bool: 

258 """Validate model availability.""" 

259 try: 

260 # List available models 

261 models = await self._client.models.list() 

262 model_ids = [m.id for m in models.data] 

263 return self.config.model in model_ids 

264 except Exception: 

265 return False 

266 

267 def get_capabilities(self) -> List[ModelCapability]: 

268 """Get OpenAI model capabilities.""" 

269 capabilities = [ 

270 ModelCapability.TEXT_GENERATION, 

271 ModelCapability.CHAT, 

272 ModelCapability.STREAMING 

273 ] 

274 

275 if 'gpt-4' in self.config.model or 'gpt-3.5' in self.config.model: 

276 capabilities.extend([ 

277 ModelCapability.FUNCTION_CALLING, 

278 ModelCapability.JSON_MODE 

279 ]) 

280 

281 if 'vision' in self.config.model: 

282 capabilities.append(ModelCapability.VISION) 

283 

284 if 'embedding' in self.config.model: 

285 capabilities.append(ModelCapability.EMBEDDINGS) 

286 

287 return capabilities 

288 

289 async def complete( 

290 self, 

291 messages: Union[str, List[LLMMessage]], 

292 **kwargs 

293 ) -> LLMResponse: 

294 """Generate completion.""" 

295 if not self._is_initialized: 

296 await self.initialize() 

297 

298 # Convert string to message list 

299 if isinstance(messages, str): 

300 messages = [LLMMessage(role='user', content=messages)] 

301 

302 # Add system prompt if configured 

303 if self.config.system_prompt and messages[0].role != 'system': 

304 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt)) 

305 

306 # Adapt messages and config 

307 adapted_messages = self.adapter.adapt_messages(messages) 

308 params = self.adapter.adapt_config(self.config) 

309 params.update(kwargs) 

310 

311 # Make API call 

312 response = await self._client.chat.completions.create( 

313 messages=adapted_messages, 

314 **params 

315 ) 

316 

317 return self.adapter.adapt_response(response) 

318 

319 async def stream_complete( 

320 self, 

321 messages: Union[str, List[LLMMessage]], 

322 **kwargs 

323 ) -> AsyncIterator[LLMStreamResponse]: 

324 """Generate streaming completion.""" 

325 if not self._is_initialized: 

326 await self.initialize() 

327 

328 # Convert string to message list 

329 if isinstance(messages, str): 

330 messages = [LLMMessage(role='user', content=messages)] 

331 

332 # Add system prompt if configured 

333 if self.config.system_prompt and messages[0].role != 'system': 

334 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt)) 

335 

336 # Adapt messages and config 

337 adapted_messages = self.adapter.adapt_messages(messages) 

338 params = self.adapter.adapt_config(self.config) 

339 params['stream'] = True 

340 params.update(kwargs) 

341 

342 # Stream API call 

343 stream = await self._client.chat.completions.create( 

344 messages=adapted_messages, 

345 **params 

346 ) 

347 

348 async for chunk in stream: 

349 if chunk.choices[0].delta.content: 

350 yield LLMStreamResponse( 

351 delta=chunk.choices[0].delta.content, 

352 is_final=chunk.choices[0].finish_reason is not None, 

353 finish_reason=chunk.choices[0].finish_reason 

354 ) 

355 

356 async def embed( 

357 self, 

358 texts: Union[str, List[str]], 

359 **kwargs 

360 ) -> Union[List[float], List[List[float]]]: 

361 """Generate embeddings.""" 

362 if not self._is_initialized: 

363 await self.initialize() 

364 

365 if isinstance(texts, str): 

366 texts = [texts] 

367 single = True 

368 else: 

369 single = False 

370 

371 response = await self._client.embeddings.create( 

372 input=texts, 

373 model=self.config.model or 'text-embedding-ada-002' 

374 ) 

375 

376 embeddings = [e.embedding for e in response.data] 

377 return embeddings[0] if single else embeddings 

378 

379 async def function_call( 

380 self, 

381 messages: List[LLMMessage], 

382 functions: List[Dict[str, Any]], 

383 **kwargs 

384 ) -> LLMResponse: 

385 """Execute function calling.""" 

386 if not self._is_initialized: 

387 await self.initialize() 

388 

389 # Add system prompt if configured 

390 if self.config.system_prompt and messages[0].role != 'system': 

391 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt)) 

392 

393 # Adapt messages and config 

394 adapted_messages = self.adapter.adapt_messages(messages) 

395 params = self.adapter.adapt_config(self.config) 

396 params['functions'] = functions 

397 params['function_call'] = kwargs.get('function_call', 'auto') 

398 params.update(kwargs) 

399 

400 # Make API call 

401 response = await self._client.chat.completions.create( 

402 messages=adapted_messages, 

403 **params 

404 ) 

405 

406 return self.adapter.adapt_response(response)