Coverage for src / dataknobs_llm / llm / providers / huggingface.py: 20%

75 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-15 10:29 -0700

1"""HuggingFace Inference API provider implementation.""" 

2 

3import os 

4from typing import TYPE_CHECKING, Any, Dict, List, Union, AsyncIterator 

5 

6from ..base import ( 

7 LLMConfig, LLMMessage, LLMResponse, LLMStreamResponse, 

8 AsyncLLMProvider, ModelCapability, 

9 normalize_llm_config 

10) 

11from dataknobs_llm.prompts import AsyncPromptBuilder 

12 

13if TYPE_CHECKING: 

14 from dataknobs_config.config import Config 

15 

16 

17class HuggingFaceProvider(AsyncLLMProvider): 

18 """HuggingFace Inference API provider.""" 

19 

20 def __init__( 

21 self, 

22 config: Union[LLMConfig, "Config", Dict[str, Any]], 

23 prompt_builder: AsyncPromptBuilder | None = None 

24 ): 

25 # Normalize config first 

26 llm_config = normalize_llm_config(config) 

27 super().__init__(llm_config, prompt_builder=prompt_builder) 

28 self.base_url = llm_config.api_base or 'https://api-inference.huggingface.co/models' 

29 

30 async def initialize(self) -> None: 

31 """Initialize HuggingFace client.""" 

32 try: 

33 import aiohttp 

34 

35 api_key = self.config.api_key or os.environ.get('HUGGINGFACE_API_KEY') 

36 if not api_key: 

37 raise ValueError("HuggingFace API key not provided") 

38 

39 self._session = aiohttp.ClientSession( 

40 headers={'Authorization': f'Bearer {api_key}'}, 

41 timeout=aiohttp.ClientTimeout(total=self.config.timeout) 

42 ) 

43 self._is_initialized = True 

44 except ImportError as e: 

45 raise ImportError("aiohttp package not installed. Install with: pip install aiohttp") from e 

46 

47 async def close(self) -> None: 

48 """Close HuggingFace client.""" 

49 if hasattr(self, '_session') and self._session: 

50 await self._session.close() 

51 self._is_initialized = False 

52 

53 async def validate_model(self) -> bool: 

54 """Validate model availability.""" 

55 try: 

56 url = f"{self.base_url}/{self.config.model}" 

57 async with self._session.get(url) as response: 

58 return response.status == 200 

59 except Exception: 

60 return False 

61 

62 def get_capabilities(self) -> List[ModelCapability]: 

63 """Get HuggingFace model capabilities.""" 

64 # Basic capabilities for text generation models 

65 return [ 

66 ModelCapability.TEXT_GENERATION, 

67 ModelCapability.EMBEDDINGS if 'embedding' in self.config.model else None # type: ignore 

68 ] 

69 

70 async def complete( 

71 self, 

72 messages: Union[str, List[LLMMessage]], 

73 config_overrides: Dict[str, Any] | None = None, 

74 **kwargs 

75 ) -> LLMResponse: 

76 """Generate completion. 

77 

78 Args: 

79 messages: Input messages or prompt 

80 config_overrides: Optional dict to override config fields (model, 

81 temperature, max_tokens, top_p, stop_sequences, seed) 

82 **kwargs: Additional provider-specific parameters 

83 """ 

84 if not self._is_initialized: 

85 await self.initialize() 

86 

87 # Get runtime config (with overrides applied if provided) 

88 runtime_config = self._get_runtime_config(config_overrides) 

89 

90 # Convert to prompt 

91 if isinstance(messages, str): 

92 prompt = messages 

93 else: 

94 prompt = self._build_prompt(messages) 

95 

96 # Make API call 

97 url = f"{self.base_url}/{runtime_config.model}" 

98 payload = { 

99 'inputs': prompt, 

100 'parameters': { 

101 'temperature': runtime_config.temperature, 

102 'top_p': runtime_config.top_p, 

103 'max_new_tokens': runtime_config.max_tokens or 100, 

104 'return_full_text': False 

105 } 

106 } 

107 

108 async with self._session.post(url, json=payload) as response: 

109 response.raise_for_status() 

110 data = await response.json() 

111 

112 # Parse response 

113 if isinstance(data, list) and len(data) > 0: 

114 text = data[0].get('generated_text', '') 

115 else: 

116 text = str(data) 

117 

118 return LLMResponse( 

119 content=text, 

120 model=runtime_config.model, 

121 finish_reason='stop' 

122 ) 

123 

124 async def stream_complete( 

125 self, 

126 messages: Union[str, List[LLMMessage]], 

127 config_overrides: Dict[str, Any] | None = None, 

128 **kwargs 

129 ) -> AsyncIterator[LLMStreamResponse]: 

130 """HuggingFace Inference API doesn't support streaming. 

131 

132 Args: 

133 messages: Input messages or prompt 

134 config_overrides: Optional dict to override config fields (model, 

135 temperature, max_tokens, top_p, stop_sequences, seed) 

136 **kwargs: Additional provider-specific parameters 

137 """ 

138 # Simulate streaming by yielding complete response 

139 response = await self.complete(messages, config_overrides=config_overrides, **kwargs) 

140 yield LLMStreamResponse( 

141 delta=response.content, 

142 is_final=True, 

143 finish_reason=response.finish_reason 

144 ) 

145 

146 async def embed( 

147 self, 

148 texts: Union[str, List[str]], 

149 **kwargs 

150 ) -> Union[List[float], List[List[float]]]: 

151 """Generate embeddings.""" 

152 if not self._is_initialized: 

153 await self.initialize() 

154 

155 if isinstance(texts, str): 

156 texts = [texts] 

157 single = True 

158 else: 

159 single = False 

160 

161 url = f"{self.base_url}/{self.config.model}" 

162 payload = {'inputs': texts} 

163 

164 async with self._session.post(url, json=payload) as response: 

165 response.raise_for_status() 

166 embeddings = await response.json() 

167 

168 return embeddings[0] if single else embeddings 

169 

170 async def function_call( 

171 self, 

172 messages: List[LLMMessage], 

173 functions: List[Dict[str, Any]], 

174 **kwargs 

175 ) -> LLMResponse: 

176 """HuggingFace doesn't have native function calling.""" 

177 raise NotImplementedError("Function calling not supported for HuggingFace models") 

178 

179 def _build_prompt(self, messages: List[LLMMessage]) -> str: 

180 """Build prompt from messages.""" 

181 prompt = "" 

182 for msg in messages: 

183 if msg.role == 'system': 

184 prompt += f"{msg.content}\n\n" 

185 elif msg.role == 'user': 

186 prompt += f"User: {msg.content}\n" 

187 elif msg.role == 'assistant': 

188 prompt += f"Assistant: {msg.content}\n" 

189 return prompt