Coverage for src / dataknobs_llm / fsm_integration / resources.py: 0%
301 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-15 10:28 -0700
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-15 10:28 -0700
1"""LLM resource provider for language model interactions.
3Note: This module was migrated from dataknobs_fsm.resources.llm to
4consolidate all LLM functionality in the dataknobs-llm package.
5"""
7import json
8import os
9import time
10from dataclasses import dataclass, field as dataclass_field
11from typing import Any, Dict, List, Union
12from enum import Enum
14from dataknobs_fsm.functions.base import ResourceError
15from dataknobs_fsm.resources.base import (
16 BaseResourceProvider,
17 ResourceHealth,
18 ResourceStatus,
19)
22class LLMProvider(Enum):
23 """Supported LLM providers."""
25 OPENAI = "openai"
26 ANTHROPIC = "anthropic"
27 OLLAMA = "ollama"
28 HUGGINGFACE = "huggingface"
29 HUGGINGFACE_INFERENCE = "huggingface_inference" # HF Inference API
30 CUSTOM = "custom"
33@dataclass
34class LLMSession:
35 """LLM session with configuration and state."""
37 provider: LLMProvider
38 model_name: str
39 api_key: str | None = None
40 endpoint: str | None = None
41 temperature: float = 0.7
42 max_tokens: int = 1000
43 top_p: float = 1.0
44 frequency_penalty: float = 0.0
45 presence_penalty: float = 0.0
47 # Rate limiting (mainly for commercial APIs)
48 requests_per_minute: int = 60
49 tokens_per_minute: int = 90000
50 request_count: int = 0
51 token_count: int = 0
52 window_start: float = dataclass_field(default_factory=time.time)
54 # Token tracking
55 total_prompt_tokens: int = 0
56 total_completion_tokens: int = 0
57 total_requests: int = 0
59 # Provider-specific settings
60 provider_config: Dict[str, Any] = dataclass_field(default_factory=dict)
62 def check_rate_limits(self, estimated_tokens: int = 0) -> bool:
63 """Check if request would exceed rate limits.
65 Args:
66 estimated_tokens: Estimated tokens for the request.
68 Returns:
69 True if request can proceed, False if rate limited.
70 """
71 # Local providers don't have rate limits
72 if self.provider in [LLMProvider.OLLAMA, LLMProvider.HUGGINGFACE]:
73 return True
75 current_time = time.time()
76 window_elapsed = current_time - self.window_start
78 # Reset window if a minute has passed
79 if window_elapsed >= 60:
80 self.request_count = 0
81 self.token_count = 0
82 self.window_start = current_time
83 return True
85 # Check limits
86 if self.request_count >= self.requests_per_minute:
87 return False
89 if self.token_count + estimated_tokens > self.tokens_per_minute:
90 return False
92 return True
94 def record_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
95 """Record token usage.
97 Args:
98 prompt_tokens: Number of prompt tokens used.
99 completion_tokens: Number of completion tokens generated.
100 """
101 total_tokens = prompt_tokens + completion_tokens
103 self.request_count += 1
104 self.token_count += total_tokens
105 self.total_requests += 1
106 self.total_prompt_tokens += prompt_tokens
107 self.total_completion_tokens += completion_tokens
110class LLMResource(BaseResourceProvider):
111 """LLM resource provider for language model operations.
113 Supports multiple providers:
114 - OpenAI: GPT models via OpenAI API
115 - Anthropic: Claude models via Anthropic API
116 - Ollama: Local models via Ollama
117 - HuggingFace: Local transformers or Inference API
118 """
120 def __init__(
121 self,
122 name: str,
123 provider: Union[str, LLMProvider] = "ollama",
124 model: str = "llama2",
125 api_key: str | None = None,
126 endpoint: str | None = None,
127 **config
128 ):
129 """Initialize LLM resource.
131 Args:
132 name: Resource name.
133 provider: LLM provider (ollama, openai, anthropic, huggingface, etc).
134 model: Model name/identifier.
135 api_key: API key for commercial providers.
136 endpoint: Custom endpoint URL.
137 **config: Additional configuration.
138 """
139 super().__init__(name, config)
141 # Convert string to enum
142 if isinstance(provider, str):
143 try:
144 self.provider = LLMProvider(provider.lower())
145 except ValueError:
146 self.provider = LLMProvider.CUSTOM
147 else:
148 self.provider = provider
150 self.model = model
151 self.api_key = api_key
152 self.endpoint = endpoint or self._get_default_endpoint()
154 # Initialize provider-specific clients
155 self._client = None
156 self._initialize_client()
158 self._sessions = {}
159 self.status = ResourceStatus.IDLE
161 def _get_default_endpoint(self) -> str | None:
162 """Get default endpoint for provider.
164 Returns:
165 Default endpoint URL or None.
166 """
167 defaults = {
168 LLMProvider.OPENAI: "https://api.openai.com/v1",
169 LLMProvider.ANTHROPIC: "https://api.anthropic.com/v1",
170 LLMProvider.OLLAMA: "http://localhost:11434",
171 LLMProvider.HUGGINGFACE_INFERENCE: "https://api-inference.huggingface.co/models",
172 }
173 return defaults.get(self.provider)
175 def _initialize_client(self) -> None:
176 """Initialize provider-specific client."""
177 try:
178 if self.provider == LLMProvider.OLLAMA:
179 # Ollama uses HTTP API, no special client needed
180 # Just verify endpoint is accessible
181 import urllib.request
182 try:
183 req = urllib.request.Request(f"{self.endpoint}/api/tags")
184 with urllib.request.urlopen(req, timeout=5) as response:
185 if response.status == 200:
186 self.status = ResourceStatus.IDLE
187 except Exception:
188 # Ollama might not be running yet, that's ok
189 self.status = ResourceStatus.IDLE
191 elif self.provider == LLMProvider.HUGGINGFACE:
192 # For local HuggingFace transformers
193 # We'll lazy-load the model when needed
194 self.status = ResourceStatus.IDLE
196 elif self.provider in [LLMProvider.OPENAI, LLMProvider.ANTHROPIC]:
197 # Commercial APIs - just verify we have API key
198 if not self.api_key:
199 raise ResourceError(
200 f"{self.provider.value} requires an API key",
201 resource_name=self.name,
202 operation="initialize"
203 )
204 self.status = ResourceStatus.IDLE
206 else:
207 self.status = ResourceStatus.IDLE
209 except Exception as e:
210 self.status = ResourceStatus.ERROR
211 raise ResourceError(
212 f"Failed to initialize {self.provider.value} client: {e}",
213 resource_name=self.name,
214 operation="initialize"
215 ) from e
217 def acquire(self, **kwargs) -> LLMSession:
218 """Acquire an LLM session.
220 Args:
221 **kwargs: Session configuration overrides.
223 Returns:
224 LLMSession instance.
226 Raises:
227 ResourceError: If acquisition fails.
228 """
229 try:
230 # Set provider-specific defaults
231 if self.provider == LLMProvider.OLLAMA:
232 # Ollama defaults
233 kwargs.setdefault("temperature", 0.8)
234 kwargs.setdefault("requests_per_minute", 0) # No limit
235 kwargs.setdefault("tokens_per_minute", 0) # No limit
237 elif self.provider == LLMProvider.HUGGINGFACE:
238 # HuggingFace local defaults
239 kwargs.setdefault("device", "cpu") # or "cuda" if available
240 kwargs.setdefault("requests_per_minute", 0) # No limit
242 session = LLMSession(
243 provider=self.provider,
244 model_name=kwargs.get("model", self.model),
245 api_key=kwargs.get("api_key", self.api_key),
246 endpoint=kwargs.get("endpoint", self.endpoint),
247 temperature=kwargs.get("temperature", 0.7),
248 max_tokens=kwargs.get("max_tokens", 1000),
249 top_p=kwargs.get("top_p", 1.0),
250 frequency_penalty=kwargs.get("frequency_penalty", 0.0),
251 presence_penalty=kwargs.get("presence_penalty", 0.0),
252 requests_per_minute=kwargs.get("requests_per_minute", 60),
253 tokens_per_minute=kwargs.get("tokens_per_minute", 90000),
254 provider_config=kwargs.get("provider_config", {})
255 )
257 session_id = id(session)
258 self._sessions[session_id] = session
259 self._resources.append(session)
261 self.status = ResourceStatus.ACTIVE
262 return session
264 except Exception as e:
265 self.status = ResourceStatus.ERROR
266 raise ResourceError(
267 f"Failed to acquire LLM session: {e}",
268 resource_name=self.name,
269 operation="acquire"
270 ) from e
272 def release(self, resource: Any) -> None:
273 """Release an LLM session.
275 Args:
276 resource: The LLMSession to release.
277 """
278 if isinstance(resource, LLMSession):
279 session_id = id(resource)
280 if session_id in self._sessions:
281 del self._sessions[session_id]
283 if resource in self._resources:
284 self._resources.remove(resource)
286 if not self._resources:
287 self.status = ResourceStatus.IDLE
289 def validate(self, resource: Any) -> bool:
290 """Validate an LLM session.
292 Args:
293 resource: The LLMSession to validate.
295 Returns:
296 True if the session is valid.
297 """
298 if not isinstance(resource, LLMSession):
299 return False
301 # Check if API key is set for commercial providers
302 if resource.provider in [LLMProvider.OPENAI, LLMProvider.ANTHROPIC,
303 LLMProvider.HUGGINGFACE_INFERENCE]:
304 if not resource.api_key:
305 return False
307 return True
309 def health_check(self) -> ResourceHealth:
310 """Check LLM service health.
312 Returns:
313 Health status.
314 """
315 session = None
316 try:
317 session = self.acquire()
319 if session.provider == LLMProvider.OLLAMA:
320 # Check Ollama API
321 import urllib.request
322 req = urllib.request.Request(f"{session.endpoint}/api/tags")
323 with urllib.request.urlopen(req, timeout=5) as response:
324 if response.status == 200:
325 self.metrics.record_health_check(True)
326 return ResourceHealth.HEALTHY
328 elif session.provider == LLMProvider.HUGGINGFACE:
329 # For local HF, just check if transformers is available
330 try:
331 import importlib.util
332 if importlib.util.find_spec('transformers'):
333 self.metrics.record_health_check(True)
334 return ResourceHealth.HEALTHY
335 else:
336 self.metrics.record_health_check(False)
337 return ResourceHealth.UNHEALTHY
338 except ImportError:
339 self.metrics.record_health_check(False)
340 return ResourceHealth.UNHEALTHY
342 else:
343 # For commercial APIs, assume healthy if session is valid
344 if self.validate(session):
345 self.metrics.record_health_check(True)
346 return ResourceHealth.HEALTHY
348 except Exception:
349 self.metrics.record_health_check(False)
350 return ResourceHealth.UNHEALTHY
351 finally:
352 if session:
353 self.release(session)
355 return ResourceHealth.UNKNOWN
357 def complete(
358 self,
359 prompt: str,
360 session: LLMSession | None = None,
361 **kwargs
362 ) -> Dict[str, Any]:
363 """Generate a completion for the given prompt.
365 Args:
366 prompt: Input prompt.
367 session: Optional session to use.
368 **kwargs: Additional parameters.
370 Returns:
371 Completion response with text and metadata.
372 """
373 if session is None:
374 session = self.acquire()
375 should_release = True
376 else:
377 should_release = False
379 try:
380 # Route to appropriate provider
381 if session.provider == LLMProvider.OLLAMA:
382 response = self._ollama_complete(session, prompt, **kwargs)
383 elif session.provider == LLMProvider.HUGGINGFACE:
384 response = self._huggingface_complete(session, prompt, **kwargs)
385 elif session.provider == LLMProvider.OPENAI:
386 response = self._openai_complete(session, prompt, **kwargs)
387 elif session.provider == LLMProvider.ANTHROPIC:
388 response = self._anthropic_complete(session, prompt, **kwargs)
389 else:
390 response = self._custom_complete(session, prompt, **kwargs)
392 # Record usage if available
393 if "usage" in response:
394 prompt_tokens = response["usage"].get("prompt_tokens", 0)
395 completion_tokens = response["usage"].get("completion_tokens", 0)
396 session.record_usage(prompt_tokens, completion_tokens)
398 return response
400 finally:
401 if should_release:
402 self.release(session)
404 def _ollama_complete(
405 self,
406 session: LLMSession,
407 prompt: str,
408 **kwargs
409 ) -> Dict[str, Any]:
410 """Ollama completion.
412 Args:
413 session: LLM session.
414 prompt: Input prompt.
415 **kwargs: Additional parameters.
417 Returns:
418 Completion response.
419 """
420 import urllib.request
421 import urllib.parse
423 data = {
424 "model": session.model_name,
425 "prompt": prompt,
426 "temperature": kwargs.get("temperature", session.temperature),
427 "max_tokens": kwargs.get("max_tokens", session.max_tokens),
428 "stream": False
429 }
431 req = urllib.request.Request(
432 f"{session.endpoint}/api/generate",
433 data=json.dumps(data).encode("utf-8"),
434 headers={"Content-Type": "application/json"}
435 )
437 with urllib.request.urlopen(req) as response:
438 result = json.loads(response.read())
440 return {
441 "choices": [{
442 "text": result.get("response", ""),
443 "index": 0,
444 "finish_reason": "stop" if result.get("done") else "length"
445 }],
446 "model": session.model_name,
447 "usage": {
448 "prompt_tokens": result.get("prompt_eval_count", 0),
449 "completion_tokens": result.get("eval_count", 0),
450 "total_tokens": result.get("prompt_eval_count", 0) + result.get("eval_count", 0)
451 }
452 }
454 def _huggingface_complete(
455 self,
456 session: LLMSession,
457 prompt: str,
458 **kwargs
459 ) -> Dict[str, Any]:
460 """HuggingFace local completion.
462 Args:
463 session: LLM session.
464 prompt: Input prompt.
465 **kwargs: Additional parameters.
467 Returns:
468 Completion response.
469 """
470 # This would use transformers library for local inference
471 # Placeholder for now
472 try:
473 from transformers import pipeline
475 # Lazy load the model
476 pipe = pipeline(
477 "text-generation",
478 model=session.model_name,
479 device=session.provider_config.get("device", "cpu")
480 )
482 result = pipe(
483 prompt,
484 max_length=kwargs.get("max_tokens", session.max_tokens),
485 temperature=kwargs.get("temperature", session.temperature),
486 top_p=kwargs.get("top_p", session.top_p),
487 )
489 generated_text = result[0]["generated_text"]
490 # Remove the prompt from the output
491 if generated_text.startswith(prompt):
492 generated_text = generated_text[len(prompt):]
494 return {
495 "choices": [{
496 "text": generated_text,
497 "index": 0,
498 "finish_reason": "stop"
499 }],
500 "model": session.model_name
501 }
503 except ImportError as e:
504 raise ResourceError(
505 "HuggingFace transformers library not installed. "
506 "Install with: pip install transformers torch",
507 resource_name=self.name,
508 operation="complete"
509 ) from e
511 def _openai_complete(
512 self,
513 session: LLMSession,
514 prompt: str,
515 **kwargs
516 ) -> Dict[str, Any]:
517 """OpenAI completion using provider system."""
518 from dataknobs_llm.llm.base import LLMConfig, LLMMessage
519 from dataknobs_llm.llm.providers import create_llm_provider as create_provider
521 # Create config from session
522 config = LLMConfig(
523 provider="openai",
524 model=session.model_name,
525 api_key=kwargs.get('api_key', os.getenv('OPENAI_API_KEY')),
526 temperature=kwargs.get('temperature', 0.7),
527 max_tokens=kwargs.get('max_tokens', 1000)
528 )
530 try:
531 # Create provider and execute
532 provider = create_provider(config, is_async=False)
533 provider.initialize()
535 # Convert prompt to message format
536 if isinstance(prompt, str):
537 messages = [LLMMessage(role="user", content=prompt)]
538 else:
539 messages = prompt # type: ignore[unreachable]
541 response = provider.complete(messages, **kwargs)
542 provider.close()
544 # Convert to expected format
545 return {
546 "choices": [{
547 "text": response.content,
548 "index": 0,
549 "finish_reason": response.finish_reason or "stop"
550 }],
551 "model": response.model,
552 "usage": response.usage
553 }
554 except Exception as e:
555 # Fallback to placeholder on error
556 return {
557 "choices": [{
558 "text": f"Error: {e!s}",
559 "index": 0,
560 "finish_reason": "error"
561 }],
562 "model": session.model_name
563 }
565 def _anthropic_complete(
566 self,
567 session: LLMSession,
568 prompt: str,
569 **kwargs
570 ) -> Dict[str, Any]:
571 """Anthropic completion using provider system."""
572 from dataknobs_llm.llm.base import LLMConfig, LLMMessage
573 from dataknobs_llm.llm.providers import create_llm_provider as create_provider
575 # Create config from session
576 config = LLMConfig(
577 provider="anthropic",
578 model=session.model_name,
579 api_key=kwargs.get('api_key', os.getenv('ANTHROPIC_API_KEY')),
580 temperature=kwargs.get('temperature', 0.7),
581 max_tokens=kwargs.get('max_tokens', 1000)
582 )
584 try:
585 # Create provider and execute
586 provider = create_provider(config, is_async=False)
587 provider.initialize()
589 # Convert prompt to message format
590 if isinstance(prompt, str):
591 messages = [LLMMessage(role="user", content=prompt)]
592 else:
593 messages = prompt # type: ignore[unreachable]
595 response = provider.complete(messages, **kwargs)
596 provider.close()
598 # Convert to expected format
599 return {
600 "choices": [{
601 "text": response.content,
602 "index": 0,
603 "finish_reason": response.finish_reason or "stop"
604 }],
605 "model": response.model,
606 "usage": response.usage
607 }
608 except Exception as e:
609 # Fallback to placeholder on error
610 return {
611 "choices": [{
612 "text": f"Error: {e!s}",
613 "index": 0,
614 "finish_reason": "error"
615 }],
616 "model": session.model_name
617 }
619 def _custom_complete(
620 self,
621 session: LLMSession,
622 prompt: str,
623 **kwargs
624 ) -> Dict[str, Any]:
625 """Custom provider completion.
627 For custom/unknown providers.
628 """
629 raise NotImplementedError(
630 f"Custom provider {session.provider.value} not implemented"
631 )
633 def embed(
634 self,
635 text: Union[str, List[str]],
636 session: LLMSession | None = None,
637 **kwargs
638 ) -> List[List[float]]:
639 """Generate embeddings for text.
641 Args:
642 text: Text or list of texts to embed.
643 session: Optional session to use.
644 **kwargs: Additional parameters.
646 Returns:
647 List of embedding vectors.
648 """
649 if session is None:
650 session = self.acquire()
651 should_release = True
652 else:
653 should_release = False
655 try:
656 if isinstance(text, str):
657 texts = [text]
658 else:
659 texts = text
661 # Route to appropriate provider
662 if session.provider == LLMProvider.OLLAMA:
663 embeddings = self._ollama_embed(session, texts, **kwargs)
664 elif session.provider == LLMProvider.HUGGINGFACE:
665 embeddings = self._huggingface_embed(session, texts, **kwargs)
666 elif session.provider == LLMProvider.OPENAI:
667 embeddings = self._openai_embed(session, texts, **kwargs)
668 else:
669 # Fallback to fake embeddings
670 embeddings = [[0.1] * 768 for _ in texts]
672 return embeddings
674 finally:
675 if should_release:
676 self.release(session)
678 def _ollama_embed(
679 self,
680 session: LLMSession,
681 texts: List[str],
682 **kwargs
683 ) -> List[List[float]]:
684 """Generate embeddings using Ollama.
686 Args:
687 session: LLM session.
688 texts: Texts to embed.
689 **kwargs: Additional parameters.
691 Returns:
692 List of embeddings.
693 """
694 import urllib.request
696 embeddings = []
697 for text in texts:
698 data = {
699 "model": kwargs.get("embed_model", "nomic-embed-text"),
700 "prompt": text
701 }
703 req = urllib.request.Request(
704 f"{session.endpoint}/api/embeddings",
705 data=json.dumps(data).encode("utf-8"),
706 headers={"Content-Type": "application/json"}
707 )
709 with urllib.request.urlopen(req) as response:
710 result = json.loads(response.read())
711 embeddings.append(result.get("embedding", []))
713 return embeddings
715 def _huggingface_embed(
716 self,
717 session: LLMSession,
718 texts: List[str],
719 **kwargs
720 ) -> List[List[float]]:
721 """Generate embeddings using HuggingFace.
723 Args:
724 session: LLM session.
725 texts: Texts to embed.
726 **kwargs: Additional parameters.
728 Returns:
729 List of embeddings.
730 """
731 try:
732 from transformers import AutoTokenizer, AutoModel
733 import torch
735 model_name = kwargs.get("embed_model", "sentence-transformers/all-MiniLM-L6-v2")
736 tokenizer = AutoTokenizer.from_pretrained(model_name)
737 model = AutoModel.from_pretrained(model_name)
739 embeddings = []
740 for text in texts:
741 inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
742 with torch.no_grad():
743 outputs = model(**inputs)
744 # Use mean pooling
745 embedding = outputs.last_hidden_state.mean(dim=1).squeeze().tolist()
746 embeddings.append(embedding)
748 return embeddings
750 except ImportError as e:
751 raise ResourceError(
752 "HuggingFace transformers library not installed",
753 resource_name=self.name,
754 operation="embed"
755 ) from e
757 def _openai_embed(
758 self,
759 session: LLMSession,
760 texts: List[str],
761 **kwargs
762 ) -> List[List[float]]:
763 """Generate embeddings using OpenAI provider system."""
764 from dataknobs_fsm.llm.base import LLMConfig
765 from dataknobs_llm.llm.providers import create_llm_provider as create_provider
767 # Create config for embeddings
768 config = LLMConfig(
769 provider="openai",
770 model=kwargs.get('embed_model', 'text-embedding-ada-002'),
771 api_key=kwargs.get('api_key', os.getenv('OPENAI_API_KEY'))
772 )
774 try:
775 # Create provider and generate embeddings
776 provider = create_provider(config, is_async=False)
777 provider.initialize()
779 embeddings = provider.embed(texts, **kwargs)
780 provider.close()
782 # Ensure we return List[List[float]]
783 if isinstance(embeddings[0], list):
784 return embeddings
785 else:
786 return [embeddings] # Single text case
788 except Exception:
789 # Fallback to placeholder dimensions on error
790 return [[0.1] * 1536 for _ in texts] # OpenAI ada-002 dimension
792 def get_usage_stats(self, session: LLMSession) -> Dict[str, Any]:
793 """Get usage statistics for a session.
795 Args:
796 session: LLM session.
798 Returns:
799 Usage statistics.
800 """
801 stats = {
802 "provider": session.provider.value,
803 "model": session.model_name,
804 "total_requests": session.total_requests,
805 }
807 # Add token stats for providers that track them
808 if session.provider in [LLMProvider.OPENAI, LLMProvider.ANTHROPIC,
809 LLMProvider.OLLAMA]:
810 stats.update({
811 "total_prompt_tokens": session.total_prompt_tokens,
812 "total_completion_tokens": session.total_completion_tokens,
813 "total_tokens": session.total_prompt_tokens + session.total_completion_tokens,
814 })
816 # Add rate limit info for commercial providers
817 if session.provider in [LLMProvider.OPENAI, LLMProvider.ANTHROPIC]:
818 stats["rate_limits"] = {
819 "requests_per_minute": session.requests_per_minute,
820 "tokens_per_minute": session.tokens_per_minute,
821 "current_window": {
822 "requests": session.request_count,
823 "tokens": session.token_count,
824 "window_start": session.window_start
825 }
826 }
828 return stats