Coverage for src / invariant / graph.py: 98.68%
76 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 10:21 +0100
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 10:21 +0100
1"""GraphResolver for parsing, validating, and sorting DAGs."""
3from collections import deque
4from typing import TYPE_CHECKING
6from invariant.node import Node, SubGraphNode
8if TYPE_CHECKING:
9 from invariant.registry import OpRegistry
11# Graph may contain regular nodes or subgraph nodes (internal DAGs).
12GraphVertex = Node | SubGraphNode
13Graph = dict[str, GraphVertex]
16class GraphResolver:
17 """Responsible for parsing graph definitions and ensuring valid DAGs.
19 Handles:
20 - Cycle detection
21 - Validation (missing dependencies, missing ops)
22 - Topological sorting
23 """
25 def __init__(self, registry: "OpRegistry | None" = None) -> None:
26 """Initialize GraphResolver.
28 Args:
29 registry: Optional OpRegistry for validating that ops exist.
30 If None, op validation is skipped.
31 """
32 self.registry = registry
34 def validate(self, graph: Graph, context_keys: set[str] | None = None) -> None:
35 """Validate a graph definition.
37 Checks:
38 - All node dependencies exist in the graph or in context
39 - All referenced ops are registered (if registry provided; Node only)
40 - No cycles exist
42 Args:
43 graph: Dictionary mapping node IDs to Node or SubGraphNode objects.
44 context_keys: Optional set of external dependency keys (from context).
45 Dependencies not in the graph are allowed if they're in context.
47 Raises:
48 ValueError: If validation fails (missing deps, missing ops, cycles).
49 """
50 # Check all dependencies exist
51 node_ids = set(graph.keys())
52 context_keys = context_keys or set()
53 for node_id, node in graph.items():
54 for dep in node.deps:
55 if dep not in node_ids and dep not in context_keys:
56 raise ValueError(
57 f"Node '{node_id}' has dependency '{dep}' that doesn't exist in graph "
58 f"or context. Available: graph={sorted(node_ids)}, "
59 f"context={sorted(context_keys)}"
60 )
62 # Check all ops are registered (if registry provided); only Node has op_name
63 if self.registry:
64 for node_id, node in graph.items():
65 if isinstance(node, Node):
66 if not self.registry.has(node.op_name):
67 raise ValueError(
68 f"Node '{node_id}' references unregistered op '{node.op_name}'"
69 )
71 # Check for cycles (excluding context dependencies)
72 if self._has_cycle(graph, context_keys=context_keys):
73 raise ValueError("Graph contains cycles")
75 def _has_cycle(self, graph: Graph, context_keys: set[str] | None = None) -> bool:
76 """Detect cycles in the graph using DFS.
78 Args:
79 graph: Dictionary mapping node IDs to GraphVertex (Node or SubGraphNode).
80 context_keys: Optional set of external dependency keys (from context).
81 These are excluded from cycle detection.
83 Returns:
84 True if cycle exists, False otherwise.
85 """
86 node_ids = set(graph.keys())
87 context_keys = context_keys or set()
88 WHITE = 0 # Unvisited
89 GRAY = 1 # Currently in DFS path
90 BLACK = 2 # Fully processed
92 color: dict[str, int] = {node_id: WHITE for node_id in node_ids}
94 def dfs(node_id: str) -> bool:
95 """DFS helper that returns True if cycle found."""
96 if node_id not in node_ids:
97 # This is a context dependency, not part of the graph - no cycle possible
98 return False
99 if color[node_id] == GRAY:
100 # Back edge found - cycle detected
101 return True
102 if color[node_id] == BLACK:
103 # Already processed
104 return False
106 color[node_id] = GRAY
107 node = graph[node_id]
108 for dep in node.deps:
109 # Only check dependencies that are in the graph (not context)
110 if dep in node_ids:
111 if dfs(dep):
112 return True
114 color[node_id] = BLACK
115 return False
117 # Check all nodes (handles disconnected components)
118 for node_id in node_ids:
119 if color[node_id] == WHITE:
120 if dfs(node_id):
121 return True
123 return False
125 def topological_sort(
126 self, graph: Graph, context_keys: set[str] | None = None
127 ) -> list[str]:
128 """Topologically sort the graph using Kahn's algorithm.
130 Args:
131 graph: Dictionary mapping node IDs to GraphVertex (Node or SubGraphNode).
132 context_keys: Optional set of external dependency keys (from context).
133 These are excluded from topological sorting.
135 Returns:
136 List of node IDs in topological order (dependencies before dependents).
138 Raises:
139 ValueError: If graph contains cycles.
140 """
141 node_ids = set(graph.keys())
142 context_keys = context_keys or set()
144 # Build reverse dependency map: which nodes depend on each node
145 # Only include dependencies that are in the graph (not context)
146 dependents: dict[str, list[str]] = {node_id: [] for node_id in node_ids}
147 for node_id, node in graph.items():
148 for dep in node.deps:
149 # Only track dependencies that are in the graph
150 if dep in node_ids:
151 dependents[dep].append(node_id)
153 # Calculate in-degree for each node (number of graph dependencies it has)
154 # Context dependencies don't count toward in-degree
155 in_degree: dict[str, int] = {}
156 for node_id, node in graph.items():
157 # Count only graph dependencies (not context)
158 graph_deps = [d for d in node.deps if d in node_ids]
159 in_degree[node_id] = len(graph_deps)
161 # Find all nodes with in-degree 0 (no graph dependencies)
162 queue = deque([node_id for node_id in node_ids if in_degree[node_id] == 0])
163 result: list[str] = []
165 while queue:
166 node_id = queue.popleft()
167 result.append(node_id)
169 # Reduce in-degree of nodes that depend on this node
170 for dependent in dependents[node_id]:
171 in_degree[dependent] -= 1
172 if in_degree[dependent] == 0:
173 queue.append(dependent)
175 # If we didn't process all nodes, there's a cycle
176 if len(result) != len(node_ids):
177 raise ValueError("Graph contains cycles (topological sort impossible)")
179 return result
181 def resolve(self, graph: Graph, context_keys: set[str] | None = None) -> list[str]:
182 """Validate and topologically sort a graph.
184 Convenience method that validates then sorts.
186 Args:
187 graph: Dictionary mapping node IDs to GraphVertex (Node or SubGraphNode).
188 context_keys: Optional set of external dependency keys (from context).
190 Returns:
191 List of node IDs in topological order.
193 Raises:
194 ValueError: If validation fails or cycles exist.
195 """
196 self.validate(graph, context_keys=context_keys)
197 return self.topological_sort(graph, context_keys=context_keys)