Coverage for src / dataknobs_llm / llm / providers / echo.py: 45%

96 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-15 10:29 -0700

1"""Echo provider for testing and debugging.""" 

2 

3import hashlib 

4from typing import TYPE_CHECKING, Any, Dict, List, Union, AsyncIterator 

5 

6from ..base import ( 

7 LLMConfig, LLMMessage, LLMResponse, LLMStreamResponse, 

8 AsyncLLMProvider, ModelCapability, 

9 normalize_llm_config 

10) 

11from dataknobs_llm.prompts import AsyncPromptBuilder 

12 

13if TYPE_CHECKING: 

14 from dataknobs_config.config import Config 

15 

16 

17class EchoProvider(AsyncLLMProvider): 

18 """Echo provider for testing and debugging. 

19 

20 This provider echoes back input messages and generates deterministic 

21 mock embeddings. Perfect for testing without real LLM API calls. 

22 

23 Features: 

24 - Echoes back user messages with configurable prefix 

25 - Generates deterministic embeddings based on content hash 

26 - Supports streaming (character-by-character echo) 

27 - Mocks function calling with deterministic responses 

28 - Zero external dependencies 

29 - Instant responses 

30 """ 

31 

32 def __init__( 

33 self, 

34 config: Union[LLMConfig, "Config", Dict[str, Any]], 

35 prompt_builder: AsyncPromptBuilder | None = None 

36 ): 

37 # Normalize config first 

38 llm_config = normalize_llm_config(config) 

39 super().__init__(llm_config, prompt_builder=prompt_builder) 

40 

41 # Echo-specific configuration from options 

42 self.echo_prefix = llm_config.options.get('echo_prefix', 'Echo: ') 

43 self.embedding_dim = llm_config.options.get('embedding_dim', 768) 

44 self.mock_tokens = llm_config.options.get('mock_tokens', True) 

45 self.stream_delay = llm_config.options.get('stream_delay', 0.0) # seconds per char 

46 

47 def _generate_embedding(self, text: str) -> List[float]: 

48 """Generate deterministic embedding vector from text. 

49 

50 Uses SHA-256 hash to create a deterministic vector that: 

51 - Is always the same for the same input 

52 - Distributes values across [-1, 1] range 

53 - Has configurable dimensionality 

54 

55 Args: 

56 text: Input text 

57 

58 Returns: 

59 Embedding vector of size self.embedding_dim 

60 """ 

61 # Create hash of the text 

62 hash_obj = hashlib.sha256(text.encode('utf-8')) 

63 hash_bytes = hash_obj.digest() 

64 

65 # Generate embedding by repeatedly hashing 

66 embedding = [] 

67 current_hash = hash_bytes 

68 

69 while len(embedding) < self.embedding_dim: 

70 # Convert hash bytes to floats in [-1, 1] 

71 for byte in current_hash: 

72 if len(embedding) >= self.embedding_dim: 

73 break 

74 # Normalize byte (0-255) to [-1, 1] 

75 embedding.append((byte / 127.5) - 1.0) 

76 

77 # Rehash for next batch of values 

78 current_hash = hashlib.sha256(current_hash).digest() 

79 

80 return embedding[:self.embedding_dim] 

81 

82 def _count_tokens(self, text: str) -> int: 

83 """Mock token counting (simple character-based estimate). 

84 

85 Args: 

86 text: Input text 

87 

88 Returns: 

89 Estimated token count 

90 """ 

91 # Rough approximation: 1 token ~= 4 characters 

92 return max(1, len(text) // 4) 

93 

94 async def initialize(self) -> None: 

95 """Initialize echo provider (no-op).""" 

96 self._is_initialized = True 

97 

98 async def close(self) -> None: 

99 """Close echo provider (no-op).""" 

100 self._is_initialized = False 

101 

102 async def validate_model(self) -> bool: 

103 """Validate model (always true for echo).""" 

104 return True 

105 

106 def get_capabilities(self) -> List[ModelCapability]: 

107 """Get echo provider capabilities.""" 

108 return [ 

109 ModelCapability.TEXT_GENERATION, 

110 ModelCapability.CHAT, 

111 ModelCapability.EMBEDDINGS, 

112 ModelCapability.FUNCTION_CALLING, 

113 ModelCapability.STREAMING, 

114 ModelCapability.JSON_MODE 

115 ] 

116 

117 async def complete( 

118 self, 

119 messages: Union[str, List[LLMMessage]], 

120 config_overrides: Dict[str, Any] | None = None, 

121 **kwargs: Any 

122 ) -> LLMResponse: 

123 """Echo back the input messages. 

124 

125 Args: 

126 messages: Input messages or prompt 

127 config_overrides: Optional dict to override config fields (model, 

128 temperature, max_tokens, top_p, stop_sequences, seed) 

129 **kwargs: Additional parameters (ignored) 

130 

131 Returns: 

132 Echo response 

133 """ 

134 if not self._is_initialized: 

135 await self.initialize() 

136 

137 # Get runtime config (with overrides applied if provided) 

138 runtime_config = self._get_runtime_config(config_overrides) 

139 

140 # Convert to message list 

141 if isinstance(messages, str): 

142 messages = [LLMMessage(role='user', content=messages)] 

143 

144 # Build echo response from last user message 

145 user_messages = [msg for msg in messages if msg.role == 'user'] 

146 if user_messages: 

147 content = self.echo_prefix + user_messages[-1].content 

148 else: 

149 content = self.echo_prefix + "(no user message)" 

150 

151 # Add system prompt if configured and in echo 

152 if runtime_config.system_prompt and runtime_config.options.get('echo_system', False): 

153 content = f"[System: {runtime_config.system_prompt}]\n{content}" 

154 

155 # Mock token usage 

156 prompt_tokens = sum(self._count_tokens(msg.content) for msg in messages) 

157 completion_tokens = self._count_tokens(content) 

158 

159 return LLMResponse( 

160 content=content, 

161 model=runtime_config.model or 'echo-model', 

162 finish_reason='stop', 

163 usage={ 

164 'prompt_tokens': prompt_tokens, 

165 'completion_tokens': completion_tokens, 

166 'total_tokens': prompt_tokens + completion_tokens 

167 } if self.mock_tokens else None 

168 ) 

169 

170 async def stream_complete( 

171 self, 

172 messages: Union[str, List[LLMMessage]], 

173 config_overrides: Dict[str, Any] | None = None, 

174 **kwargs: Any 

175 ) -> AsyncIterator[LLMStreamResponse]: 

176 """Stream echo response character by character. 

177 

178 Args: 

179 messages: Input messages or prompt 

180 config_overrides: Optional dict to override config fields (model, 

181 temperature, max_tokens, top_p, stop_sequences, seed) 

182 **kwargs: Additional parameters (ignored) 

183 

184 Yields: 

185 Streaming response chunks 

186 """ 

187 if not self._is_initialized: 

188 await self.initialize() 

189 

190 # Get full response 

191 response = await self.complete(messages, config_overrides=config_overrides, **kwargs) 

192 

193 # Stream character by character 

194 for i, char in enumerate(response.content): 

195 is_final = (i == len(response.content) - 1) 

196 

197 yield LLMStreamResponse( 

198 delta=char, 

199 is_final=is_final, 

200 finish_reason='stop' if is_final else None, 

201 usage=response.usage if is_final else None 

202 ) 

203 

204 # Optional delay for realistic streaming 

205 if self.stream_delay > 0: 

206 import asyncio 

207 await asyncio.sleep(self.stream_delay) 

208 

209 async def embed( 

210 self, 

211 texts: Union[str, List[str]], 

212 **kwargs: Any 

213 ) -> Union[List[float], List[List[float]]]: 

214 """Generate deterministic mock embeddings. 

215 

216 Args: 

217 texts: Input text(s) 

218 **kwargs: Additional parameters (ignored) 

219 

220 Returns: 

221 Embedding vector(s) 

222 """ 

223 if not self._is_initialized: 

224 await self.initialize() 

225 

226 if isinstance(texts, str): 

227 return self._generate_embedding(texts) 

228 else: 

229 return [self._generate_embedding(text) for text in texts] 

230 

231 async def function_call( 

232 self, 

233 messages: List[LLMMessage], 

234 functions: List[Dict[str, Any]], 

235 **kwargs: Any 

236 ) -> LLMResponse: 

237 """Mock function calling with deterministic response. 

238 

239 Args: 

240 messages: Conversation messages 

241 functions: Available functions 

242 **kwargs: Additional parameters (ignored) 

243 

244 Returns: 

245 Response with mock function call 

246 """ 

247 if not self._is_initialized: 

248 await self.initialize() 

249 

250 # Get last user message 

251 user_messages = [msg for msg in messages if msg.role == 'user'] 

252 user_content = user_messages[-1].content if user_messages else "" 

253 

254 # Mock function call: use first function with mock arguments 

255 if functions: 

256 first_func = functions[0] 

257 func_name = first_func.get('name', 'unknown_function') 

258 

259 # Generate mock arguments based on parameters schema 

260 params = first_func.get('parameters', {}) 

261 properties = params.get('properties', {}) 

262 

263 mock_args = {} 

264 for param_name, param_schema in properties.items(): 

265 param_type = param_schema.get('type', 'string') 

266 

267 # Generate mock value based on type 

268 if param_type == 'string': 

269 mock_args[param_name] = f"mock_{param_name}_from_echo" 

270 elif param_type == 'number' or param_type == 'integer': 

271 # Use hash to generate deterministic number 

272 hash_val = int(hashlib.md5(user_content.encode()).hexdigest()[:8], 16) 

273 mock_args[param_name] = hash_val % 100 

274 elif param_type == 'boolean': 

275 # Deterministic boolean based on hash 

276 hash_val = int(hashlib.md5(user_content.encode()).hexdigest()[:2], 16) 

277 mock_args[param_name] = hash_val % 2 == 0 

278 elif param_type == 'array': 

279 mock_args[param_name] = ["mock_item_1", "mock_item_2"] 

280 elif param_type == 'object': 

281 mock_args[param_name] = {"mock_key": "mock_value"} 

282 else: 

283 mock_args[param_name] = None 

284 

285 # Build response with function call 

286 content = f"{self.echo_prefix}Calling function '{func_name}'" 

287 

288 prompt_tokens = sum(self._count_tokens(msg.content) for msg in messages) 

289 completion_tokens = self._count_tokens(content) 

290 

291 return LLMResponse( 

292 content=content, 

293 model=self.config.model or 'echo-model', 

294 finish_reason='function_call', 

295 usage={ 

296 'prompt_tokens': prompt_tokens, 

297 'completion_tokens': completion_tokens, 

298 'total_tokens': prompt_tokens + completion_tokens 

299 } if self.mock_tokens else None, 

300 function_call={ 

301 'name': func_name, 

302 'arguments': mock_args 

303 } 

304 ) 

305 else: 

306 # No functions provided, just echo 

307 return await self.complete(messages, **kwargs)