Coverage for src / dataknobs_llm / conversations / flow / flow.py: 44%

68 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-15 10:28 -0700

1"""Conversation flow definitions. 

2 

3This module provides high-level abstractions for defining conversation flows 

4that are executed using the FSM engine. 

5""" 

6 

7from typing import Dict, List, Any, Callable 

8from dataclasses import dataclass, field 

9from abc import ABC, abstractmethod 

10 

11 

12class TransitionCondition(ABC): 

13 """Base class for transition conditions. 

14 

15 Transition conditions determine which state to transition to based on 

16 the LLM response and conversation context. 

17 """ 

18 

19 @abstractmethod 

20 async def evaluate( 

21 self, 

22 response: str, 

23 context: Dict[str, Any] 

24 ) -> bool: 

25 """Evaluate if this transition should be taken. 

26 

27 Args: 

28 response: The LLM response text 

29 context: Current conversation context 

30 

31 Returns: 

32 True if this transition should be taken 

33 """ 

34 pass 

35 

36 @abstractmethod 

37 def to_fsm_function(self) -> str: 

38 """Convert to FSM function name for registration. 

39 

40 Returns: 

41 Function name that will be registered in FSM function registry 

42 """ 

43 pass 

44 

45 

46@dataclass 

47class FlowState: 

48 """Single state in a conversation flow. 

49 

50 Each state represents a point in the conversation where the system 

51 generates a response using a prompt from the prompt library. 

52 

53 Attributes: 

54 prompt_name: Name of the prompt in the prompt library 

55 transitions: Map of condition names to next state names 

56 transition_conditions: Map of condition names to TransitionCondition objects 

57 max_loops: Maximum times this state can be revisited (None = unlimited) 

58 prompt_params: Static parameters to pass to the prompt 

59 on_enter: Hook called when entering this state 

60 on_exit: Hook called when exiting this state 

61 metadata: Additional metadata for this state 

62 """ 

63 

64 prompt_name: str 

65 transitions: Dict[str, str] = field(default_factory=dict) 

66 transition_conditions: Dict[str, TransitionCondition] = field(default_factory=dict) 

67 

68 # Loop detection 

69 max_loops: int | None = None 

70 

71 # Prompt parameters 

72 prompt_params: Dict[str, Any] = field(default_factory=dict) 

73 

74 # Hooks 

75 on_enter: Callable | None = None 

76 on_exit: Callable | None = None 

77 

78 # Metadata 

79 metadata: Dict[str, Any] = field(default_factory=dict) 

80 

81 def __post_init__(self): 

82 """Validate state configuration.""" 

83 if not self.prompt_name: 

84 raise ValueError("prompt_name is required") 

85 

86 # Ensure all transitions have corresponding conditions 

87 for cond_name in self.transitions.keys(): 

88 if cond_name not in self.transition_conditions: 

89 raise ValueError( 

90 f"Transition '{cond_name}' has no corresponding condition" 

91 ) 

92 

93 

94@dataclass 

95class ConversationFlow: 

96 """Complete conversation flow definition. 

97 

98 A conversation flow defines the states and transitions that make up 

99 a conversation, managed by the FSM engine. 

100 

101 Attributes: 

102 name: Unique name for this flow 

103 initial_state: Name of the starting state 

104 states: Map of state names to FlowState objects 

105 max_total_loops: Maximum total transitions (prevents infinite loops) 

106 timeout_seconds: Maximum execution time (None = no timeout) 

107 initial_context: Initial context variables 

108 description: Human-readable description 

109 version: Semantic version string 

110 metadata: Additional metadata 

111 """ 

112 

113 name: str 

114 initial_state: str 

115 states: Dict[str, FlowState] 

116 

117 # Global settings 

118 max_total_loops: int = 10 

119 timeout_seconds: float | None = None 

120 

121 # Context 

122 initial_context: Dict[str, Any] = field(default_factory=dict) 

123 

124 # Metadata 

125 description: str | None = None 

126 version: str = "1.0.0" 

127 metadata: Dict[str, Any] = field(default_factory=dict) 

128 

129 def __post_init__(self): 

130 """Validate flow configuration.""" 

131 if not self.name: 

132 raise ValueError("Flow name is required") 

133 

134 if not self.initial_state: 

135 raise ValueError("initial_state is required") 

136 

137 if not self.states: 

138 raise ValueError("Flow must have at least one state") 

139 

140 if self.initial_state not in self.states: 

141 raise ValueError( 

142 f"initial_state '{self.initial_state}' not found in states" 

143 ) 

144 

145 # Validate all transition targets exist 

146 for state_name, state in self.states.items(): 

147 for target_state in state.transitions.values(): 

148 if target_state not in self.states and target_state != "end": 

149 raise ValueError( 

150 f"State '{state_name}' transitions to unknown state '{target_state}'" 

151 ) 

152 

153 def get_state(self, state_name: str) -> FlowState: 

154 """Get a state by name. 

155 

156 Args: 

157 state_name: Name of the state 

158 

159 Returns: 

160 FlowState object 

161 

162 Raises: 

163 KeyError: If state not found 

164 """ 

165 if state_name not in self.states: 

166 raise KeyError(f"State '{state_name}' not found") 

167 return self.states[state_name] 

168 

169 def get_reachable_states(self, from_state: str) -> List[str]: 

170 """Get all states reachable from a given state. 

171 

172 Args: 

173 from_state: Starting state name 

174 

175 Returns: 

176 List of reachable state names 

177 """ 

178 state = self.get_state(from_state) 

179 return list(state.transitions.values()) 

180 

181 def validate_flow(self) -> List[str]: 

182 """Validate the flow and return any warnings. 

183 

184 Returns: 

185 List of warning messages (empty if no issues) 

186 """ 

187 warnings = [] 

188 

189 # Check for unreachable states 

190 reachable = {self.initial_state} 

191 to_visit = [self.initial_state] 

192 

193 while to_visit: 

194 current = to_visit.pop() 

195 for next_state in self.get_reachable_states(current): 

196 if next_state != "end" and next_state not in reachable: 

197 reachable.add(next_state) 

198 to_visit.append(next_state) 

199 

200 for state_name in self.states.keys(): 

201 if state_name not in reachable: 

202 warnings.append(f"State '{state_name}' is unreachable") 

203 

204 # Check for states with no exit transitions 

205 for state_name, state in self.states.items(): 

206 if not state.transitions: 

207 warnings.append( 

208 f"State '{state_name}' has no exit transitions (potential dead end)" 

209 ) 

210 

211 return warnings