Coverage for src/dataknobs_llm/conversations/flow/conditions.py: 81%
116 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
1"""Transition condition implementations.
3This module provides concrete implementations of TransitionCondition for
4common conversation flow patterns.
5"""
7import re
8from typing import Dict, Any, List, Callable
9from dataclasses import dataclass
11from .flow import TransitionCondition
14@dataclass
15class AlwaysCondition(TransitionCondition):
16 """Condition that always evaluates to True.
18 Useful for unconditional transitions or fallback transitions.
19 """
21 name: str = "always"
23 async def evaluate(self, response: str, context: Dict[str, Any]) -> bool:
24 """Always returns True."""
25 return True
27 def to_fsm_function(self) -> str:
28 """Return function name for FSM registration."""
29 return f"always_{id(self)}"
32@dataclass
33class KeywordCondition(TransitionCondition):
34 """Condition based on keyword matching in response.
36 Evaluates to True if any of the specified keywords are found in the
37 response (case-insensitive by default).
39 Attributes:
40 keywords: List of keywords to match
41 case_sensitive: Whether matching should be case-sensitive
42 match_whole_word: If True, only match complete words
43 """
45 keywords: List[str]
46 case_sensitive: bool = False
47 match_whole_word: bool = False
49 async def evaluate(self, response: str, context: Dict[str, Any]) -> bool:
50 """Check if any keyword is in the response."""
51 text = response if self.case_sensitive else response.lower()
53 for keyword in self.keywords:
54 kw = keyword if self.case_sensitive else keyword.lower()
56 if self.match_whole_word:
57 # Use word boundaries
58 pattern = r'\b' + re.escape(kw) + r'\b'
59 if re.search(pattern, text):
60 return True
61 else:
62 if kw in text:
63 return True
65 return False
67 def to_fsm_function(self) -> str:
68 """Return function name for FSM registration."""
69 return f"keyword_{id(self)}"
72@dataclass
73class RegexCondition(TransitionCondition):
74 """Condition based on regular expression matching.
76 Evaluates to True if the regex pattern matches the response.
78 Attributes:
79 pattern: Regular expression pattern
80 flags: Regex flags (re.IGNORECASE, etc.)
81 """
83 pattern: str
84 flags: int = 0
86 def __post_init__(self):
87 """Compile the regex pattern."""
88 self._compiled_pattern = re.compile(self.pattern, self.flags)
90 async def evaluate(self, response: str, context: Dict[str, Any]) -> bool:
91 """Check if pattern matches the response."""
92 return bool(self._compiled_pattern.search(response))
94 def to_fsm_function(self) -> str:
95 """Return function name for FSM registration."""
96 return f"regex_{id(self)}"
99@dataclass
100class LLMClassifierCondition(TransitionCondition):
101 """Condition based on LLM classification of the response.
103 Uses an LLM to classify the response into one of several categories.
104 Evaluates to True if the classification matches the expected value.
106 Attributes:
107 classifier_prompt: Prompt template for classification
108 expected_value: Expected classification result
109 llm_config: Optional LLM configuration override
110 """
112 classifier_prompt: str
113 expected_value: str
114 llm_config: Dict[str, Any] | None = None
116 async def evaluate(self, response: str, context: Dict[str, Any]) -> bool:
117 """Use LLM to classify and check against expected value."""
118 # Import here to avoid circular dependencies
119 from dataknobs_llm.llm import create_llm_provider, LLMConfig
121 # Get LLM provider from context or create new one
122 llm = context.get('_llm_provider')
123 if llm is None and self.llm_config:
124 config = LLMConfig(**self.llm_config)
125 llm = create_llm_provider(config)
127 if llm is None:
128 raise ValueError(
129 "LLMClassifierCondition requires an LLM provider in context "
130 "or llm_config parameter"
131 )
133 # Format the classifier prompt with the response
134 prompt = self.classifier_prompt.replace("{{response}}", response)
136 # Get classification from LLM
137 result = await llm.complete(prompt)
138 classification = result.content.strip().lower()
140 # Check if it matches expected value
141 return classification == self.expected_value.lower()
143 def to_fsm_function(self) -> str:
144 """Return function name for FSM registration."""
145 return f"llm_classifier_{id(self)}"
148@dataclass
149class ContextCondition(TransitionCondition):
150 """Condition based on context variables.
152 Evaluates a condition based on values in the conversation context.
154 Attributes:
155 condition_func: Function that takes context and returns bool
156 """
158 condition_func: Callable[[Dict[str, Any]], bool]
160 async def evaluate(self, response: str, context: Dict[str, Any]) -> bool:
161 """Evaluate condition function with context."""
162 return self.condition_func(context)
164 def to_fsm_function(self) -> str:
165 """Return function name for FSM registration."""
166 return f"context_{id(self)}"
169@dataclass
170class CompositeCondition(TransitionCondition):
171 """Condition that combines multiple conditions with AND/OR logic.
173 Attributes:
174 conditions: List of conditions to evaluate
175 operator: 'and' or 'or'
176 """
178 conditions: List[TransitionCondition]
179 operator: str = "and" # 'and' or 'or'
181 def __post_init__(self):
182 """Validate operator."""
183 if self.operator not in ("and", "or"):
184 raise ValueError("operator must be 'and' or 'or'")
186 async def evaluate(self, response: str, context: Dict[str, Any]) -> bool:
187 """Evaluate all conditions with specified operator."""
188 results = [
189 await cond.evaluate(response, context)
190 for cond in self.conditions
191 ]
193 if self.operator == "and":
194 return all(results)
195 else: # or
196 return any(results)
198 def to_fsm_function(self) -> str:
199 """Return function name for FSM registration."""
200 return f"composite_{self.operator}_{id(self)}"
203@dataclass
204class SentimentCondition(TransitionCondition):
205 """Condition based on sentiment analysis.
207 Evaluates to True if the response sentiment matches the expected sentiment.
209 Attributes:
210 expected_sentiment: Expected sentiment ('positive', 'negative', 'neutral')
211 threshold: Confidence threshold (0.0 to 1.0)
212 """
214 expected_sentiment: str
215 threshold: float = 0.5
217 def __post_init__(self):
218 """Validate sentiment value."""
219 valid_sentiments = ('positive', 'negative', 'neutral')
220 if self.expected_sentiment not in valid_sentiments:
221 raise ValueError(
222 f"expected_sentiment must be one of {valid_sentiments}"
223 )
225 async def evaluate(self, response: str, context: Dict[str, Any]) -> bool:
226 """Analyze sentiment and check against expected value."""
227 # Simple keyword-based sentiment analysis (can be replaced with ML model)
228 response_lower = response.lower()
230 positive_words = {'happy', 'good', 'great', 'excellent', 'yes', 'sure', 'love', 'like'}
231 negative_words = {'sad', 'bad', 'terrible', 'no', 'hate', 'dislike', 'poor'}
233 positive_count = sum(1 for word in positive_words if word in response_lower)
234 negative_count = sum(1 for word in negative_words if word in response_lower)
236 total = positive_count + negative_count
237 if total == 0:
238 sentiment = 'neutral'
239 confidence = 1.0
240 else:
241 if positive_count > negative_count:
242 sentiment = 'positive'
243 confidence = positive_count / total
244 elif negative_count > positive_count:
245 sentiment = 'negative'
246 confidence = negative_count / total
247 else:
248 sentiment = 'neutral'
249 confidence = 0.5
251 return sentiment == self.expected_sentiment and confidence >= self.threshold
253 def to_fsm_function(self) -> str:
254 """Return function name for FSM registration."""
255 return f"sentiment_{self.expected_sentiment}_{id(self)}"
258# Factory functions for common conditions
260def keyword_condition(keywords: List[str], **kwargs) -> KeywordCondition:
261 """Create a keyword condition.
263 Args:
264 keywords: List of keywords to match
265 **kwargs: Additional KeywordCondition parameters
267 Returns:
268 KeywordCondition instance
269 """
270 return KeywordCondition(keywords=keywords, **kwargs)
273def regex_condition(pattern: str, **kwargs) -> RegexCondition:
274 """Create a regex condition.
276 Args:
277 pattern: Regular expression pattern
278 **kwargs: Additional RegexCondition parameters
280 Returns:
281 RegexCondition instance
282 """
283 return RegexCondition(pattern=pattern, **kwargs)
286def always() -> AlwaysCondition:
287 """Create an always-true condition.
289 Returns:
290 AlwaysCondition instance
291 """
292 return AlwaysCondition()
295def context_condition(func: Callable[[Dict[str, Any]], bool]) -> ContextCondition:
296 """Create a context condition.
298 Args:
299 func: Function that takes context dict and returns bool
301 Returns:
302 ContextCondition instance
303 """
304 return ContextCondition(condition_func=func)