Coverage for src / dataknobs_llm / llm / providers / huggingface.py: 20%
75 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-15 10:29 -0700
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-15 10:29 -0700
1"""HuggingFace Inference API provider implementation."""
3import os
4from typing import TYPE_CHECKING, Any, Dict, List, Union, AsyncIterator
6from ..base import (
7 LLMConfig, LLMMessage, LLMResponse, LLMStreamResponse,
8 AsyncLLMProvider, ModelCapability,
9 normalize_llm_config
10)
11from dataknobs_llm.prompts import AsyncPromptBuilder
13if TYPE_CHECKING:
14 from dataknobs_config.config import Config
17class HuggingFaceProvider(AsyncLLMProvider):
18 """HuggingFace Inference API provider."""
20 def __init__(
21 self,
22 config: Union[LLMConfig, "Config", Dict[str, Any]],
23 prompt_builder: AsyncPromptBuilder | None = None
24 ):
25 # Normalize config first
26 llm_config = normalize_llm_config(config)
27 super().__init__(llm_config, prompt_builder=prompt_builder)
28 self.base_url = llm_config.api_base or 'https://api-inference.huggingface.co/models'
30 async def initialize(self) -> None:
31 """Initialize HuggingFace client."""
32 try:
33 import aiohttp
35 api_key = self.config.api_key or os.environ.get('HUGGINGFACE_API_KEY')
36 if not api_key:
37 raise ValueError("HuggingFace API key not provided")
39 self._session = aiohttp.ClientSession(
40 headers={'Authorization': f'Bearer {api_key}'},
41 timeout=aiohttp.ClientTimeout(total=self.config.timeout)
42 )
43 self._is_initialized = True
44 except ImportError as e:
45 raise ImportError("aiohttp package not installed. Install with: pip install aiohttp") from e
47 async def close(self) -> None:
48 """Close HuggingFace client."""
49 if hasattr(self, '_session') and self._session:
50 await self._session.close()
51 self._is_initialized = False
53 async def validate_model(self) -> bool:
54 """Validate model availability."""
55 try:
56 url = f"{self.base_url}/{self.config.model}"
57 async with self._session.get(url) as response:
58 return response.status == 200
59 except Exception:
60 return False
62 def get_capabilities(self) -> List[ModelCapability]:
63 """Get HuggingFace model capabilities."""
64 # Basic capabilities for text generation models
65 return [
66 ModelCapability.TEXT_GENERATION,
67 ModelCapability.EMBEDDINGS if 'embedding' in self.config.model else None # type: ignore
68 ]
70 async def complete(
71 self,
72 messages: Union[str, List[LLMMessage]],
73 config_overrides: Dict[str, Any] | None = None,
74 **kwargs
75 ) -> LLMResponse:
76 """Generate completion.
78 Args:
79 messages: Input messages or prompt
80 config_overrides: Optional dict to override config fields (model,
81 temperature, max_tokens, top_p, stop_sequences, seed)
82 **kwargs: Additional provider-specific parameters
83 """
84 if not self._is_initialized:
85 await self.initialize()
87 # Get runtime config (with overrides applied if provided)
88 runtime_config = self._get_runtime_config(config_overrides)
90 # Convert to prompt
91 if isinstance(messages, str):
92 prompt = messages
93 else:
94 prompt = self._build_prompt(messages)
96 # Make API call
97 url = f"{self.base_url}/{runtime_config.model}"
98 payload = {
99 'inputs': prompt,
100 'parameters': {
101 'temperature': runtime_config.temperature,
102 'top_p': runtime_config.top_p,
103 'max_new_tokens': runtime_config.max_tokens or 100,
104 'return_full_text': False
105 }
106 }
108 async with self._session.post(url, json=payload) as response:
109 response.raise_for_status()
110 data = await response.json()
112 # Parse response
113 if isinstance(data, list) and len(data) > 0:
114 text = data[0].get('generated_text', '')
115 else:
116 text = str(data)
118 return LLMResponse(
119 content=text,
120 model=runtime_config.model,
121 finish_reason='stop'
122 )
124 async def stream_complete(
125 self,
126 messages: Union[str, List[LLMMessage]],
127 config_overrides: Dict[str, Any] | None = None,
128 **kwargs
129 ) -> AsyncIterator[LLMStreamResponse]:
130 """HuggingFace Inference API doesn't support streaming.
132 Args:
133 messages: Input messages or prompt
134 config_overrides: Optional dict to override config fields (model,
135 temperature, max_tokens, top_p, stop_sequences, seed)
136 **kwargs: Additional provider-specific parameters
137 """
138 # Simulate streaming by yielding complete response
139 response = await self.complete(messages, config_overrides=config_overrides, **kwargs)
140 yield LLMStreamResponse(
141 delta=response.content,
142 is_final=True,
143 finish_reason=response.finish_reason
144 )
146 async def embed(
147 self,
148 texts: Union[str, List[str]],
149 **kwargs
150 ) -> Union[List[float], List[List[float]]]:
151 """Generate embeddings."""
152 if not self._is_initialized:
153 await self.initialize()
155 if isinstance(texts, str):
156 texts = [texts]
157 single = True
158 else:
159 single = False
161 url = f"{self.base_url}/{self.config.model}"
162 payload = {'inputs': texts}
164 async with self._session.post(url, json=payload) as response:
165 response.raise_for_status()
166 embeddings = await response.json()
168 return embeddings[0] if single else embeddings
170 async def function_call(
171 self,
172 messages: List[LLMMessage],
173 functions: List[Dict[str, Any]],
174 **kwargs
175 ) -> LLMResponse:
176 """HuggingFace doesn't have native function calling."""
177 raise NotImplementedError("Function calling not supported for HuggingFace models")
179 def _build_prompt(self, messages: List[LLMMessage]) -> str:
180 """Build prompt from messages."""
181 prompt = ""
182 for msg in messages:
183 if msg.role == 'system':
184 prompt += f"{msg.content}\n\n"
185 elif msg.role == 'user':
186 prompt += f"User: {msg.content}\n"
187 elif msg.role == 'assistant':
188 prompt += f"Assistant: {msg.content}\n"
189 return prompt