Coverage for src/dataknobs_llm/conversations/flow/flow.py: 94%
68 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-08 13:51 -0700
1"""Conversation flow definitions.
3This module provides high-level abstractions for defining conversation flows
4that are executed using the FSM engine.
5"""
7from typing import Dict, List, Any, Callable
8from dataclasses import dataclass, field
9from abc import ABC, abstractmethod
12class TransitionCondition(ABC):
13 """Base class for transition conditions.
15 Transition conditions determine which state to transition to based on
16 the LLM response and conversation context.
17 """
19 @abstractmethod
20 async def evaluate(
21 self,
22 response: str,
23 context: Dict[str, Any]
24 ) -> bool:
25 """Evaluate if this transition should be taken.
27 Args:
28 response: The LLM response text
29 context: Current conversation context
31 Returns:
32 True if this transition should be taken
33 """
34 pass
36 @abstractmethod
37 def to_fsm_function(self) -> str:
38 """Convert to FSM function name for registration.
40 Returns:
41 Function name that will be registered in FSM function registry
42 """
43 pass
46@dataclass
47class FlowState:
48 """Single state in a conversation flow.
50 Each state represents a point in the conversation where the system
51 generates a response using a prompt from the prompt library.
53 Attributes:
54 prompt_name: Name of the prompt in the prompt library
55 transitions: Map of condition names to next state names
56 transition_conditions: Map of condition names to TransitionCondition objects
57 max_loops: Maximum times this state can be revisited (None = unlimited)
58 prompt_params: Static parameters to pass to the prompt
59 on_enter: Hook called when entering this state
60 on_exit: Hook called when exiting this state
61 metadata: Additional metadata for this state
62 """
64 prompt_name: str
65 transitions: Dict[str, str] = field(default_factory=dict)
66 transition_conditions: Dict[str, TransitionCondition] = field(default_factory=dict)
68 # Loop detection
69 max_loops: int | None = None
71 # Prompt parameters
72 prompt_params: Dict[str, Any] = field(default_factory=dict)
74 # Hooks
75 on_enter: Callable | None = None
76 on_exit: Callable | None = None
78 # Metadata
79 metadata: Dict[str, Any] = field(default_factory=dict)
81 def __post_init__(self):
82 """Validate state configuration."""
83 if not self.prompt_name:
84 raise ValueError("prompt_name is required")
86 # Ensure all transitions have corresponding conditions
87 for cond_name in self.transitions.keys():
88 if cond_name not in self.transition_conditions:
89 raise ValueError(
90 f"Transition '{cond_name}' has no corresponding condition"
91 )
94@dataclass
95class ConversationFlow:
96 """Complete conversation flow definition.
98 A conversation flow defines the states and transitions that make up
99 a conversation, managed by the FSM engine.
101 Attributes:
102 name: Unique name for this flow
103 initial_state: Name of the starting state
104 states: Map of state names to FlowState objects
105 max_total_loops: Maximum total transitions (prevents infinite loops)
106 timeout_seconds: Maximum execution time (None = no timeout)
107 initial_context: Initial context variables
108 description: Human-readable description
109 version: Semantic version string
110 metadata: Additional metadata
111 """
113 name: str
114 initial_state: str
115 states: Dict[str, FlowState]
117 # Global settings
118 max_total_loops: int = 10
119 timeout_seconds: float | None = None
121 # Context
122 initial_context: Dict[str, Any] = field(default_factory=dict)
124 # Metadata
125 description: str | None = None
126 version: str = "1.0.0"
127 metadata: Dict[str, Any] = field(default_factory=dict)
129 def __post_init__(self):
130 """Validate flow configuration."""
131 if not self.name:
132 raise ValueError("Flow name is required")
134 if not self.initial_state:
135 raise ValueError("initial_state is required")
137 if not self.states:
138 raise ValueError("Flow must have at least one state")
140 if self.initial_state not in self.states:
141 raise ValueError(
142 f"initial_state '{self.initial_state}' not found in states"
143 )
145 # Validate all transition targets exist
146 for state_name, state in self.states.items():
147 for target_state in state.transitions.values():
148 if target_state not in self.states and target_state != "end":
149 raise ValueError(
150 f"State '{state_name}' transitions to unknown state '{target_state}'"
151 )
153 def get_state(self, state_name: str) -> FlowState:
154 """Get a state by name.
156 Args:
157 state_name: Name of the state
159 Returns:
160 FlowState object
162 Raises:
163 KeyError: If state not found
164 """
165 if state_name not in self.states:
166 raise KeyError(f"State '{state_name}' not found")
167 return self.states[state_name]
169 def get_reachable_states(self, from_state: str) -> List[str]:
170 """Get all states reachable from a given state.
172 Args:
173 from_state: Starting state name
175 Returns:
176 List of reachable state names
177 """
178 state = self.get_state(from_state)
179 return list(state.transitions.values())
181 def validate_flow(self) -> List[str]:
182 """Validate the flow and return any warnings.
184 Returns:
185 List of warning messages (empty if no issues)
186 """
187 warnings = []
189 # Check for unreachable states
190 reachable = {self.initial_state}
191 to_visit = [self.initial_state]
193 while to_visit:
194 current = to_visit.pop()
195 for next_state in self.get_reachable_states(current):
196 if next_state != "end" and next_state not in reachable:
197 reachable.add(next_state)
198 to_visit.append(next_state)
200 for state_name in self.states.keys():
201 if state_name not in reachable:
202 warnings.append(f"State '{state_name}' is unreachable")
204 # Check for states with no exit transitions
205 for state_name, state in self.states.items():
206 if not state.transitions:
207 warnings.append(
208 f"State '{state_name}' has no exit transitions (potential dead end)"
209 )
211 return warnings