Coverage for src/dataknobs_llm/llm/providers/echo.py: 94%

95 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-08 13:51 -0700

1"""Echo provider for testing and debugging.""" 

2 

3import hashlib 

4from typing import TYPE_CHECKING, Any, Dict, List, Union, AsyncIterator 

5 

6from ..base import ( 

7 LLMConfig, LLMMessage, LLMResponse, LLMStreamResponse, 

8 AsyncLLMProvider, ModelCapability, 

9 normalize_llm_config 

10) 

11from dataknobs_llm.prompts import AsyncPromptBuilder 

12 

13if TYPE_CHECKING: 

14 from dataknobs_config.config import Config 

15 

16 

17class EchoProvider(AsyncLLMProvider): 

18 """Echo provider for testing and debugging. 

19 

20 This provider echoes back input messages and generates deterministic 

21 mock embeddings. Perfect for testing without real LLM API calls. 

22 

23 Features: 

24 - Echoes back user messages with configurable prefix 

25 - Generates deterministic embeddings based on content hash 

26 - Supports streaming (character-by-character echo) 

27 - Mocks function calling with deterministic responses 

28 - Zero external dependencies 

29 - Instant responses 

30 """ 

31 

32 def __init__( 

33 self, 

34 config: Union[LLMConfig, "Config", Dict[str, Any]], 

35 prompt_builder: AsyncPromptBuilder | None = None 

36 ): 

37 # Normalize config first 

38 llm_config = normalize_llm_config(config) 

39 super().__init__(llm_config, prompt_builder=prompt_builder) 

40 

41 # Echo-specific configuration from options 

42 self.echo_prefix = llm_config.options.get('echo_prefix', 'Echo: ') 

43 self.embedding_dim = llm_config.options.get('embedding_dim', 768) 

44 self.mock_tokens = llm_config.options.get('mock_tokens', True) 

45 self.stream_delay = llm_config.options.get('stream_delay', 0.0) # seconds per char 

46 

47 def _generate_embedding(self, text: str) -> List[float]: 

48 """Generate deterministic embedding vector from text. 

49 

50 Uses SHA-256 hash to create a deterministic vector that: 

51 - Is always the same for the same input 

52 - Distributes values across [-1, 1] range 

53 - Has configurable dimensionality 

54 

55 Args: 

56 text: Input text 

57 

58 Returns: 

59 Embedding vector of size self.embedding_dim 

60 """ 

61 # Create hash of the text 

62 hash_obj = hashlib.sha256(text.encode('utf-8')) 

63 hash_bytes = hash_obj.digest() 

64 

65 # Generate embedding by repeatedly hashing 

66 embedding = [] 

67 current_hash = hash_bytes 

68 

69 while len(embedding) < self.embedding_dim: 

70 # Convert hash bytes to floats in [-1, 1] 

71 for byte in current_hash: 

72 if len(embedding) >= self.embedding_dim: 

73 break 

74 # Normalize byte (0-255) to [-1, 1] 

75 embedding.append((byte / 127.5) - 1.0) 

76 

77 # Rehash for next batch of values 

78 current_hash = hashlib.sha256(current_hash).digest() 

79 

80 return embedding[:self.embedding_dim] 

81 

82 def _count_tokens(self, text: str) -> int: 

83 """Mock token counting (simple character-based estimate). 

84 

85 Args: 

86 text: Input text 

87 

88 Returns: 

89 Estimated token count 

90 """ 

91 # Rough approximation: 1 token ~= 4 characters 

92 return max(1, len(text) // 4) 

93 

94 async def initialize(self) -> None: 

95 """Initialize echo provider (no-op).""" 

96 self._is_initialized = True 

97 

98 async def close(self) -> None: 

99 """Close echo provider (no-op).""" 

100 self._is_initialized = False 

101 

102 async def validate_model(self) -> bool: 

103 """Validate model (always true for echo).""" 

104 return True 

105 

106 def get_capabilities(self) -> List[ModelCapability]: 

107 """Get echo provider capabilities.""" 

108 return [ 

109 ModelCapability.TEXT_GENERATION, 

110 ModelCapability.CHAT, 

111 ModelCapability.EMBEDDINGS, 

112 ModelCapability.FUNCTION_CALLING, 

113 ModelCapability.STREAMING, 

114 ModelCapability.JSON_MODE 

115 ] 

116 

117 async def complete( 

118 self, 

119 messages: Union[str, List[LLMMessage]], 

120 **kwargs: Any 

121 ) -> LLMResponse: 

122 """Echo back the input messages. 

123 

124 Args: 

125 messages: Input messages or prompt 

126 **kwargs: Additional parameters (ignored) 

127 

128 Returns: 

129 Echo response 

130 """ 

131 if not self._is_initialized: 

132 await self.initialize() 

133 

134 # Convert to message list 

135 if isinstance(messages, str): 

136 messages = [LLMMessage(role='user', content=messages)] 

137 

138 # Build echo response from last user message 

139 user_messages = [msg for msg in messages if msg.role == 'user'] 

140 if user_messages: 

141 content = self.echo_prefix + user_messages[-1].content 

142 else: 

143 content = self.echo_prefix + "(no user message)" 

144 

145 # Add system prompt if configured and in echo 

146 if self.config.system_prompt and self.config.options.get('echo_system', False): 

147 content = f"[System: {self.config.system_prompt}]\n{content}" 

148 

149 # Mock token usage 

150 prompt_tokens = sum(self._count_tokens(msg.content) for msg in messages) 

151 completion_tokens = self._count_tokens(content) 

152 

153 return LLMResponse( 

154 content=content, 

155 model=self.config.model or 'echo-model', 

156 finish_reason='stop', 

157 usage={ 

158 'prompt_tokens': prompt_tokens, 

159 'completion_tokens': completion_tokens, 

160 'total_tokens': prompt_tokens + completion_tokens 

161 } if self.mock_tokens else None 

162 ) 

163 

164 async def stream_complete( 

165 self, 

166 messages: Union[str, List[LLMMessage]], 

167 **kwargs: Any 

168 ) -> AsyncIterator[LLMStreamResponse]: 

169 """Stream echo response character by character. 

170 

171 Args: 

172 messages: Input messages or prompt 

173 **kwargs: Additional parameters (ignored) 

174 

175 Yields: 

176 Streaming response chunks 

177 """ 

178 if not self._is_initialized: 

179 await self.initialize() 

180 

181 # Get full response 

182 response = await self.complete(messages, **kwargs) 

183 

184 # Stream character by character 

185 for i, char in enumerate(response.content): 

186 is_final = (i == len(response.content) - 1) 

187 

188 yield LLMStreamResponse( 

189 delta=char, 

190 is_final=is_final, 

191 finish_reason='stop' if is_final else None, 

192 usage=response.usage if is_final else None 

193 ) 

194 

195 # Optional delay for realistic streaming 

196 if self.stream_delay > 0: 

197 import asyncio 

198 await asyncio.sleep(self.stream_delay) 

199 

200 async def embed( 

201 self, 

202 texts: Union[str, List[str]], 

203 **kwargs: Any 

204 ) -> Union[List[float], List[List[float]]]: 

205 """Generate deterministic mock embeddings. 

206 

207 Args: 

208 texts: Input text(s) 

209 **kwargs: Additional parameters (ignored) 

210 

211 Returns: 

212 Embedding vector(s) 

213 """ 

214 if not self._is_initialized: 

215 await self.initialize() 

216 

217 if isinstance(texts, str): 

218 return self._generate_embedding(texts) 

219 else: 

220 return [self._generate_embedding(text) for text in texts] 

221 

222 async def function_call( 

223 self, 

224 messages: List[LLMMessage], 

225 functions: List[Dict[str, Any]], 

226 **kwargs: Any 

227 ) -> LLMResponse: 

228 """Mock function calling with deterministic response. 

229 

230 Args: 

231 messages: Conversation messages 

232 functions: Available functions 

233 **kwargs: Additional parameters (ignored) 

234 

235 Returns: 

236 Response with mock function call 

237 """ 

238 if not self._is_initialized: 

239 await self.initialize() 

240 

241 # Get last user message 

242 user_messages = [msg for msg in messages if msg.role == 'user'] 

243 user_content = user_messages[-1].content if user_messages else "" 

244 

245 # Mock function call: use first function with mock arguments 

246 if functions: 

247 first_func = functions[0] 

248 func_name = first_func.get('name', 'unknown_function') 

249 

250 # Generate mock arguments based on parameters schema 

251 params = first_func.get('parameters', {}) 

252 properties = params.get('properties', {}) 

253 

254 mock_args = {} 

255 for param_name, param_schema in properties.items(): 

256 param_type = param_schema.get('type', 'string') 

257 

258 # Generate mock value based on type 

259 if param_type == 'string': 

260 mock_args[param_name] = f"mock_{param_name}_from_echo" 

261 elif param_type == 'number' or param_type == 'integer': 

262 # Use hash to generate deterministic number 

263 hash_val = int(hashlib.md5(user_content.encode()).hexdigest()[:8], 16) 

264 mock_args[param_name] = hash_val % 100 

265 elif param_type == 'boolean': 

266 # Deterministic boolean based on hash 

267 hash_val = int(hashlib.md5(user_content.encode()).hexdigest()[:2], 16) 

268 mock_args[param_name] = hash_val % 2 == 0 

269 elif param_type == 'array': 

270 mock_args[param_name] = ["mock_item_1", "mock_item_2"] 

271 elif param_type == 'object': 

272 mock_args[param_name] = {"mock_key": "mock_value"} 

273 else: 

274 mock_args[param_name] = None 

275 

276 # Build response with function call 

277 content = f"{self.echo_prefix}Calling function '{func_name}'" 

278 

279 prompt_tokens = sum(self._count_tokens(msg.content) for msg in messages) 

280 completion_tokens = self._count_tokens(content) 

281 

282 return LLMResponse( 

283 content=content, 

284 model=self.config.model or 'echo-model', 

285 finish_reason='function_call', 

286 usage={ 

287 'prompt_tokens': prompt_tokens, 

288 'completion_tokens': completion_tokens, 

289 'total_tokens': prompt_tokens + completion_tokens 

290 } if self.mock_tokens else None, 

291 function_call={ 

292 'name': func_name, 

293 'arguments': mock_args 

294 } 

295 ) 

296 else: 

297 # No functions provided, just echo 

298 return await self.complete(messages, **kwargs)