Coverage for src/dataknobs_llm/llm/providers/openai.py: 17%
122 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
1"""OpenAI LLM provider implementation.
3This module provides OpenAI API integration for dataknobs-llm, supporting:
4- GPT-4, GPT-3.5-turbo, and other OpenAI chat models
5- Text embeddings (ada-002, etc.)
6- Function calling / tool use
7- Streaming responses
8- JSON mode for structured outputs
9- Vision models (GPT-4V)
11The OpenAIProvider uses the official OpenAI Python SDK and supports all
12standard OpenAI API parameters.
14Example:
15 ```python
16 from dataknobs_llm.llm.providers import OpenAIProvider
17 from dataknobs_llm.llm.base import LLMConfig
19 # Create provider
20 config = LLMConfig(
21 provider="openai",
22 model="gpt-4",
23 api_key="sk-...", # or set OPENAI_API_KEY env var
24 temperature=0.7,
25 max_tokens=500
26 )
28 async with OpenAIProvider(config) as llm:
29 # Simple completion
30 response = await llm.complete("What is Python?")
31 print(response.content)
33 # Streaming
34 async for chunk in llm.stream_complete("Tell a story"):
35 print(chunk.delta, end="", flush=True)
37 # Embeddings
38 embedding = await llm.embed("sample text")
39 print(f"Dimensions: {len(embedding)}")
40 ```
42See Also:
43 - OpenAI API Documentation: https://platform.openai.com/docs
44 - openai Python package: https://github.com/openai/openai-python
45"""
47import os
48from typing import TYPE_CHECKING, Any, Dict, List, Union, AsyncIterator
50from ..base import (
51 LLMConfig, LLMMessage, LLMResponse, LLMStreamResponse,
52 AsyncLLMProvider, ModelCapability,
53 LLMAdapter, normalize_llm_config
54)
55from dataknobs_llm.prompts import AsyncPromptBuilder
57if TYPE_CHECKING:
58 from dataknobs_config.config import Config
61class OpenAIAdapter(LLMAdapter):
62 """Adapter for OpenAI API format."""
64 def adapt_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]:
65 """Convert messages to OpenAI format."""
66 adapted = []
67 for msg in messages:
68 message = {
69 'role': msg.role,
70 'content': msg.content
71 }
72 if msg.name:
73 message['name'] = msg.name
74 if msg.function_call:
75 message['function_call'] = msg.function_call
76 adapted.append(message)
77 return adapted
79 def adapt_response(self, response: Any) -> LLMResponse:
80 """Convert OpenAI response to standard format."""
81 choice = response.choices[0]
82 message = choice.message
84 return LLMResponse(
85 content=message.content or '',
86 model=response.model,
87 finish_reason=choice.finish_reason,
88 usage={
89 'prompt_tokens': response.usage.prompt_tokens,
90 'completion_tokens': response.usage.completion_tokens,
91 'total_tokens': response.usage.total_tokens
92 } if response.usage else None,
93 function_call=message.function_call if hasattr(message, 'function_call') else None
94 )
96 def adapt_config(self, config: LLMConfig) -> Dict[str, Any]:
97 """Convert config to OpenAI parameters."""
98 params = {
99 'model': config.model,
100 'temperature': config.temperature,
101 'top_p': config.top_p,
102 'frequency_penalty': config.frequency_penalty,
103 'presence_penalty': config.presence_penalty,
104 }
106 if config.max_tokens:
107 params['max_tokens'] = config.max_tokens
108 if config.stop_sequences:
109 params['stop'] = config.stop_sequences
110 if config.seed:
111 params['seed'] = config.seed
112 if config.logit_bias:
113 params['logit_bias'] = config.logit_bias
114 if config.user_id:
115 params['user'] = config.user_id
116 if config.response_format == 'json':
117 params['response_format'] = {'type': 'json_object'}
118 if config.functions:
119 params['functions'] = config.functions
120 if config.function_call:
121 params['function_call'] = config.function_call
123 return params
126class OpenAIProvider(AsyncLLMProvider):
127 """OpenAI LLM provider with full API support.
129 Provides async access to OpenAI's chat, completion, embedding, and
130 function calling APIs. Supports all GPT models including GPT-4, GPT-3.5,
131 and specialized models (vision, embeddings).
133 Features:
134 - Full GPT-4 and GPT-3.5-turbo support
135 - Streaming responses for real-time output
136 - Function calling for tool use
137 - JSON mode for structured outputs
138 - Embeddings for semantic search
139 - Custom API endpoints (e.g., Azure OpenAI)
140 - Automatic retry with rate limiting
141 - Cost tracking
143 Example:
144 ```python
145 from dataknobs_llm.llm.providers import OpenAIProvider
146 from dataknobs_llm.llm.base import LLMConfig, LLMMessage
148 # Basic usage
149 config = LLMConfig(
150 provider="openai",
151 model="gpt-4",
152 api_key="sk-...",
153 temperature=0.7
154 )
156 async with OpenAIProvider(config) as llm:
157 # Simple question
158 response = await llm.complete("Explain async/await")
159 print(response.content)
161 # Multi-turn conversation
162 messages = [
163 LLMMessage(role="system", content="You are a coding tutor"),
164 LLMMessage(role="user", content="How do I use asyncio?")
165 ]
166 response = await llm.complete(messages)
168 # JSON mode for structured output
169 json_config = LLMConfig(
170 provider="openai",
171 model="gpt-4",
172 response_format="json",
173 system_prompt="Return JSON only"
174 )
176 llm = OpenAIProvider(json_config)
177 await llm.initialize()
178 response = await llm.complete(
179 "List 3 Python libraries as JSON: {name, description}"
180 )
181 import json
182 data = json.loads(response.content)
184 # With Azure OpenAI
185 azure_config = LLMConfig(
186 provider="openai",
187 model="gpt-4",
188 api_base="https://your-resource.openai.azure.com/",
189 api_key="azure-key"
190 )
192 # Function calling
193 functions = [{
194 "name": "search",
195 "description": "Search for information",
196 "parameters": {
197 "type": "object",
198 "properties": {
199 "query": {"type": "string"}
200 }
201 }
202 }]
204 response = await llm.function_call(messages, functions)
205 if response.function_call:
206 print(f"Call: {response.function_call['name']}")
207 ```
209 Args:
210 config: LLMConfig, dataknobs Config, or dict with provider settings
211 prompt_builder: Optional AsyncPromptBuilder for prompt rendering
213 Attributes:
214 adapter (OpenAIAdapter): Format adapter for OpenAI API
215 _client: OpenAI AsyncOpenAI client instance
217 See Also:
218 LLMConfig: Configuration options
219 AsyncLLMProvider: Base provider interface
220 OpenAIAdapter: Format conversion
221 """
223 def __init__(
224 self,
225 config: Union[LLMConfig, "Config", Dict[str, Any]],
226 prompt_builder: AsyncPromptBuilder | None = None
227 ):
228 # Normalize config first
229 llm_config = normalize_llm_config(config)
230 super().__init__(llm_config, prompt_builder=prompt_builder)
231 self.adapter = OpenAIAdapter()
233 async def initialize(self) -> None:
234 """Initialize OpenAI client."""
235 try:
236 import openai
238 api_key = self.config.api_key or os.environ.get('OPENAI_API_KEY')
239 if not api_key:
240 raise ValueError("OpenAI API key not provided")
242 self._client = openai.AsyncOpenAI(
243 api_key=api_key,
244 base_url=self.config.api_base,
245 timeout=self.config.timeout
246 )
247 self._is_initialized = True
248 except ImportError as e:
249 raise ImportError("openai package not installed. Install with: pip install openai") from e
251 async def close(self) -> None:
252 """Close OpenAI client."""
253 if self._client:
254 await self._client.close() # type: ignore[unreachable]
255 self._is_initialized = False
257 async def validate_model(self) -> bool:
258 """Validate model availability."""
259 try:
260 # List available models
261 models = await self._client.models.list()
262 model_ids = [m.id for m in models.data]
263 return self.config.model in model_ids
264 except Exception:
265 return False
267 def get_capabilities(self) -> List[ModelCapability]:
268 """Get OpenAI model capabilities."""
269 capabilities = [
270 ModelCapability.TEXT_GENERATION,
271 ModelCapability.CHAT,
272 ModelCapability.STREAMING
273 ]
275 if 'gpt-4' in self.config.model or 'gpt-3.5' in self.config.model:
276 capabilities.extend([
277 ModelCapability.FUNCTION_CALLING,
278 ModelCapability.JSON_MODE
279 ])
281 if 'vision' in self.config.model:
282 capabilities.append(ModelCapability.VISION)
284 if 'embedding' in self.config.model:
285 capabilities.append(ModelCapability.EMBEDDINGS)
287 return capabilities
289 async def complete(
290 self,
291 messages: Union[str, List[LLMMessage]],
292 **kwargs
293 ) -> LLMResponse:
294 """Generate completion."""
295 if not self._is_initialized:
296 await self.initialize()
298 # Convert string to message list
299 if isinstance(messages, str):
300 messages = [LLMMessage(role='user', content=messages)]
302 # Add system prompt if configured
303 if self.config.system_prompt and messages[0].role != 'system':
304 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt))
306 # Adapt messages and config
307 adapted_messages = self.adapter.adapt_messages(messages)
308 params = self.adapter.adapt_config(self.config)
309 params.update(kwargs)
311 # Make API call
312 response = await self._client.chat.completions.create(
313 messages=adapted_messages,
314 **params
315 )
317 return self.adapter.adapt_response(response)
319 async def stream_complete(
320 self,
321 messages: Union[str, List[LLMMessage]],
322 **kwargs
323 ) -> AsyncIterator[LLMStreamResponse]:
324 """Generate streaming completion."""
325 if not self._is_initialized:
326 await self.initialize()
328 # Convert string to message list
329 if isinstance(messages, str):
330 messages = [LLMMessage(role='user', content=messages)]
332 # Add system prompt if configured
333 if self.config.system_prompt and messages[0].role != 'system':
334 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt))
336 # Adapt messages and config
337 adapted_messages = self.adapter.adapt_messages(messages)
338 params = self.adapter.adapt_config(self.config)
339 params['stream'] = True
340 params.update(kwargs)
342 # Stream API call
343 stream = await self._client.chat.completions.create(
344 messages=adapted_messages,
345 **params
346 )
348 async for chunk in stream:
349 if chunk.choices[0].delta.content:
350 yield LLMStreamResponse(
351 delta=chunk.choices[0].delta.content,
352 is_final=chunk.choices[0].finish_reason is not None,
353 finish_reason=chunk.choices[0].finish_reason
354 )
356 async def embed(
357 self,
358 texts: Union[str, List[str]],
359 **kwargs
360 ) -> Union[List[float], List[List[float]]]:
361 """Generate embeddings."""
362 if not self._is_initialized:
363 await self.initialize()
365 if isinstance(texts, str):
366 texts = [texts]
367 single = True
368 else:
369 single = False
371 response = await self._client.embeddings.create(
372 input=texts,
373 model=self.config.model or 'text-embedding-ada-002'
374 )
376 embeddings = [e.embedding for e in response.data]
377 return embeddings[0] if single else embeddings
379 async def function_call(
380 self,
381 messages: List[LLMMessage],
382 functions: List[Dict[str, Any]],
383 **kwargs
384 ) -> LLMResponse:
385 """Execute function calling."""
386 if not self._is_initialized:
387 await self.initialize()
389 # Add system prompt if configured
390 if self.config.system_prompt and messages[0].role != 'system':
391 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt))
393 # Adapt messages and config
394 adapted_messages = self.adapter.adapt_messages(messages)
395 params = self.adapter.adapt_config(self.config)
396 params['functions'] = functions
397 params['function_call'] = kwargs.get('function_call', 'auto')
398 params.update(kwargs)
400 # Make API call
401 response = await self._client.chat.completions.create(
402 messages=adapted_messages,
403 **params
404 )
406 return self.adapter.adapt_response(response)