Coverage for src/dataknobs_llm/llm/providers.py: 13%
662 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-31 16:04 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-31 16:04 -0600
1"""LLM provider implementations.
3This module provides implementations for various LLM providers.
4Supports both direct instantiation and dataknobs Config-based factory pattern.
5"""
7import os
8import json
9import hashlib
10from typing import Any, Dict, List, Union, AsyncIterator, Type, Optional
12from .base import (
13 LLMConfig, LLMMessage, LLMResponse, LLMStreamResponse,
14 AsyncLLMProvider, SyncLLMProvider, ModelCapability,
15 LLMAdapter, normalize_llm_config
16)
18# Import prompt builder types - clean one-way dependency (llm depends on prompts)
19from dataknobs_llm.prompts import AsyncPromptBuilder
22class SyncProviderAdapter:
23 """Sync adapter for async LLM providers."""
25 def __init__(self, async_provider: AsyncLLMProvider):
26 """Initialize with async provider.
28 Args:
29 async_provider: The async provider to wrap.
30 """
31 self.async_provider = async_provider
33 def initialize(self) -> None:
34 """Initialize the provider synchronously."""
35 import asyncio
36 try:
37 loop = asyncio.get_event_loop()
38 except RuntimeError:
39 loop = asyncio.new_event_loop()
40 asyncio.set_event_loop(loop)
42 return loop.run_until_complete(self.async_provider.initialize())
44 def close(self) -> None:
45 """Close the provider synchronously."""
46 import asyncio
47 try:
48 loop = asyncio.get_event_loop()
49 except RuntimeError:
50 loop = asyncio.new_event_loop()
51 asyncio.set_event_loop(loop)
53 return loop.run_until_complete(self.async_provider.close())
55 def complete(
56 self,
57 messages: Union[str, List[LLMMessage]],
58 **kwargs
59 ) -> LLMResponse:
60 """Generate completion synchronously."""
61 import asyncio
62 try:
63 loop = asyncio.get_event_loop()
64 except RuntimeError:
65 loop = asyncio.new_event_loop()
66 asyncio.set_event_loop(loop)
68 return loop.run_until_complete(self.async_provider.complete(messages, **kwargs))
70 def stream(
71 self,
72 messages: Union[str, List[LLMMessage]],
73 **kwargs
74 ):
75 """Stream completion synchronously."""
76 import asyncio
77 try:
78 loop = asyncio.get_event_loop()
79 except RuntimeError:
80 loop = asyncio.new_event_loop()
81 asyncio.set_event_loop(loop)
83 async def _stream():
84 async for chunk in self.async_provider.stream_complete(messages, **kwargs):
85 yield chunk
87 # Convert async generator to sync generator
88 async_gen = _stream()
89 try:
90 while True:
91 try:
92 yield loop.run_until_complete(async_gen.__anext__())
93 except StopAsyncIteration:
94 break
95 finally:
96 loop.run_until_complete(async_gen.aclose())
98 def embed(
99 self,
100 texts: Union[str, List[str]],
101 **kwargs
102 ) -> Union[List[float], List[List[float]]]:
103 """Generate embeddings synchronously."""
104 import asyncio
105 try:
106 loop = asyncio.get_event_loop()
107 except RuntimeError:
108 loop = asyncio.new_event_loop()
109 asyncio.set_event_loop(loop)
111 return loop.run_until_complete(self.async_provider.embed(texts, **kwargs))
113 def function_call(
114 self,
115 messages: List[LLMMessage],
116 functions: List[Dict[str, Any]],
117 **kwargs
118 ) -> LLMResponse:
119 """Make function call synchronously."""
120 import asyncio
121 try:
122 loop = asyncio.get_event_loop()
123 except RuntimeError:
124 loop = asyncio.new_event_loop()
125 asyncio.set_event_loop(loop)
127 return loop.run_until_complete(self.async_provider.function_call(messages, functions, **kwargs))
129 def validate_model(self) -> bool:
130 """Validate model synchronously."""
131 import asyncio
132 try:
133 loop = asyncio.get_event_loop()
134 except RuntimeError:
135 loop = asyncio.new_event_loop()
136 asyncio.set_event_loop(loop)
138 return loop.run_until_complete(self.async_provider.validate_model()) # type: ignore
140 def get_capabilities(self) -> List[ModelCapability]:
141 """Get capabilities synchronously."""
142 return self.async_provider.get_capabilities()
144 @property
145 def is_initialized(self) -> bool:
146 """Check if provider is initialized."""
147 return self.async_provider.is_initialized
150class OpenAIAdapter(LLMAdapter):
151 """Adapter for OpenAI API format."""
153 def adapt_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]:
154 """Convert messages to OpenAI format."""
155 adapted = []
156 for msg in messages:
157 message = {
158 'role': msg.role,
159 'content': msg.content
160 }
161 if msg.name:
162 message['name'] = msg.name
163 if msg.function_call:
164 message['function_call'] = msg.function_call
165 adapted.append(message)
166 return adapted
168 def adapt_response(self, response: Any) -> LLMResponse:
169 """Convert OpenAI response to standard format."""
170 choice = response.choices[0]
171 message = choice.message
173 return LLMResponse(
174 content=message.content or '',
175 model=response.model,
176 finish_reason=choice.finish_reason,
177 usage={
178 'prompt_tokens': response.usage.prompt_tokens,
179 'completion_tokens': response.usage.completion_tokens,
180 'total_tokens': response.usage.total_tokens
181 } if response.usage else None,
182 function_call=message.function_call if hasattr(message, 'function_call') else None
183 )
185 def adapt_config(self, config: LLMConfig) -> Dict[str, Any]:
186 """Convert config to OpenAI parameters."""
187 params = {
188 'model': config.model,
189 'temperature': config.temperature,
190 'top_p': config.top_p,
191 'frequency_penalty': config.frequency_penalty,
192 'presence_penalty': config.presence_penalty,
193 }
195 if config.max_tokens:
196 params['max_tokens'] = config.max_tokens
197 if config.stop_sequences:
198 params['stop'] = config.stop_sequences
199 if config.seed:
200 params['seed'] = config.seed
201 if config.logit_bias:
202 params['logit_bias'] = config.logit_bias
203 if config.user_id:
204 params['user'] = config.user_id
205 if config.response_format == 'json':
206 params['response_format'] = {'type': 'json_object'}
207 if config.functions:
208 params['functions'] = config.functions
209 if config.function_call:
210 params['function_call'] = config.function_call
212 return params
215class OpenAIProvider(AsyncLLMProvider):
216 """OpenAI LLM provider."""
218 def __init__(
219 self,
220 config: Union[LLMConfig, "Config", Dict[str, Any]],
221 prompt_builder: Optional[AsyncPromptBuilder] = None
222 ):
223 # Normalize config first
224 llm_config = normalize_llm_config(config)
225 super().__init__(llm_config, prompt_builder=prompt_builder)
226 self.adapter = OpenAIAdapter()
228 async def initialize(self) -> None:
229 """Initialize OpenAI client."""
230 try:
231 import openai
233 api_key = self.config.api_key or os.environ.get('OPENAI_API_KEY')
234 if not api_key:
235 raise ValueError("OpenAI API key not provided")
237 self._client = openai.AsyncOpenAI(
238 api_key=api_key,
239 base_url=self.config.api_base,
240 timeout=self.config.timeout
241 )
242 self._is_initialized = True
243 except ImportError as e:
244 raise ImportError("openai package not installed. Install with: pip install openai") from e
246 async def close(self) -> None:
247 """Close OpenAI client."""
248 if self._client:
249 await self._client.close() # type: ignore[unreachable]
250 self._is_initialized = False
252 async def validate_model(self) -> bool:
253 """Validate model availability."""
254 try:
255 # List available models
256 models = await self._client.models.list()
257 model_ids = [m.id for m in models.data]
258 return self.config.model in model_ids
259 except Exception:
260 return False
262 def get_capabilities(self) -> List[ModelCapability]:
263 """Get OpenAI model capabilities."""
264 capabilities = [
265 ModelCapability.TEXT_GENERATION,
266 ModelCapability.CHAT,
267 ModelCapability.STREAMING
268 ]
270 if 'gpt-4' in self.config.model or 'gpt-3.5' in self.config.model:
271 capabilities.extend([
272 ModelCapability.FUNCTION_CALLING,
273 ModelCapability.JSON_MODE
274 ])
276 if 'vision' in self.config.model:
277 capabilities.append(ModelCapability.VISION)
279 if 'embedding' in self.config.model:
280 capabilities.append(ModelCapability.EMBEDDINGS)
282 return capabilities
284 async def complete(
285 self,
286 messages: Union[str, List[LLMMessage]],
287 **kwargs
288 ) -> LLMResponse:
289 """Generate completion."""
290 if not self._is_initialized:
291 await self.initialize()
293 # Convert string to message list
294 if isinstance(messages, str):
295 messages = [LLMMessage(role='user', content=messages)]
297 # Add system prompt if configured
298 if self.config.system_prompt and messages[0].role != 'system':
299 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt))
301 # Adapt messages and config
302 adapted_messages = self.adapter.adapt_messages(messages)
303 params = self.adapter.adapt_config(self.config)
304 params.update(kwargs)
306 # Make API call
307 response = await self._client.chat.completions.create(
308 messages=adapted_messages,
309 **params
310 )
312 return self.adapter.adapt_response(response)
314 async def stream_complete(
315 self,
316 messages: Union[str, List[LLMMessage]],
317 **kwargs
318 ) -> AsyncIterator[LLMStreamResponse]:
319 """Generate streaming completion."""
320 if not self._is_initialized:
321 await self.initialize()
323 # Convert string to message list
324 if isinstance(messages, str):
325 messages = [LLMMessage(role='user', content=messages)]
327 # Add system prompt if configured
328 if self.config.system_prompt and messages[0].role != 'system':
329 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt))
331 # Adapt messages and config
332 adapted_messages = self.adapter.adapt_messages(messages)
333 params = self.adapter.adapt_config(self.config)
334 params['stream'] = True
335 params.update(kwargs)
337 # Stream API call
338 stream = await self._client.chat.completions.create(
339 messages=adapted_messages,
340 **params
341 )
343 async for chunk in stream:
344 if chunk.choices[0].delta.content:
345 yield LLMStreamResponse(
346 delta=chunk.choices[0].delta.content,
347 is_final=chunk.choices[0].finish_reason is not None,
348 finish_reason=chunk.choices[0].finish_reason
349 )
351 async def embed(
352 self,
353 texts: Union[str, List[str]],
354 **kwargs
355 ) -> Union[List[float], List[List[float]]]:
356 """Generate embeddings."""
357 if not self._is_initialized:
358 await self.initialize()
360 if isinstance(texts, str):
361 texts = [texts]
362 single = True
363 else:
364 single = False
366 response = await self._client.embeddings.create(
367 input=texts,
368 model=self.config.model or 'text-embedding-ada-002'
369 )
371 embeddings = [e.embedding for e in response.data]
372 return embeddings[0] if single else embeddings
374 async def function_call(
375 self,
376 messages: List[LLMMessage],
377 functions: List[Dict[str, Any]],
378 **kwargs
379 ) -> LLMResponse:
380 """Execute function calling."""
381 if not self._is_initialized:
382 await self.initialize()
384 # Add system prompt if configured
385 if self.config.system_prompt and messages[0].role != 'system':
386 messages.insert(0, LLMMessage(role='system', content=self.config.system_prompt))
388 # Adapt messages and config
389 adapted_messages = self.adapter.adapt_messages(messages)
390 params = self.adapter.adapt_config(self.config)
391 params['functions'] = functions
392 params['function_call'] = kwargs.get('function_call', 'auto')
393 params.update(kwargs)
395 # Make API call
396 response = await self._client.chat.completions.create(
397 messages=adapted_messages,
398 **params
399 )
401 return self.adapter.adapt_response(response)
404class AnthropicProvider(AsyncLLMProvider):
405 """Anthropic Claude LLM provider.
407 Supports latest Anthropic features including:
408 - Native tools API (Claude 3+)
409 - Vision capabilities (Claude 3+)
410 - Streaming responses
411 """
413 def __init__(
414 self,
415 config: Union[LLMConfig, "Config", Dict[str, Any]],
416 prompt_builder: Optional[AsyncPromptBuilder] = None
417 ):
418 # Normalize config first
419 llm_config = normalize_llm_config(config)
420 super().__init__(llm_config, prompt_builder=prompt_builder)
422 async def initialize(self) -> None:
423 """Initialize Anthropic client."""
424 try:
425 import anthropic
427 api_key = self.config.api_key or os.environ.get('ANTHROPIC_API_KEY')
428 if not api_key:
429 raise ValueError("Anthropic API key not provided")
431 self._client = anthropic.AsyncAnthropic(
432 api_key=api_key,
433 base_url=self.config.api_base,
434 timeout=self.config.timeout
435 )
436 self._is_initialized = True
437 except ImportError as e:
438 raise ImportError("anthropic package not installed. Install with: pip install anthropic") from e
440 async def close(self) -> None:
441 """Close Anthropic client."""
442 if self._client:
443 await self._client.close() # type: ignore[unreachable]
444 self._is_initialized = False
446 async def validate_model(self) -> bool:
447 """Validate model availability."""
448 valid_models = [
449 'claude-3-opus', 'claude-3-sonnet', 'claude-3-haiku',
450 'claude-2.1', 'claude-2.0', 'claude-instant-1.2'
451 ]
452 return any(m in self.config.model for m in valid_models)
454 def get_capabilities(self) -> List[ModelCapability]:
455 """Get Anthropic model capabilities."""
456 capabilities = [
457 ModelCapability.TEXT_GENERATION,
458 ModelCapability.CHAT,
459 ModelCapability.STREAMING,
460 ModelCapability.CODE
461 ]
463 # Claude 3+ models support vision and tools
464 if 'claude-3' in self.config.model or 'claude-sonnet' in self.config.model or 'claude-opus' in self.config.model:
465 capabilities.extend([
466 ModelCapability.VISION,
467 ModelCapability.FUNCTION_CALLING
468 ])
470 return capabilities
472 async def complete(
473 self,
474 messages: Union[str, List[LLMMessage]],
475 **kwargs
476 ) -> LLMResponse:
477 """Generate completion."""
478 if not self._is_initialized:
479 await self.initialize()
481 # Convert to Anthropic format
482 if isinstance(messages, str):
483 prompt = messages
484 else:
485 # Build prompt from messages
486 prompt = ""
487 for msg in messages:
488 if msg.role == 'system':
489 prompt = msg.content + "\n\n" + prompt
490 elif msg.role == 'user':
491 prompt += f"\n\nHuman: {msg.content}"
492 elif msg.role == 'assistant':
493 prompt += f"\n\nAssistant: {msg.content}"
494 prompt += "\n\nAssistant:"
496 # Make API call
497 response = await self._client.messages.create(
498 model=self.config.model,
499 messages=[{"role": "user", "content": prompt}],
500 max_tokens=self.config.max_tokens or 1024,
501 temperature=self.config.temperature,
502 top_p=self.config.top_p,
503 stop_sequences=self.config.stop_sequences
504 )
506 return LLMResponse(
507 content=response.content[0].text,
508 model=response.model,
509 finish_reason=response.stop_reason,
510 usage={
511 'prompt_tokens': response.usage.input_tokens,
512 'completion_tokens': response.usage.output_tokens,
513 'total_tokens': response.usage.input_tokens + response.usage.output_tokens
514 } if hasattr(response, 'usage') else None
515 )
517 async def stream_complete(
518 self,
519 messages: Union[str, List[LLMMessage]],
520 **kwargs
521 ) -> AsyncIterator[LLMStreamResponse]:
522 """Generate streaming completion."""
523 if not self._is_initialized:
524 await self.initialize()
526 # Convert to Anthropic format
527 if isinstance(messages, str):
528 prompt = messages
529 else:
530 prompt = self._build_prompt(messages)
532 # Stream API call
533 async with self._client.messages.stream(
534 model=self.config.model,
535 messages=[{"role": "user", "content": prompt}],
536 max_tokens=self.config.max_tokens or 1024,
537 temperature=self.config.temperature
538 ) as stream:
539 async for chunk in stream:
540 if chunk.type == 'content_block_delta':
541 yield LLMStreamResponse(
542 delta=chunk.delta.text,
543 is_final=False
544 )
546 # Final message
547 message = await stream.get_final_message()
548 yield LLMStreamResponse(
549 delta='',
550 is_final=True,
551 finish_reason=message.stop_reason
552 )
554 async def embed(
555 self,
556 texts: Union[str, List[str]],
557 **kwargs
558 ) -> Union[List[float], List[List[float]]]:
559 """Anthropic doesn't provide embeddings."""
560 raise NotImplementedError("Anthropic doesn't provide embedding models")
562 async def function_call(
563 self,
564 messages: List[LLMMessage],
565 functions: List[Dict[str, Any]],
566 **kwargs
567 ) -> LLMResponse:
568 """Execute function calling with native Anthropic tools API (Claude 3+)."""
569 if not self._is_initialized:
570 await self.initialize()
572 # Convert to Anthropic message format
573 anthropic_messages = []
574 system_content = self.config.system_prompt or ''
576 for msg in messages:
577 if msg.role == 'system':
578 # Anthropic uses system parameter, not system messages
579 system_content = msg.content if not system_content else f"{system_content}\n\n{msg.content}"
580 else:
581 anthropic_messages.append({
582 'role': msg.role,
583 'content': msg.content
584 })
586 # Convert functions to Anthropic tools format
587 tools = []
588 for func in functions:
589 tool = {
590 'name': func.get('name', ''),
591 'description': func.get('description', ''),
592 'input_schema': func.get('parameters', {
593 'type': 'object',
594 'properties': {},
595 'required': []
596 })
597 }
598 tools.append(tool)
600 # Make API call with tools
601 try:
602 response = await self._client.messages.create(
603 model=self.config.model,
604 messages=anthropic_messages,
605 system=system_content if system_content else None,
606 tools=tools,
607 max_tokens=self.config.max_tokens or 1024,
608 temperature=self.config.temperature,
609 top_p=self.config.top_p
610 )
612 # Extract response content and tool use
613 content = ''
614 tool_use = None
616 for block in response.content:
617 if block.type == 'text':
618 content += block.text
619 elif block.type == 'tool_use':
620 tool_use = {
621 'name': block.name,
622 'arguments': block.input
623 }
625 llm_response = LLMResponse(
626 content=content,
627 model=response.model,
628 finish_reason=response.stop_reason,
629 usage={
630 'prompt_tokens': response.usage.input_tokens,
631 'completion_tokens': response.usage.output_tokens,
632 'total_tokens': response.usage.input_tokens + response.usage.output_tokens
633 },
634 function_call=tool_use
635 )
637 return llm_response
639 except Exception as e:
640 # Fallback to prompt-based approach for older models
641 import logging
642 logging.warning(f"Anthropic native tools failed, falling back to prompt-based: {e}")
644 function_descriptions = "\n".join([
645 f"- {f['name']}: {f['description']}"
646 for f in functions
647 ])
649 system_prompt = f"""You have access to the following functions:
650{function_descriptions}
652When you need to call a function, respond with:
653FUNCTION_CALL: {{
654 "name": "function_name",
655 "arguments": {{...}}
656}}"""
658 messages_with_system = [
659 LLMMessage(role='system', content=system_prompt)
660 ] + list(messages)
662 response = await self.complete(messages_with_system, **kwargs)
664 # Parse function call from response
665 if 'FUNCTION_CALL:' in response.content:
666 try:
667 func_json = response.content.split('FUNCTION_CALL:')[1].strip()
668 function_call = json.loads(func_json)
669 response.function_call = function_call
670 except (json.JSONDecodeError, IndexError):
671 pass
673 return response
675 def _build_prompt(self, messages: List[LLMMessage]) -> str:
676 """Build Anthropic-style prompt from messages."""
677 prompt = ""
678 for msg in messages:
679 if msg.role == 'system':
680 prompt = msg.content + "\n\n" + prompt
681 elif msg.role == 'user':
682 prompt += f"\n\nHuman: {msg.content}"
683 elif msg.role == 'assistant':
684 prompt += f"\n\nAssistant: {msg.content}"
685 prompt += "\n\nAssistant:"
686 return prompt
689class OllamaProvider(AsyncLLMProvider):
690 """Ollama local LLM provider.
692 Supports latest Ollama features including:
693 - Native tools/function calling (Ollama 0.1.17+)
694 - Chat endpoint with message history
695 - Streaming responses
696 - Embeddings
697 """
699 def __init__(
700 self,
701 config: Union[LLMConfig, "Config", Dict[str, Any]],
702 prompt_builder: Optional[AsyncPromptBuilder] = None
703 ):
704 # Normalize config first
705 llm_config = normalize_llm_config(config)
706 super().__init__(llm_config, prompt_builder=prompt_builder)
708 # Check for Docker environment and adjust URL accordingly
709 default_url = 'http://localhost:11434'
710 if os.path.exists('/.dockerenv'):
711 # Running in Docker, use host.docker.internal
712 default_url = 'http://host.docker.internal:11434'
714 # Allow environment variable override
715 self.base_url = llm_config.api_base or os.environ.get('OLLAMA_BASE_URL', default_url)
717 def _build_options(self) -> Dict[str, Any]:
718 """Build options dict for Ollama API calls.
720 Returns:
721 Dictionary of options for the API request.
722 """
723 options: Dict[str, Any] = {
724 'temperature': self.config.temperature,
725 'top_p': self.config.top_p
726 }
728 if self.config.seed is not None:
729 options['seed'] = self.config.seed
731 if self.config.max_tokens:
732 options['num_predict'] = self.config.max_tokens # type: ignore
734 if self.config.stop_sequences:
735 options['stop'] = self.config.stop_sequences # type: ignore
737 return options
739 def _messages_to_ollama(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]:
740 """Convert LLMMessage list to Ollama chat format.
742 Args:
743 messages: List of LLM messages
745 Returns:
746 List of message dicts in Ollama format
747 """
748 ollama_messages = []
749 for msg in messages:
750 message = {
751 'role': msg.role,
752 'content': msg.content
753 }
754 # Ollama supports images in messages for vision models
755 if msg.metadata.get('images'):
756 message['images'] = msg.metadata['images']
757 ollama_messages.append(message)
758 return ollama_messages
760 def _adapt_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
761 """Adapt tools to Ollama format.
763 Ollama uses a similar format to OpenAI for tools.
765 Args:
766 tools: List of tool definitions
768 Returns:
769 List of tools in Ollama format
770 """
771 # Ollama format is similar to OpenAI
772 ollama_tools = []
773 for tool in tools:
774 ollama_tools.append({
775 'type': 'function',
776 'function': {
777 'name': tool.get('name'),
778 'description': tool.get('description', ''),
779 'parameters': tool.get('parameters', {})
780 }
781 })
782 return ollama_tools
784 async def initialize(self) -> None:
785 """Initialize Ollama client."""
786 try:
787 import aiohttp
788 self._session = aiohttp.ClientSession(
789 timeout=aiohttp.ClientTimeout(total=self.config.timeout or 30.0)
790 )
792 # Test connection and verify model availability
793 try:
794 async with self._session.get(f"{self.base_url}/api/tags") as response:
795 if response.status == 200:
796 data = await response.json()
797 models = [m['name'] for m in data.get('models', [])]
798 if models:
799 # Check if configured model is available
800 if self.config.model not in models:
801 # Try without tag (e.g., 'llama2' instead of 'llama2:latest')
802 base_model = self.config.model.split(':')[0]
803 matching_models = [m for m in models if m.startswith(base_model)]
804 if matching_models:
805 # Use first matching model
806 self.config.model = matching_models[0]
807 import logging
808 logging.info(f"Ollama: Using model {self.config.model}")
809 else:
810 import logging
811 logging.warning(f"Ollama: Model {self.config.model} not found. Available: {models}")
812 else:
813 import logging
814 logging.warning("Ollama: No models found. Please pull a model first.")
815 else:
816 import logging
817 logging.warning(f"Ollama: API returned status {response.status}")
818 except aiohttp.ClientError as e:
819 import logging
820 logging.warning(f"Ollama: Could not connect to {self.base_url}: {e}")
822 self._is_initialized = True
823 except ImportError as e:
824 raise ImportError("aiohttp package not installed. Install with: pip install aiohttp") from e
826 async def close(self) -> None:
827 """Close Ollama client."""
828 if hasattr(self, '_session') and self._session:
829 await self._session.close()
830 self._is_initialized = False
832 async def validate_model(self) -> bool:
833 """Validate model availability."""
834 if not self._is_initialized or not hasattr(self, '_session'):
835 return False
837 try:
838 async with self._session.get(f"{self.base_url}/api/tags") as response:
839 if response.status == 200:
840 data = await response.json()
841 models = [m['name'] for m in data.get('models', [])]
842 # Check exact match or base model match
843 if self.config.model in models:
844 return True
845 base_model = self.config.model.split(':')[0]
846 return any(m.startswith(base_model) for m in models)
847 except Exception:
848 return False
849 return False
851 def get_capabilities(self) -> List[ModelCapability]:
852 """Get Ollama model capabilities."""
853 # Capabilities depend on the specific model
854 capabilities = [
855 ModelCapability.TEXT_GENERATION,
856 ModelCapability.CHAT,
857 ModelCapability.STREAMING
858 ]
860 # Most recent Ollama models support function calling
861 if any(model in self.config.model.lower() for model in ['llama3', 'mistral', 'mixtral', 'qwen']):
862 capabilities.append(ModelCapability.FUNCTION_CALLING)
864 if 'llava' in self.config.model.lower():
865 capabilities.append(ModelCapability.VISION)
867 if 'codellama' in self.config.model.lower() or 'codegemma' in self.config.model.lower():
868 capabilities.append(ModelCapability.CODE)
870 return capabilities
872 async def complete(
873 self,
874 messages: Union[str, List[LLMMessage]],
875 **kwargs
876 ) -> LLMResponse:
877 """Generate completion using Ollama chat endpoint."""
878 if not self._is_initialized:
879 await self.initialize()
881 # Convert to message list
882 if isinstance(messages, str):
883 messages = [LLMMessage(role='user', content=messages)]
885 # Add system prompt if configured
886 if self.config.system_prompt and (not messages or messages[0].role != 'system'):
887 messages = [LLMMessage(role='system', content=self.config.system_prompt)] + list(messages)
889 # Convert to Ollama format
890 ollama_messages = self._messages_to_ollama(messages)
892 # Build payload for chat endpoint
893 payload = {
894 'model': self.config.model,
895 'messages': ollama_messages,
896 'stream': False,
897 'options': self._build_options()
898 }
900 # Add format if JSON mode requested
901 if self.config.response_format == 'json':
902 payload['format'] = 'json'
904 async with self._session.post(f"{self.base_url}/api/chat", json=payload) as response:
905 response.raise_for_status()
906 data = await response.json()
908 # Extract response
909 content = data.get('message', {}).get('content', '')
911 return LLMResponse(
912 content=content,
913 model=self.config.model,
914 finish_reason='stop' if data.get('done') else 'length',
915 usage={
916 'prompt_tokens': data.get('prompt_eval_count', 0),
917 'completion_tokens': data.get('eval_count', 0),
918 'total_tokens': data.get('prompt_eval_count', 0) + data.get('eval_count', 0)
919 } if 'eval_count' in data else None,
920 metadata={
921 'eval_duration': data.get('eval_duration'),
922 'total_duration': data.get('total_duration'),
923 'model_info': data.get('model', '')
924 }
925 )
927 async def stream_complete(
928 self,
929 messages: Union[str, List[LLMMessage]],
930 **kwargs
931 ) -> AsyncIterator[LLMStreamResponse]:
932 """Generate streaming completion."""
933 if not self._is_initialized:
934 await self.initialize()
936 # Convert to Ollama format
937 if isinstance(messages, str):
938 prompt = messages
939 else:
940 prompt = self._build_prompt(messages)
942 # Stream API call
943 payload = {
944 'model': self.config.model,
945 'prompt': prompt,
946 'stream': True,
947 'options': self._build_options()
948 }
950 async with self._session.post(f"{self.base_url}/api/generate", json=payload) as response:
951 response.raise_for_status()
953 async for line in response.content:
954 if line:
955 data = json.loads(line.decode('utf-8'))
956 yield LLMStreamResponse(
957 delta=data.get('response', ''),
958 is_final=data.get('done', False),
959 finish_reason='stop' if data.get('done') else None
960 )
962 async def embed(
963 self,
964 texts: Union[str, List[str]],
965 **kwargs
966 ) -> Union[List[float], List[List[float]]]:
967 """Generate embeddings."""
968 if not self._is_initialized:
969 await self.initialize()
971 if isinstance(texts, str):
972 texts = [texts]
973 single = True
974 else:
975 single = False
977 embeddings = []
978 for text in texts:
979 payload = {
980 'model': self.config.model,
981 'prompt': text
982 }
984 async with self._session.post(f"{self.base_url}/api/embeddings", json=payload) as response:
985 response.raise_for_status()
986 data = await response.json()
987 embeddings.append(data['embedding'])
989 return embeddings[0] if single else embeddings
991 async def function_call(
992 self,
993 messages: List[LLMMessage],
994 functions: List[Dict[str, Any]],
995 **kwargs
996 ) -> LLMResponse:
997 """Execute function calling with native Ollama tools support.
999 For Ollama 0.1.17+, uses native tools API.
1000 Falls back to prompt-based approach for older versions.
1001 """
1002 if not self._is_initialized:
1003 await self.initialize()
1005 # Add system prompt if configured
1006 if self.config.system_prompt and (not messages or messages[0].role != 'system'):
1007 messages = [LLMMessage(role='system', content=self.config.system_prompt)] + list(messages)
1009 # Convert to Ollama format
1010 ollama_messages = self._messages_to_ollama(messages)
1012 # Adapt tools to Ollama format
1013 ollama_tools = self._adapt_tools(functions)
1015 # Build payload with tools
1016 payload = {
1017 'model': self.config.model,
1018 'messages': ollama_messages,
1019 'tools': ollama_tools,
1020 'stream': False,
1021 'options': self._build_options()
1022 }
1024 try:
1025 async with self._session.post(f"{self.base_url}/api/chat", json=payload) as response:
1026 response.raise_for_status()
1027 data = await response.json()
1029 # Extract response and tool calls
1030 message = data.get('message', {})
1031 content = message.get('content', '')
1032 tool_calls = message.get('tool_calls', [])
1034 # Build response
1035 llm_response = LLMResponse(
1036 content=content,
1037 model=self.config.model,
1038 finish_reason='tool_calls' if tool_calls else 'stop',
1039 usage={
1040 'prompt_tokens': data.get('prompt_eval_count', 0),
1041 'completion_tokens': data.get('eval_count', 0),
1042 'total_tokens': data.get('prompt_eval_count', 0) + data.get('eval_count', 0)
1043 } if 'eval_count' in data else None
1044 )
1046 # Add tool call information if present
1047 if tool_calls:
1048 # Use first tool call (Ollama can return multiple)
1049 tool_call = tool_calls[0]
1050 llm_response.function_call = {
1051 'name': tool_call.get('function', {}).get('name', ''),
1052 'arguments': tool_call.get('function', {}).get('arguments', {})
1053 }
1055 return llm_response
1057 except Exception as e:
1058 # Fallback to prompt-based approach if native tools not supported
1059 import logging
1060 logging.warning(f"Ollama native tools failed, falling back to prompt-based: {e}")
1062 function_descriptions = json.dumps(functions, indent=2)
1064 system_prompt = f"""You have access to these functions:
1065{function_descriptions}
1067To call a function, respond with JSON:
1068{{"function": "name", "arguments": {{...}}}}"""
1070 messages_with_system = [
1071 LLMMessage(role='system', content=system_prompt)
1072 ] + list(messages)
1074 response = await self.complete(messages_with_system, **kwargs)
1076 # Try to parse function call
1077 try:
1078 func_data = json.loads(response.content)
1079 if 'function' in func_data:
1080 response.function_call = {
1081 'name': func_data['function'],
1082 'arguments': func_data.get('arguments', {})
1083 }
1084 except json.JSONDecodeError:
1085 pass
1087 return response
1089 def _build_prompt(self, messages: List[LLMMessage]) -> str:
1090 """Build prompt from messages."""
1091 prompt = ""
1092 for msg in messages:
1093 if msg.role == 'system':
1094 prompt += f"System: {msg.content}\n\n"
1095 elif msg.role == 'user':
1096 prompt += f"User: {msg.content}\n\n"
1097 elif msg.role == 'assistant':
1098 prompt += f"Assistant: {msg.content}\n\n"
1099 return prompt
1102class HuggingFaceProvider(AsyncLLMProvider):
1103 """HuggingFace Inference API provider."""
1105 def __init__(
1106 self,
1107 config: Union[LLMConfig, "Config", Dict[str, Any]],
1108 prompt_builder: Optional[AsyncPromptBuilder] = None
1109 ):
1110 # Normalize config first
1111 llm_config = normalize_llm_config(config)
1112 super().__init__(llm_config, prompt_builder=prompt_builder)
1113 self.base_url = llm_config.api_base or 'https://api-inference.huggingface.co/models'
1115 async def initialize(self) -> None:
1116 """Initialize HuggingFace client."""
1117 try:
1118 import aiohttp
1120 api_key = self.config.api_key or os.environ.get('HUGGINGFACE_API_KEY')
1121 if not api_key:
1122 raise ValueError("HuggingFace API key not provided")
1124 self._session = aiohttp.ClientSession(
1125 headers={'Authorization': f'Bearer {api_key}'},
1126 timeout=aiohttp.ClientTimeout(total=self.config.timeout)
1127 )
1128 self._is_initialized = True
1129 except ImportError as e:
1130 raise ImportError("aiohttp package not installed. Install with: pip install aiohttp") from e
1132 async def close(self) -> None:
1133 """Close HuggingFace client."""
1134 if hasattr(self, '_session') and self._session:
1135 await self._session.close()
1136 self._is_initialized = False
1138 async def validate_model(self) -> bool:
1139 """Validate model availability."""
1140 try:
1141 url = f"{self.base_url}/{self.config.model}"
1142 async with self._session.get(url) as response:
1143 return response.status == 200
1144 except Exception:
1145 return False
1147 def get_capabilities(self) -> List[ModelCapability]:
1148 """Get HuggingFace model capabilities."""
1149 # Basic capabilities for text generation models
1150 return [
1151 ModelCapability.TEXT_GENERATION,
1152 ModelCapability.EMBEDDINGS if 'embedding' in self.config.model else None # type: ignore
1153 ]
1155 async def complete(
1156 self,
1157 messages: Union[str, List[LLMMessage]],
1158 **kwargs
1159 ) -> LLMResponse:
1160 """Generate completion."""
1161 if not self._is_initialized:
1162 await self.initialize()
1164 # Convert to prompt
1165 if isinstance(messages, str):
1166 prompt = messages
1167 else:
1168 prompt = self._build_prompt(messages)
1170 # Make API call
1171 url = f"{self.base_url}/{self.config.model}"
1172 payload = {
1173 'inputs': prompt,
1174 'parameters': {
1175 'temperature': self.config.temperature,
1176 'top_p': self.config.top_p,
1177 'max_new_tokens': self.config.max_tokens or 100,
1178 'return_full_text': False
1179 }
1180 }
1182 async with self._session.post(url, json=payload) as response:
1183 response.raise_for_status()
1184 data = await response.json()
1186 # Parse response
1187 if isinstance(data, list) and len(data) > 0:
1188 text = data[0].get('generated_text', '')
1189 else:
1190 text = str(data)
1192 return LLMResponse(
1193 content=text,
1194 model=self.config.model,
1195 finish_reason='stop'
1196 )
1198 async def stream_complete(
1199 self,
1200 messages: Union[str, List[LLMMessage]],
1201 **kwargs
1202 ) -> AsyncIterator[LLMStreamResponse]:
1203 """HuggingFace Inference API doesn't support streaming."""
1204 # Simulate streaming by yielding complete response
1205 response = await self.complete(messages, **kwargs)
1206 yield LLMStreamResponse(
1207 delta=response.content,
1208 is_final=True,
1209 finish_reason=response.finish_reason
1210 )
1212 async def embed(
1213 self,
1214 texts: Union[str, List[str]],
1215 **kwargs
1216 ) -> Union[List[float], List[List[float]]]:
1217 """Generate embeddings."""
1218 if not self._is_initialized:
1219 await self.initialize()
1221 if isinstance(texts, str):
1222 texts = [texts]
1223 single = True
1224 else:
1225 single = False
1227 url = f"{self.base_url}/{self.config.model}"
1228 payload = {'inputs': texts}
1230 async with self._session.post(url, json=payload) as response:
1231 response.raise_for_status()
1232 embeddings = await response.json()
1234 return embeddings[0] if single else embeddings
1236 async def function_call(
1237 self,
1238 messages: List[LLMMessage],
1239 functions: List[Dict[str, Any]],
1240 **kwargs
1241 ) -> LLMResponse:
1242 """HuggingFace doesn't have native function calling."""
1243 raise NotImplementedError("Function calling not supported for HuggingFace models")
1245 def _build_prompt(self, messages: List[LLMMessage]) -> str:
1246 """Build prompt from messages."""
1247 prompt = ""
1248 for msg in messages:
1249 if msg.role == 'system':
1250 prompt += f"{msg.content}\n\n"
1251 elif msg.role == 'user':
1252 prompt += f"User: {msg.content}\n"
1253 elif msg.role == 'assistant':
1254 prompt += f"Assistant: {msg.content}\n"
1255 return prompt
1258class EchoProvider(AsyncLLMProvider):
1259 """Echo provider for testing and debugging.
1261 This provider echoes back input messages and generates deterministic
1262 mock embeddings. Perfect for testing without real LLM API calls.
1264 Features:
1265 - Echoes back user messages with configurable prefix
1266 - Generates deterministic embeddings based on content hash
1267 - Supports streaming (character-by-character echo)
1268 - Mocks function calling with deterministic responses
1269 - Zero external dependencies
1270 - Instant responses
1271 """
1273 def __init__(
1274 self,
1275 config: Union[LLMConfig, "Config", Dict[str, Any]],
1276 prompt_builder: Optional[AsyncPromptBuilder] = None
1277 ):
1278 # Normalize config first
1279 llm_config = normalize_llm_config(config)
1280 super().__init__(llm_config, prompt_builder=prompt_builder)
1282 # Echo-specific configuration from options
1283 self.echo_prefix = llm_config.options.get('echo_prefix', 'Echo: ')
1284 self.embedding_dim = llm_config.options.get('embedding_dim', 768)
1285 self.mock_tokens = llm_config.options.get('mock_tokens', True)
1286 self.stream_delay = llm_config.options.get('stream_delay', 0.0) # seconds per char
1288 def _generate_embedding(self, text: str) -> List[float]:
1289 """Generate deterministic embedding vector from text.
1291 Uses SHA-256 hash to create a deterministic vector that:
1292 - Is always the same for the same input
1293 - Distributes values across [-1, 1] range
1294 - Has configurable dimensionality
1296 Args:
1297 text: Input text
1299 Returns:
1300 Embedding vector of size self.embedding_dim
1301 """
1302 # Create hash of the text
1303 hash_obj = hashlib.sha256(text.encode('utf-8'))
1304 hash_bytes = hash_obj.digest()
1306 # Generate embedding by repeatedly hashing
1307 embedding = []
1308 current_hash = hash_bytes
1310 while len(embedding) < self.embedding_dim:
1311 # Convert hash bytes to floats in [-1, 1]
1312 for byte in current_hash:
1313 if len(embedding) >= self.embedding_dim:
1314 break
1315 # Normalize byte (0-255) to [-1, 1]
1316 embedding.append((byte / 127.5) - 1.0)
1318 # Rehash for next batch of values
1319 current_hash = hashlib.sha256(current_hash).digest()
1321 return embedding[:self.embedding_dim]
1323 def _count_tokens(self, text: str) -> int:
1324 """Mock token counting (simple character-based estimate).
1326 Args:
1327 text: Input text
1329 Returns:
1330 Estimated token count
1331 """
1332 # Rough approximation: 1 token ~= 4 characters
1333 return max(1, len(text) // 4)
1335 async def initialize(self) -> None:
1336 """Initialize echo provider (no-op)."""
1337 self._is_initialized = True
1339 async def close(self) -> None:
1340 """Close echo provider (no-op)."""
1341 self._is_initialized = False
1343 async def validate_model(self) -> bool:
1344 """Validate model (always true for echo)."""
1345 return True
1347 def get_capabilities(self) -> List[ModelCapability]:
1348 """Get echo provider capabilities."""
1349 return [
1350 ModelCapability.TEXT_GENERATION,
1351 ModelCapability.CHAT,
1352 ModelCapability.EMBEDDINGS,
1353 ModelCapability.FUNCTION_CALLING,
1354 ModelCapability.STREAMING,
1355 ModelCapability.JSON_MODE
1356 ]
1358 async def complete(
1359 self,
1360 messages: Union[str, List[LLMMessage]],
1361 **kwargs
1362 ) -> LLMResponse:
1363 """Echo back the input messages.
1365 Args:
1366 messages: Input messages or prompt
1367 **kwargs: Additional parameters (ignored)
1369 Returns:
1370 Echo response
1371 """
1372 if not self._is_initialized:
1373 await self.initialize()
1375 # Convert to message list
1376 if isinstance(messages, str):
1377 messages = [LLMMessage(role='user', content=messages)]
1379 # Build echo response from last user message
1380 user_messages = [msg for msg in messages if msg.role == 'user']
1381 if user_messages:
1382 content = self.echo_prefix + user_messages[-1].content
1383 else:
1384 content = self.echo_prefix + "(no user message)"
1386 # Add system prompt if configured and in echo
1387 if self.config.system_prompt and self.config.options.get('echo_system', False):
1388 content = f"[System: {self.config.system_prompt}]\n{content}"
1390 # Mock token usage
1391 prompt_tokens = sum(self._count_tokens(msg.content) for msg in messages)
1392 completion_tokens = self._count_tokens(content)
1394 return LLMResponse(
1395 content=content,
1396 model=self.config.model or 'echo-model',
1397 finish_reason='stop',
1398 usage={
1399 'prompt_tokens': prompt_tokens,
1400 'completion_tokens': completion_tokens,
1401 'total_tokens': prompt_tokens + completion_tokens
1402 } if self.mock_tokens else None
1403 )
1405 async def stream_complete(
1406 self,
1407 messages: Union[str, List[LLMMessage]],
1408 **kwargs
1409 ) -> AsyncIterator[LLMStreamResponse]:
1410 """Stream echo response character by character.
1412 Args:
1413 messages: Input messages or prompt
1414 **kwargs: Additional parameters (ignored)
1416 Yields:
1417 Streaming response chunks
1418 """
1419 if not self._is_initialized:
1420 await self.initialize()
1422 # Get full response
1423 response = await self.complete(messages, **kwargs)
1425 # Stream character by character
1426 for i, char in enumerate(response.content):
1427 is_final = (i == len(response.content) - 1)
1429 yield LLMStreamResponse(
1430 delta=char,
1431 is_final=is_final,
1432 finish_reason='stop' if is_final else None,
1433 usage=response.usage if is_final else None
1434 )
1436 # Optional delay for realistic streaming
1437 if self.stream_delay > 0:
1438 import asyncio
1439 await asyncio.sleep(self.stream_delay)
1441 async def embed(
1442 self,
1443 texts: Union[str, List[str]],
1444 **kwargs
1445 ) -> Union[List[float], List[List[float]]]:
1446 """Generate deterministic mock embeddings.
1448 Args:
1449 texts: Input text(s)
1450 **kwargs: Additional parameters (ignored)
1452 Returns:
1453 Embedding vector(s)
1454 """
1455 if not self._is_initialized:
1456 await self.initialize()
1458 if isinstance(texts, str):
1459 return self._generate_embedding(texts)
1460 else:
1461 return [self._generate_embedding(text) for text in texts]
1463 async def function_call(
1464 self,
1465 messages: List[LLMMessage],
1466 functions: List[Dict[str, Any]],
1467 **kwargs
1468 ) -> LLMResponse:
1469 """Mock function calling with deterministic response.
1471 Args:
1472 messages: Conversation messages
1473 functions: Available functions
1474 **kwargs: Additional parameters (ignored)
1476 Returns:
1477 Response with mock function call
1478 """
1479 if not self._is_initialized:
1480 await self.initialize()
1482 # Get last user message
1483 user_messages = [msg for msg in messages if msg.role == 'user']
1484 user_content = user_messages[-1].content if user_messages else ""
1486 # Mock function call: use first function with mock arguments
1487 if functions:
1488 first_func = functions[0]
1489 func_name = first_func.get('name', 'unknown_function')
1491 # Generate mock arguments based on parameters schema
1492 params = first_func.get('parameters', {})
1493 properties = params.get('properties', {})
1495 mock_args = {}
1496 for param_name, param_schema in properties.items():
1497 param_type = param_schema.get('type', 'string')
1499 # Generate mock value based on type
1500 if param_type == 'string':
1501 mock_args[param_name] = f"mock_{param_name}_from_echo"
1502 elif param_type == 'number' or param_type == 'integer':
1503 # Use hash to generate deterministic number
1504 hash_val = int(hashlib.md5(user_content.encode()).hexdigest()[:8], 16)
1505 mock_args[param_name] = hash_val % 100
1506 elif param_type == 'boolean':
1507 # Deterministic boolean based on hash
1508 hash_val = int(hashlib.md5(user_content.encode()).hexdigest()[:2], 16)
1509 mock_args[param_name] = hash_val % 2 == 0
1510 elif param_type == 'array':
1511 mock_args[param_name] = ["mock_item_1", "mock_item_2"]
1512 elif param_type == 'object':
1513 mock_args[param_name] = {"mock_key": "mock_value"}
1514 else:
1515 mock_args[param_name] = None
1517 # Build response with function call
1518 content = f"{self.echo_prefix}Calling function '{func_name}'"
1520 prompt_tokens = sum(self._count_tokens(msg.content) for msg in messages)
1521 completion_tokens = self._count_tokens(content)
1523 return LLMResponse(
1524 content=content,
1525 model=self.config.model or 'echo-model',
1526 finish_reason='function_call',
1527 usage={
1528 'prompt_tokens': prompt_tokens,
1529 'completion_tokens': completion_tokens,
1530 'total_tokens': prompt_tokens + completion_tokens
1531 } if self.mock_tokens else None,
1532 function_call={
1533 'name': func_name,
1534 'arguments': mock_args
1535 }
1536 )
1537 else:
1538 # No functions provided, just echo
1539 return await self.complete(messages, **kwargs)
1542class LLMProviderFactory:
1543 """Factory for creating LLM providers from configuration.
1545 This factory class integrates with the dataknobs Config system,
1546 allowing providers to be instantiated via Config.get_factory().
1548 Example:
1549 >>> from dataknobs_config import Config
1550 >>> config = Config({
1551 ... "llm": [{
1552 ... "name": "gpt4",
1553 ... "provider": "openai",
1554 ... "model": "gpt-4",
1555 ... "factory": "dataknobs_llm.LLMProviderFactory"
1556 ... }]
1557 ... })
1558 >>> factory = config.get_factory("llm", "gpt4")
1559 >>> provider = factory.create(config.get("llm", "gpt4"))
1560 """
1562 # Registry of provider classes
1563 _providers: Dict[str, Type[AsyncLLMProvider]] = {
1564 'openai': None, # type: ignore # Populated lazily
1565 'anthropic': None, # type: ignore
1566 'ollama': None, # type: ignore
1567 'huggingface': None, # type: ignore
1568 'echo': None, # type: ignore
1569 }
1571 def __init__(self, is_async: bool = True):
1572 """Initialize the factory.
1574 Args:
1575 is_async: Whether to create async providers (default: True)
1576 """
1577 self.is_async = is_async
1579 # Lazily populate provider registry
1580 if LLMProviderFactory._providers['openai'] is None:
1581 LLMProviderFactory._providers.update({
1582 'openai': OpenAIProvider,
1583 'anthropic': AnthropicProvider,
1584 'ollama': OllamaProvider,
1585 'huggingface': HuggingFaceProvider,
1586 'echo': EchoProvider,
1587 })
1589 def create(
1590 self,
1591 config: Union[LLMConfig, "Config", Dict[str, Any]],
1592 **kwargs: Any
1593 ) -> Union[AsyncLLMProvider, SyncLLMProvider]:
1594 """Create an LLM provider from configuration.
1596 Args:
1597 config: Configuration (LLMConfig, Config object, or dict)
1598 **kwargs: Additional arguments passed to provider constructor
1600 Returns:
1601 LLM provider instance
1603 Raises:
1604 ValueError: If provider type is unknown
1605 """
1606 # Normalize config to LLMConfig
1607 llm_config = normalize_llm_config(config)
1609 # Get provider class
1610 provider_class = self._providers.get(llm_config.provider.lower())
1611 if not provider_class:
1612 raise ValueError(
1613 f"Unknown provider: {llm_config.provider}. "
1614 f"Available providers: {list(self._providers.keys())}"
1615 )
1617 # Create provider instance
1618 if self.is_async:
1619 return provider_class(llm_config)
1620 else:
1621 # Wrap in sync adapter
1622 async_provider = provider_class(llm_config)
1623 return SyncProviderAdapter(async_provider) # type: ignore
1625 @classmethod
1626 def register_provider(
1627 cls,
1628 name: str,
1629 provider_class: Type[AsyncLLMProvider]
1630 ) -> None:
1631 """Register a custom provider class.
1633 Allows extending the factory with custom provider implementations.
1635 Args:
1636 name: Provider name (e.g., 'custom')
1637 provider_class: Provider class (must inherit from AsyncLLMProvider)
1639 Example:
1640 >>> class CustomProvider(AsyncLLMProvider):
1641 ... pass
1642 >>> LLMProviderFactory.register_provider('custom', CustomProvider)
1643 """
1644 cls._providers[name.lower()] = provider_class
1646 def __call__(
1647 self,
1648 config: Union[LLMConfig, "Config", Dict[str, Any]],
1649 **kwargs: Any
1650 ) -> Union[AsyncLLMProvider, SyncLLMProvider]:
1651 """Allow factory to be called directly.
1653 Makes the factory callable for convenience.
1655 Args:
1656 config: Configuration
1657 **kwargs: Additional arguments
1659 Returns:
1660 LLM provider instance
1661 """
1662 return self.create(config, **kwargs)
1665def create_llm_provider(
1666 config: Union[LLMConfig, "Config", Dict[str, Any]],
1667 is_async: bool = True
1668) -> Union[AsyncLLMProvider, SyncLLMProvider]:
1669 """Create appropriate LLM provider based on configuration.
1671 Convenience function that uses LLMProviderFactory internally.
1672 Now supports LLMConfig, Config objects, and dictionaries.
1674 Args:
1675 config: LLM configuration (LLMConfig, Config, or dict)
1676 is_async: Whether to create async provider
1678 Returns:
1679 LLM provider instance
1681 Example:
1682 >>> # Direct usage with dict
1683 >>> provider = create_llm_provider({
1684 ... "provider": "openai",
1685 ... "model": "gpt-4",
1686 ... "api_key": "..."
1687 ... })
1689 >>> # With Config object
1690 >>> from dataknobs_config import Config
1691 >>> config = Config({"llm": [{"provider": "openai", "model": "gpt-4"}]})
1692 >>> provider = create_llm_provider(config)
1693 """
1694 factory = LLMProviderFactory(is_async=is_async)
1695 return factory.create(config)