Coverage for src/dataknobs_llm/conversations/flow/conditions.py: 81%

116 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-08 13:51 -0700

1"""Transition condition implementations. 

2 

3This module provides concrete implementations of TransitionCondition for 

4common conversation flow patterns. 

5""" 

6 

7import re 

8from typing import Dict, Any, List, Callable 

9from dataclasses import dataclass 

10 

11from .flow import TransitionCondition 

12 

13 

14@dataclass 

15class AlwaysCondition(TransitionCondition): 

16 """Condition that always evaluates to True. 

17 

18 Useful for unconditional transitions or fallback transitions. 

19 """ 

20 

21 name: str = "always" 

22 

23 async def evaluate(self, response: str, context: Dict[str, Any]) -> bool: 

24 """Always returns True.""" 

25 return True 

26 

27 def to_fsm_function(self) -> str: 

28 """Return function name for FSM registration.""" 

29 return f"always_{id(self)}" 

30 

31 

32@dataclass 

33class KeywordCondition(TransitionCondition): 

34 """Condition based on keyword matching in response. 

35 

36 Evaluates to True if any of the specified keywords are found in the 

37 response (case-insensitive by default). 

38 

39 Attributes: 

40 keywords: List of keywords to match 

41 case_sensitive: Whether matching should be case-sensitive 

42 match_whole_word: If True, only match complete words 

43 """ 

44 

45 keywords: List[str] 

46 case_sensitive: bool = False 

47 match_whole_word: bool = False 

48 

49 async def evaluate(self, response: str, context: Dict[str, Any]) -> bool: 

50 """Check if any keyword is in the response.""" 

51 text = response if self.case_sensitive else response.lower() 

52 

53 for keyword in self.keywords: 

54 kw = keyword if self.case_sensitive else keyword.lower() 

55 

56 if self.match_whole_word: 

57 # Use word boundaries 

58 pattern = r'\b' + re.escape(kw) + r'\b' 

59 if re.search(pattern, text): 

60 return True 

61 else: 

62 if kw in text: 

63 return True 

64 

65 return False 

66 

67 def to_fsm_function(self) -> str: 

68 """Return function name for FSM registration.""" 

69 return f"keyword_{id(self)}" 

70 

71 

72@dataclass 

73class RegexCondition(TransitionCondition): 

74 """Condition based on regular expression matching. 

75 

76 Evaluates to True if the regex pattern matches the response. 

77 

78 Attributes: 

79 pattern: Regular expression pattern 

80 flags: Regex flags (re.IGNORECASE, etc.) 

81 """ 

82 

83 pattern: str 

84 flags: int = 0 

85 

86 def __post_init__(self): 

87 """Compile the regex pattern.""" 

88 self._compiled_pattern = re.compile(self.pattern, self.flags) 

89 

90 async def evaluate(self, response: str, context: Dict[str, Any]) -> bool: 

91 """Check if pattern matches the response.""" 

92 return bool(self._compiled_pattern.search(response)) 

93 

94 def to_fsm_function(self) -> str: 

95 """Return function name for FSM registration.""" 

96 return f"regex_{id(self)}" 

97 

98 

99@dataclass 

100class LLMClassifierCondition(TransitionCondition): 

101 """Condition based on LLM classification of the response. 

102 

103 Uses an LLM to classify the response into one of several categories. 

104 Evaluates to True if the classification matches the expected value. 

105 

106 Attributes: 

107 classifier_prompt: Prompt template for classification 

108 expected_value: Expected classification result 

109 llm_config: Optional LLM configuration override 

110 """ 

111 

112 classifier_prompt: str 

113 expected_value: str 

114 llm_config: Dict[str, Any] | None = None 

115 

116 async def evaluate(self, response: str, context: Dict[str, Any]) -> bool: 

117 """Use LLM to classify and check against expected value.""" 

118 # Import here to avoid circular dependencies 

119 from dataknobs_llm.llm import create_llm_provider, LLMConfig 

120 

121 # Get LLM provider from context or create new one 

122 llm = context.get('_llm_provider') 

123 if llm is None and self.llm_config: 

124 config = LLMConfig(**self.llm_config) 

125 llm = create_llm_provider(config) 

126 

127 if llm is None: 

128 raise ValueError( 

129 "LLMClassifierCondition requires an LLM provider in context " 

130 "or llm_config parameter" 

131 ) 

132 

133 # Format the classifier prompt with the response 

134 prompt = self.classifier_prompt.replace("{{response}}", response) 

135 

136 # Get classification from LLM 

137 result = await llm.complete(prompt) 

138 classification = result.content.strip().lower() 

139 

140 # Check if it matches expected value 

141 return classification == self.expected_value.lower() 

142 

143 def to_fsm_function(self) -> str: 

144 """Return function name for FSM registration.""" 

145 return f"llm_classifier_{id(self)}" 

146 

147 

148@dataclass 

149class ContextCondition(TransitionCondition): 

150 """Condition based on context variables. 

151 

152 Evaluates a condition based on values in the conversation context. 

153 

154 Attributes: 

155 condition_func: Function that takes context and returns bool 

156 """ 

157 

158 condition_func: Callable[[Dict[str, Any]], bool] 

159 

160 async def evaluate(self, response: str, context: Dict[str, Any]) -> bool: 

161 """Evaluate condition function with context.""" 

162 return self.condition_func(context) 

163 

164 def to_fsm_function(self) -> str: 

165 """Return function name for FSM registration.""" 

166 return f"context_{id(self)}" 

167 

168 

169@dataclass 

170class CompositeCondition(TransitionCondition): 

171 """Condition that combines multiple conditions with AND/OR logic. 

172 

173 Attributes: 

174 conditions: List of conditions to evaluate 

175 operator: 'and' or 'or' 

176 """ 

177 

178 conditions: List[TransitionCondition] 

179 operator: str = "and" # 'and' or 'or' 

180 

181 def __post_init__(self): 

182 """Validate operator.""" 

183 if self.operator not in ("and", "or"): 

184 raise ValueError("operator must be 'and' or 'or'") 

185 

186 async def evaluate(self, response: str, context: Dict[str, Any]) -> bool: 

187 """Evaluate all conditions with specified operator.""" 

188 results = [ 

189 await cond.evaluate(response, context) 

190 for cond in self.conditions 

191 ] 

192 

193 if self.operator == "and": 

194 return all(results) 

195 else: # or 

196 return any(results) 

197 

198 def to_fsm_function(self) -> str: 

199 """Return function name for FSM registration.""" 

200 return f"composite_{self.operator}_{id(self)}" 

201 

202 

203@dataclass 

204class SentimentCondition(TransitionCondition): 

205 """Condition based on sentiment analysis. 

206 

207 Evaluates to True if the response sentiment matches the expected sentiment. 

208 

209 Attributes: 

210 expected_sentiment: Expected sentiment ('positive', 'negative', 'neutral') 

211 threshold: Confidence threshold (0.0 to 1.0) 

212 """ 

213 

214 expected_sentiment: str 

215 threshold: float = 0.5 

216 

217 def __post_init__(self): 

218 """Validate sentiment value.""" 

219 valid_sentiments = ('positive', 'negative', 'neutral') 

220 if self.expected_sentiment not in valid_sentiments: 

221 raise ValueError( 

222 f"expected_sentiment must be one of {valid_sentiments}" 

223 ) 

224 

225 async def evaluate(self, response: str, context: Dict[str, Any]) -> bool: 

226 """Analyze sentiment and check against expected value.""" 

227 # Simple keyword-based sentiment analysis (can be replaced with ML model) 

228 response_lower = response.lower() 

229 

230 positive_words = {'happy', 'good', 'great', 'excellent', 'yes', 'sure', 'love', 'like'} 

231 negative_words = {'sad', 'bad', 'terrible', 'no', 'hate', 'dislike', 'poor'} 

232 

233 positive_count = sum(1 for word in positive_words if word in response_lower) 

234 negative_count = sum(1 for word in negative_words if word in response_lower) 

235 

236 total = positive_count + negative_count 

237 if total == 0: 

238 sentiment = 'neutral' 

239 confidence = 1.0 

240 else: 

241 if positive_count > negative_count: 

242 sentiment = 'positive' 

243 confidence = positive_count / total 

244 elif negative_count > positive_count: 

245 sentiment = 'negative' 

246 confidence = negative_count / total 

247 else: 

248 sentiment = 'neutral' 

249 confidence = 0.5 

250 

251 return sentiment == self.expected_sentiment and confidence >= self.threshold 

252 

253 def to_fsm_function(self) -> str: 

254 """Return function name for FSM registration.""" 

255 return f"sentiment_{self.expected_sentiment}_{id(self)}" 

256 

257 

258# Factory functions for common conditions 

259 

260def keyword_condition(keywords: List[str], **kwargs) -> KeywordCondition: 

261 """Create a keyword condition. 

262 

263 Args: 

264 keywords: List of keywords to match 

265 **kwargs: Additional KeywordCondition parameters 

266 

267 Returns: 

268 KeywordCondition instance 

269 """ 

270 return KeywordCondition(keywords=keywords, **kwargs) 

271 

272 

273def regex_condition(pattern: str, **kwargs) -> RegexCondition: 

274 """Create a regex condition. 

275 

276 Args: 

277 pattern: Regular expression pattern 

278 **kwargs: Additional RegexCondition parameters 

279 

280 Returns: 

281 RegexCondition instance 

282 """ 

283 return RegexCondition(pattern=pattern, **kwargs) 

284 

285 

286def always() -> AlwaysCondition: 

287 """Create an always-true condition. 

288 

289 Returns: 

290 AlwaysCondition instance 

291 """ 

292 return AlwaysCondition() 

293 

294 

295def context_condition(func: Callable[[Dict[str, Any]], bool]) -> ContextCondition: 

296 """Create a context condition. 

297 

298 Args: 

299 func: Function that takes context dict and returns bool 

300 

301 Returns: 

302 ContextCondition instance 

303 """ 

304 return ContextCondition(condition_func=func)