Coverage for src / invariant / graph.py: 98.63%
73 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-20 16:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-20 16:05 +0000
1"""GraphResolver for parsing, validating, and sorting DAGs."""
3from collections import deque
4from typing import TYPE_CHECKING
6from invariant.node import Node
8if TYPE_CHECKING:
9 from invariant.registry import OpRegistry
12class GraphResolver:
13 """Responsible for parsing graph definitions and ensuring valid DAGs.
15 Handles:
16 - Cycle detection
17 - Validation (missing dependencies, missing ops)
18 - Topological sorting
19 """
21 def __init__(self, registry: "OpRegistry | None" = None) -> None:
22 """Initialize GraphResolver.
24 Args:
25 registry: Optional OpRegistry for validating that ops exist.
26 If None, op validation is skipped.
27 """
28 self.registry = registry
30 def validate(
31 self, graph: dict[str, Node], context_keys: set[str] | None = None
32 ) -> None:
33 """Validate a graph definition.
35 Checks:
36 - All node dependencies exist in the graph or in context
37 - All referenced ops are registered (if registry provided)
38 - No cycles exist
40 Args:
41 graph: Dictionary mapping node IDs to Node objects.
42 context_keys: Optional set of external dependency keys (from context).
43 Dependencies not in the graph are allowed if they're in context.
45 Raises:
46 ValueError: If validation fails (missing deps, missing ops, cycles).
47 """
48 # Check all dependencies exist
49 node_ids = set(graph.keys())
50 context_keys = context_keys or set()
51 for node_id, node in graph.items():
52 for dep in node.deps:
53 if dep not in node_ids and dep not in context_keys:
54 raise ValueError(
55 f"Node '{node_id}' has dependency '{dep}' that doesn't exist in graph "
56 f"or context. Available: graph={sorted(node_ids)}, "
57 f"context={sorted(context_keys)}"
58 )
60 # Check all ops are registered (if registry provided)
61 if self.registry:
62 for node_id, node in graph.items():
63 if not self.registry.has(node.op_name):
64 raise ValueError(
65 f"Node '{node_id}' references unregistered op '{node.op_name}'"
66 )
68 # Check for cycles (excluding context dependencies)
69 if self._has_cycle(graph, context_keys=context_keys):
70 raise ValueError("Graph contains cycles")
72 def _has_cycle(
73 self, graph: dict[str, Node], context_keys: set[str] | None = None
74 ) -> bool:
75 """Detect cycles in the graph using DFS.
77 Args:
78 graph: Dictionary mapping node IDs to Node objects.
79 context_keys: Optional set of external dependency keys (from context).
80 These are excluded from cycle detection.
82 Returns:
83 True if cycle exists, False otherwise.
84 """
85 node_ids = set(graph.keys())
86 context_keys = context_keys or set()
87 WHITE = 0 # Unvisited
88 GRAY = 1 # Currently in DFS path
89 BLACK = 2 # Fully processed
91 color: dict[str, int] = {node_id: WHITE for node_id in node_ids}
93 def dfs(node_id: str) -> bool:
94 """DFS helper that returns True if cycle found."""
95 if node_id not in node_ids:
96 # This is a context dependency, not part of the graph - no cycle possible
97 return False
98 if color[node_id] == GRAY:
99 # Back edge found - cycle detected
100 return True
101 if color[node_id] == BLACK:
102 # Already processed
103 return False
105 color[node_id] = GRAY
106 node = graph[node_id]
107 for dep in node.deps:
108 # Only check dependencies that are in the graph (not context)
109 if dep in node_ids:
110 if dfs(dep):
111 return True
113 color[node_id] = BLACK
114 return False
116 # Check all nodes (handles disconnected components)
117 for node_id in node_ids:
118 if color[node_id] == WHITE:
119 if dfs(node_id):
120 return True
122 return False
124 def topological_sort(
125 self, graph: dict[str, Node], context_keys: set[str] | None = None
126 ) -> list[str]:
127 """Topologically sort the graph using Kahn's algorithm.
129 Args:
130 graph: Dictionary mapping node IDs to Node objects.
131 context_keys: Optional set of external dependency keys (from context).
132 These are excluded from topological sorting.
134 Returns:
135 List of node IDs in topological order (dependencies before dependents).
137 Raises:
138 ValueError: If graph contains cycles.
139 """
140 node_ids = set(graph.keys())
141 context_keys = context_keys or set()
143 # Build reverse dependency map: which nodes depend on each node
144 # Only include dependencies that are in the graph (not context)
145 dependents: dict[str, list[str]] = {node_id: [] for node_id in node_ids}
146 for node_id, node in graph.items():
147 for dep in node.deps:
148 # Only track dependencies that are in the graph
149 if dep in node_ids:
150 dependents[dep].append(node_id)
152 # Calculate in-degree for each node (number of graph dependencies it has)
153 # Context dependencies don't count toward in-degree
154 in_degree: dict[str, int] = {}
155 for node_id, node in graph.items():
156 # Count only graph dependencies (not context)
157 graph_deps = [d for d in node.deps if d in node_ids]
158 in_degree[node_id] = len(graph_deps)
160 # Find all nodes with in-degree 0 (no graph dependencies)
161 queue = deque([node_id for node_id in node_ids if in_degree[node_id] == 0])
162 result: list[str] = []
164 while queue:
165 node_id = queue.popleft()
166 result.append(node_id)
168 # Reduce in-degree of nodes that depend on this node
169 for dependent in dependents[node_id]:
170 in_degree[dependent] -= 1
171 if in_degree[dependent] == 0:
172 queue.append(dependent)
174 # If we didn't process all nodes, there's a cycle
175 if len(result) != len(node_ids):
176 raise ValueError("Graph contains cycles (topological sort impossible)")
178 return result
180 def resolve(
181 self, graph: dict[str, Node], context_keys: set[str] | None = None
182 ) -> list[str]:
183 """Validate and topologically sort a graph.
185 Convenience method that validates then sorts.
187 Args:
188 graph: Dictionary mapping node IDs to Node objects.
189 context_keys: Optional set of external dependency keys (from context).
191 Returns:
192 List of node IDs in topological order.
194 Raises:
195 ValueError: If validation fails or cycles exist.
196 """
197 self.validate(graph, context_keys=context_keys)
198 return self.topological_sort(graph, context_keys=context_keys)