Coverage for src/dataknobs_llm/llm/providers/huggingface.py: 20%
74 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:48 -0700
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:48 -0700
1"""HuggingFace Inference API provider implementation."""
3import os
4from typing import TYPE_CHECKING, Any, Dict, List, Union, AsyncIterator
6from ..base import (
7 LLMConfig, LLMMessage, LLMResponse, LLMStreamResponse,
8 AsyncLLMProvider, ModelCapability,
9 normalize_llm_config
10)
11from dataknobs_llm.prompts import AsyncPromptBuilder
13if TYPE_CHECKING:
14 from dataknobs_config.config import Config
17class HuggingFaceProvider(AsyncLLMProvider):
18 """HuggingFace Inference API provider."""
20 def __init__(
21 self,
22 config: Union[LLMConfig, "Config", Dict[str, Any]],
23 prompt_builder: AsyncPromptBuilder | None = None
24 ):
25 # Normalize config first
26 llm_config = normalize_llm_config(config)
27 super().__init__(llm_config, prompt_builder=prompt_builder)
28 self.base_url = llm_config.api_base or 'https://api-inference.huggingface.co/models'
30 async def initialize(self) -> None:
31 """Initialize HuggingFace client."""
32 try:
33 import aiohttp
35 api_key = self.config.api_key or os.environ.get('HUGGINGFACE_API_KEY')
36 if not api_key:
37 raise ValueError("HuggingFace API key not provided")
39 self._session = aiohttp.ClientSession(
40 headers={'Authorization': f'Bearer {api_key}'},
41 timeout=aiohttp.ClientTimeout(total=self.config.timeout)
42 )
43 self._is_initialized = True
44 except ImportError as e:
45 raise ImportError("aiohttp package not installed. Install with: pip install aiohttp") from e
47 async def close(self) -> None:
48 """Close HuggingFace client."""
49 if hasattr(self, '_session') and self._session:
50 await self._session.close()
51 self._is_initialized = False
53 async def validate_model(self) -> bool:
54 """Validate model availability."""
55 try:
56 url = f"{self.base_url}/{self.config.model}"
57 async with self._session.get(url) as response:
58 return response.status == 200
59 except Exception:
60 return False
62 def get_capabilities(self) -> List[ModelCapability]:
63 """Get HuggingFace model capabilities."""
64 # Basic capabilities for text generation models
65 return [
66 ModelCapability.TEXT_GENERATION,
67 ModelCapability.EMBEDDINGS if 'embedding' in self.config.model else None # type: ignore
68 ]
70 async def complete(
71 self,
72 messages: Union[str, List[LLMMessage]],
73 **kwargs
74 ) -> LLMResponse:
75 """Generate completion."""
76 if not self._is_initialized:
77 await self.initialize()
79 # Convert to prompt
80 if isinstance(messages, str):
81 prompt = messages
82 else:
83 prompt = self._build_prompt(messages)
85 # Make API call
86 url = f"{self.base_url}/{self.config.model}"
87 payload = {
88 'inputs': prompt,
89 'parameters': {
90 'temperature': self.config.temperature,
91 'top_p': self.config.top_p,
92 'max_new_tokens': self.config.max_tokens or 100,
93 'return_full_text': False
94 }
95 }
97 async with self._session.post(url, json=payload) as response:
98 response.raise_for_status()
99 data = await response.json()
101 # Parse response
102 if isinstance(data, list) and len(data) > 0:
103 text = data[0].get('generated_text', '')
104 else:
105 text = str(data)
107 return LLMResponse(
108 content=text,
109 model=self.config.model,
110 finish_reason='stop'
111 )
113 async def stream_complete(
114 self,
115 messages: Union[str, List[LLMMessage]],
116 **kwargs
117 ) -> AsyncIterator[LLMStreamResponse]:
118 """HuggingFace Inference API doesn't support streaming."""
119 # Simulate streaming by yielding complete response
120 response = await self.complete(messages, **kwargs)
121 yield LLMStreamResponse(
122 delta=response.content,
123 is_final=True,
124 finish_reason=response.finish_reason
125 )
127 async def embed(
128 self,
129 texts: Union[str, List[str]],
130 **kwargs
131 ) -> Union[List[float], List[List[float]]]:
132 """Generate embeddings."""
133 if not self._is_initialized:
134 await self.initialize()
136 if isinstance(texts, str):
137 texts = [texts]
138 single = True
139 else:
140 single = False
142 url = f"{self.base_url}/{self.config.model}"
143 payload = {'inputs': texts}
145 async with self._session.post(url, json=payload) as response:
146 response.raise_for_status()
147 embeddings = await response.json()
149 return embeddings[0] if single else embeddings
151 async def function_call(
152 self,
153 messages: List[LLMMessage],
154 functions: List[Dict[str, Any]],
155 **kwargs
156 ) -> LLMResponse:
157 """HuggingFace doesn't have native function calling."""
158 raise NotImplementedError("Function calling not supported for HuggingFace models")
160 def _build_prompt(self, messages: List[LLMMessage]) -> str:
161 """Build prompt from messages."""
162 prompt = ""
163 for msg in messages:
164 if msg.role == 'system':
165 prompt += f"{msg.content}\n\n"
166 elif msg.role == 'user':
167 prompt += f"User: {msg.content}\n"
168 elif msg.role == 'assistant':
169 prompt += f"Assistant: {msg.content}\n"
170 return prompt