Coverage for src/dataknobs_llm/llm/providers/huggingface.py: 20%

74 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-08 13:48 -0700

1"""HuggingFace Inference API provider implementation.""" 

2 

3import os 

4from typing import TYPE_CHECKING, Any, Dict, List, Union, AsyncIterator 

5 

6from ..base import ( 

7 LLMConfig, LLMMessage, LLMResponse, LLMStreamResponse, 

8 AsyncLLMProvider, ModelCapability, 

9 normalize_llm_config 

10) 

11from dataknobs_llm.prompts import AsyncPromptBuilder 

12 

13if TYPE_CHECKING: 

14 from dataknobs_config.config import Config 

15 

16 

17class HuggingFaceProvider(AsyncLLMProvider): 

18 """HuggingFace Inference API provider.""" 

19 

20 def __init__( 

21 self, 

22 config: Union[LLMConfig, "Config", Dict[str, Any]], 

23 prompt_builder: AsyncPromptBuilder | None = None 

24 ): 

25 # Normalize config first 

26 llm_config = normalize_llm_config(config) 

27 super().__init__(llm_config, prompt_builder=prompt_builder) 

28 self.base_url = llm_config.api_base or 'https://api-inference.huggingface.co/models' 

29 

30 async def initialize(self) -> None: 

31 """Initialize HuggingFace client.""" 

32 try: 

33 import aiohttp 

34 

35 api_key = self.config.api_key or os.environ.get('HUGGINGFACE_API_KEY') 

36 if not api_key: 

37 raise ValueError("HuggingFace API key not provided") 

38 

39 self._session = aiohttp.ClientSession( 

40 headers={'Authorization': f'Bearer {api_key}'}, 

41 timeout=aiohttp.ClientTimeout(total=self.config.timeout) 

42 ) 

43 self._is_initialized = True 

44 except ImportError as e: 

45 raise ImportError("aiohttp package not installed. Install with: pip install aiohttp") from e 

46 

47 async def close(self) -> None: 

48 """Close HuggingFace client.""" 

49 if hasattr(self, '_session') and self._session: 

50 await self._session.close() 

51 self._is_initialized = False 

52 

53 async def validate_model(self) -> bool: 

54 """Validate model availability.""" 

55 try: 

56 url = f"{self.base_url}/{self.config.model}" 

57 async with self._session.get(url) as response: 

58 return response.status == 200 

59 except Exception: 

60 return False 

61 

62 def get_capabilities(self) -> List[ModelCapability]: 

63 """Get HuggingFace model capabilities.""" 

64 # Basic capabilities for text generation models 

65 return [ 

66 ModelCapability.TEXT_GENERATION, 

67 ModelCapability.EMBEDDINGS if 'embedding' in self.config.model else None # type: ignore 

68 ] 

69 

70 async def complete( 

71 self, 

72 messages: Union[str, List[LLMMessage]], 

73 **kwargs 

74 ) -> LLMResponse: 

75 """Generate completion.""" 

76 if not self._is_initialized: 

77 await self.initialize() 

78 

79 # Convert to prompt 

80 if isinstance(messages, str): 

81 prompt = messages 

82 else: 

83 prompt = self._build_prompt(messages) 

84 

85 # Make API call 

86 url = f"{self.base_url}/{self.config.model}" 

87 payload = { 

88 'inputs': prompt, 

89 'parameters': { 

90 'temperature': self.config.temperature, 

91 'top_p': self.config.top_p, 

92 'max_new_tokens': self.config.max_tokens or 100, 

93 'return_full_text': False 

94 } 

95 } 

96 

97 async with self._session.post(url, json=payload) as response: 

98 response.raise_for_status() 

99 data = await response.json() 

100 

101 # Parse response 

102 if isinstance(data, list) and len(data) > 0: 

103 text = data[0].get('generated_text', '') 

104 else: 

105 text = str(data) 

106 

107 return LLMResponse( 

108 content=text, 

109 model=self.config.model, 

110 finish_reason='stop' 

111 ) 

112 

113 async def stream_complete( 

114 self, 

115 messages: Union[str, List[LLMMessage]], 

116 **kwargs 

117 ) -> AsyncIterator[LLMStreamResponse]: 

118 """HuggingFace Inference API doesn't support streaming.""" 

119 # Simulate streaming by yielding complete response 

120 response = await self.complete(messages, **kwargs) 

121 yield LLMStreamResponse( 

122 delta=response.content, 

123 is_final=True, 

124 finish_reason=response.finish_reason 

125 ) 

126 

127 async def embed( 

128 self, 

129 texts: Union[str, List[str]], 

130 **kwargs 

131 ) -> Union[List[float], List[List[float]]]: 

132 """Generate embeddings.""" 

133 if not self._is_initialized: 

134 await self.initialize() 

135 

136 if isinstance(texts, str): 

137 texts = [texts] 

138 single = True 

139 else: 

140 single = False 

141 

142 url = f"{self.base_url}/{self.config.model}" 

143 payload = {'inputs': texts} 

144 

145 async with self._session.post(url, json=payload) as response: 

146 response.raise_for_status() 

147 embeddings = await response.json() 

148 

149 return embeddings[0] if single else embeddings 

150 

151 async def function_call( 

152 self, 

153 messages: List[LLMMessage], 

154 functions: List[Dict[str, Any]], 

155 **kwargs 

156 ) -> LLMResponse: 

157 """HuggingFace doesn't have native function calling.""" 

158 raise NotImplementedError("Function calling not supported for HuggingFace models") 

159 

160 def _build_prompt(self, messages: List[LLMMessage]) -> str: 

161 """Build prompt from messages.""" 

162 prompt = "" 

163 for msg in messages: 

164 if msg.role == 'system': 

165 prompt += f"{msg.content}\n\n" 

166 elif msg.role == 'user': 

167 prompt += f"User: {msg.content}\n" 

168 elif msg.role == 'assistant': 

169 prompt += f"Assistant: {msg.content}\n" 

170 return prompt