Coverage for src / invariant / graph.py: 98.63%

73 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-20 16:05 +0000

1"""GraphResolver for parsing, validating, and sorting DAGs.""" 

2 

3from collections import deque 

4from typing import TYPE_CHECKING 

5 

6from invariant.node import Node 

7 

8if TYPE_CHECKING: 

9 from invariant.registry import OpRegistry 

10 

11 

12class GraphResolver: 

13 """Responsible for parsing graph definitions and ensuring valid DAGs. 

14 

15 Handles: 

16 - Cycle detection 

17 - Validation (missing dependencies, missing ops) 

18 - Topological sorting 

19 """ 

20 

21 def __init__(self, registry: "OpRegistry | None" = None) -> None: 

22 """Initialize GraphResolver. 

23 

24 Args: 

25 registry: Optional OpRegistry for validating that ops exist. 

26 If None, op validation is skipped. 

27 """ 

28 self.registry = registry 

29 

30 def validate( 

31 self, graph: dict[str, Node], context_keys: set[str] | None = None 

32 ) -> None: 

33 """Validate a graph definition. 

34 

35 Checks: 

36 - All node dependencies exist in the graph or in context 

37 - All referenced ops are registered (if registry provided) 

38 - No cycles exist 

39 

40 Args: 

41 graph: Dictionary mapping node IDs to Node objects. 

42 context_keys: Optional set of external dependency keys (from context). 

43 Dependencies not in the graph are allowed if they're in context. 

44 

45 Raises: 

46 ValueError: If validation fails (missing deps, missing ops, cycles). 

47 """ 

48 # Check all dependencies exist 

49 node_ids = set(graph.keys()) 

50 context_keys = context_keys or set() 

51 for node_id, node in graph.items(): 

52 for dep in node.deps: 

53 if dep not in node_ids and dep not in context_keys: 

54 raise ValueError( 

55 f"Node '{node_id}' has dependency '{dep}' that doesn't exist in graph " 

56 f"or context. Available: graph={sorted(node_ids)}, " 

57 f"context={sorted(context_keys)}" 

58 ) 

59 

60 # Check all ops are registered (if registry provided) 

61 if self.registry: 

62 for node_id, node in graph.items(): 

63 if not self.registry.has(node.op_name): 

64 raise ValueError( 

65 f"Node '{node_id}' references unregistered op '{node.op_name}'" 

66 ) 

67 

68 # Check for cycles (excluding context dependencies) 

69 if self._has_cycle(graph, context_keys=context_keys): 

70 raise ValueError("Graph contains cycles") 

71 

72 def _has_cycle( 

73 self, graph: dict[str, Node], context_keys: set[str] | None = None 

74 ) -> bool: 

75 """Detect cycles in the graph using DFS. 

76 

77 Args: 

78 graph: Dictionary mapping node IDs to Node objects. 

79 context_keys: Optional set of external dependency keys (from context). 

80 These are excluded from cycle detection. 

81 

82 Returns: 

83 True if cycle exists, False otherwise. 

84 """ 

85 node_ids = set(graph.keys()) 

86 context_keys = context_keys or set() 

87 WHITE = 0 # Unvisited 

88 GRAY = 1 # Currently in DFS path 

89 BLACK = 2 # Fully processed 

90 

91 color: dict[str, int] = {node_id: WHITE for node_id in node_ids} 

92 

93 def dfs(node_id: str) -> bool: 

94 """DFS helper that returns True if cycle found.""" 

95 if node_id not in node_ids: 

96 # This is a context dependency, not part of the graph - no cycle possible 

97 return False 

98 if color[node_id] == GRAY: 

99 # Back edge found - cycle detected 

100 return True 

101 if color[node_id] == BLACK: 

102 # Already processed 

103 return False 

104 

105 color[node_id] = GRAY 

106 node = graph[node_id] 

107 for dep in node.deps: 

108 # Only check dependencies that are in the graph (not context) 

109 if dep in node_ids: 

110 if dfs(dep): 

111 return True 

112 

113 color[node_id] = BLACK 

114 return False 

115 

116 # Check all nodes (handles disconnected components) 

117 for node_id in node_ids: 

118 if color[node_id] == WHITE: 

119 if dfs(node_id): 

120 return True 

121 

122 return False 

123 

124 def topological_sort( 

125 self, graph: dict[str, Node], context_keys: set[str] | None = None 

126 ) -> list[str]: 

127 """Topologically sort the graph using Kahn's algorithm. 

128 

129 Args: 

130 graph: Dictionary mapping node IDs to Node objects. 

131 context_keys: Optional set of external dependency keys (from context). 

132 These are excluded from topological sorting. 

133 

134 Returns: 

135 List of node IDs in topological order (dependencies before dependents). 

136 

137 Raises: 

138 ValueError: If graph contains cycles. 

139 """ 

140 node_ids = set(graph.keys()) 

141 context_keys = context_keys or set() 

142 

143 # Build reverse dependency map: which nodes depend on each node 

144 # Only include dependencies that are in the graph (not context) 

145 dependents: dict[str, list[str]] = {node_id: [] for node_id in node_ids} 

146 for node_id, node in graph.items(): 

147 for dep in node.deps: 

148 # Only track dependencies that are in the graph 

149 if dep in node_ids: 

150 dependents[dep].append(node_id) 

151 

152 # Calculate in-degree for each node (number of graph dependencies it has) 

153 # Context dependencies don't count toward in-degree 

154 in_degree: dict[str, int] = {} 

155 for node_id, node in graph.items(): 

156 # Count only graph dependencies (not context) 

157 graph_deps = [d for d in node.deps if d in node_ids] 

158 in_degree[node_id] = len(graph_deps) 

159 

160 # Find all nodes with in-degree 0 (no graph dependencies) 

161 queue = deque([node_id for node_id in node_ids if in_degree[node_id] == 0]) 

162 result: list[str] = [] 

163 

164 while queue: 

165 node_id = queue.popleft() 

166 result.append(node_id) 

167 

168 # Reduce in-degree of nodes that depend on this node 

169 for dependent in dependents[node_id]: 

170 in_degree[dependent] -= 1 

171 if in_degree[dependent] == 0: 

172 queue.append(dependent) 

173 

174 # If we didn't process all nodes, there's a cycle 

175 if len(result) != len(node_ids): 

176 raise ValueError("Graph contains cycles (topological sort impossible)") 

177 

178 return result 

179 

180 def resolve( 

181 self, graph: dict[str, Node], context_keys: set[str] | None = None 

182 ) -> list[str]: 

183 """Validate and topologically sort a graph. 

184 

185 Convenience method that validates then sorts. 

186 

187 Args: 

188 graph: Dictionary mapping node IDs to Node objects. 

189 context_keys: Optional set of external dependency keys (from context). 

190 

191 Returns: 

192 List of node IDs in topological order. 

193 

194 Raises: 

195 ValueError: If validation fails or cycles exist. 

196 """ 

197 self.validate(graph, context_keys=context_keys) 

198 return self.topological_sort(graph, context_keys=context_keys)