Coverage for src / invariant / graph.py: 98.68%

76 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 10:21 +0100

1"""GraphResolver for parsing, validating, and sorting DAGs.""" 

2 

3from collections import deque 

4from typing import TYPE_CHECKING 

5 

6from invariant.node import Node, SubGraphNode 

7 

8if TYPE_CHECKING: 

9 from invariant.registry import OpRegistry 

10 

11# Graph may contain regular nodes or subgraph nodes (internal DAGs). 

12GraphVertex = Node | SubGraphNode 

13Graph = dict[str, GraphVertex] 

14 

15 

16class GraphResolver: 

17 """Responsible for parsing graph definitions and ensuring valid DAGs. 

18 

19 Handles: 

20 - Cycle detection 

21 - Validation (missing dependencies, missing ops) 

22 - Topological sorting 

23 """ 

24 

25 def __init__(self, registry: "OpRegistry | None" = None) -> None: 

26 """Initialize GraphResolver. 

27 

28 Args: 

29 registry: Optional OpRegistry for validating that ops exist. 

30 If None, op validation is skipped. 

31 """ 

32 self.registry = registry 

33 

34 def validate(self, graph: Graph, context_keys: set[str] | None = None) -> None: 

35 """Validate a graph definition. 

36 

37 Checks: 

38 - All node dependencies exist in the graph or in context 

39 - All referenced ops are registered (if registry provided; Node only) 

40 - No cycles exist 

41 

42 Args: 

43 graph: Dictionary mapping node IDs to Node or SubGraphNode objects. 

44 context_keys: Optional set of external dependency keys (from context). 

45 Dependencies not in the graph are allowed if they're in context. 

46 

47 Raises: 

48 ValueError: If validation fails (missing deps, missing ops, cycles). 

49 """ 

50 # Check all dependencies exist 

51 node_ids = set(graph.keys()) 

52 context_keys = context_keys or set() 

53 for node_id, node in graph.items(): 

54 for dep in node.deps: 

55 if dep not in node_ids and dep not in context_keys: 

56 raise ValueError( 

57 f"Node '{node_id}' has dependency '{dep}' that doesn't exist in graph " 

58 f"or context. Available: graph={sorted(node_ids)}, " 

59 f"context={sorted(context_keys)}" 

60 ) 

61 

62 # Check all ops are registered (if registry provided); only Node has op_name 

63 if self.registry: 

64 for node_id, node in graph.items(): 

65 if isinstance(node, Node): 

66 if not self.registry.has(node.op_name): 

67 raise ValueError( 

68 f"Node '{node_id}' references unregistered op '{node.op_name}'" 

69 ) 

70 

71 # Check for cycles (excluding context dependencies) 

72 if self._has_cycle(graph, context_keys=context_keys): 

73 raise ValueError("Graph contains cycles") 

74 

75 def _has_cycle(self, graph: Graph, context_keys: set[str] | None = None) -> bool: 

76 """Detect cycles in the graph using DFS. 

77 

78 Args: 

79 graph: Dictionary mapping node IDs to GraphVertex (Node or SubGraphNode). 

80 context_keys: Optional set of external dependency keys (from context). 

81 These are excluded from cycle detection. 

82 

83 Returns: 

84 True if cycle exists, False otherwise. 

85 """ 

86 node_ids = set(graph.keys()) 

87 context_keys = context_keys or set() 

88 WHITE = 0 # Unvisited 

89 GRAY = 1 # Currently in DFS path 

90 BLACK = 2 # Fully processed 

91 

92 color: dict[str, int] = {node_id: WHITE for node_id in node_ids} 

93 

94 def dfs(node_id: str) -> bool: 

95 """DFS helper that returns True if cycle found.""" 

96 if node_id not in node_ids: 

97 # This is a context dependency, not part of the graph - no cycle possible 

98 return False 

99 if color[node_id] == GRAY: 

100 # Back edge found - cycle detected 

101 return True 

102 if color[node_id] == BLACK: 

103 # Already processed 

104 return False 

105 

106 color[node_id] = GRAY 

107 node = graph[node_id] 

108 for dep in node.deps: 

109 # Only check dependencies that are in the graph (not context) 

110 if dep in node_ids: 

111 if dfs(dep): 

112 return True 

113 

114 color[node_id] = BLACK 

115 return False 

116 

117 # Check all nodes (handles disconnected components) 

118 for node_id in node_ids: 

119 if color[node_id] == WHITE: 

120 if dfs(node_id): 

121 return True 

122 

123 return False 

124 

125 def topological_sort( 

126 self, graph: Graph, context_keys: set[str] | None = None 

127 ) -> list[str]: 

128 """Topologically sort the graph using Kahn's algorithm. 

129 

130 Args: 

131 graph: Dictionary mapping node IDs to GraphVertex (Node or SubGraphNode). 

132 context_keys: Optional set of external dependency keys (from context). 

133 These are excluded from topological sorting. 

134 

135 Returns: 

136 List of node IDs in topological order (dependencies before dependents). 

137 

138 Raises: 

139 ValueError: If graph contains cycles. 

140 """ 

141 node_ids = set(graph.keys()) 

142 context_keys = context_keys or set() 

143 

144 # Build reverse dependency map: which nodes depend on each node 

145 # Only include dependencies that are in the graph (not context) 

146 dependents: dict[str, list[str]] = {node_id: [] for node_id in node_ids} 

147 for node_id, node in graph.items(): 

148 for dep in node.deps: 

149 # Only track dependencies that are in the graph 

150 if dep in node_ids: 

151 dependents[dep].append(node_id) 

152 

153 # Calculate in-degree for each node (number of graph dependencies it has) 

154 # Context dependencies don't count toward in-degree 

155 in_degree: dict[str, int] = {} 

156 for node_id, node in graph.items(): 

157 # Count only graph dependencies (not context) 

158 graph_deps = [d for d in node.deps if d in node_ids] 

159 in_degree[node_id] = len(graph_deps) 

160 

161 # Find all nodes with in-degree 0 (no graph dependencies) 

162 queue = deque([node_id for node_id in node_ids if in_degree[node_id] == 0]) 

163 result: list[str] = [] 

164 

165 while queue: 

166 node_id = queue.popleft() 

167 result.append(node_id) 

168 

169 # Reduce in-degree of nodes that depend on this node 

170 for dependent in dependents[node_id]: 

171 in_degree[dependent] -= 1 

172 if in_degree[dependent] == 0: 

173 queue.append(dependent) 

174 

175 # If we didn't process all nodes, there's a cycle 

176 if len(result) != len(node_ids): 

177 raise ValueError("Graph contains cycles (topological sort impossible)") 

178 

179 return result 

180 

181 def resolve(self, graph: Graph, context_keys: set[str] | None = None) -> list[str]: 

182 """Validate and topologically sort a graph. 

183 

184 Convenience method that validates then sorts. 

185 

186 Args: 

187 graph: Dictionary mapping node IDs to GraphVertex (Node or SubGraphNode). 

188 context_keys: Optional set of external dependency keys (from context). 

189 

190 Returns: 

191 List of node IDs in topological order. 

192 

193 Raises: 

194 ValueError: If validation fails or cycles exist. 

195 """ 

196 self.validate(graph, context_keys=context_keys) 

197 return self.topological_sort(graph, context_keys=context_keys)