Coverage for src / invariant / executor.py: 91.38%
58 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-20 16:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-20 16:05 +0000
1"""Executor: The runtime engine for executing DAGs."""
3import inspect
4from typing import TYPE_CHECKING, Any
6from invariant.cacheable import is_cacheable
7from invariant.expressions import resolve_params
8from invariant.graph import GraphResolver
9from invariant.hashing import hash_manifest
11if TYPE_CHECKING:
12 from invariant.node import Node
13 from invariant.registry import OpRegistry
14 from invariant.store.base import ArtifactStore
17class Executor:
18 """Runtime engine for executing DAGs.
20 Manages the two-phase execution:
21 - Phase 1: Context Resolution (Graph -> Manifest)
22 - Phase 2: Action Execution (Manifest -> Artifact)
23 """
25 def __init__(
26 self,
27 registry: "OpRegistry",
28 store: "ArtifactStore",
29 resolver: "GraphResolver | None" = None,
30 ) -> None:
31 """Initialize Executor.
33 Args:
34 registry: OpRegistry for looking up operations.
35 store: ArtifactStore for caching artifacts.
36 resolver: Optional GraphResolver. If None, creates one with registry.
37 """
38 self.registry = registry
39 self.store = store
40 self.resolver = resolver or GraphResolver(registry)
42 def execute(
43 self, graph: dict[str, "Node"], context: dict[str, Any] | None = None
44 ) -> dict[str, Any]:
45 """Execute a graph and return artifacts for each node.
47 Args:
48 graph: Dictionary mapping node IDs to Node objects.
49 context: Optional dictionary of external dependencies (values not in graph).
50 These are injected as artifacts available to any node that declares
51 them in deps.
53 Returns:
54 Dictionary mapping node IDs to their resulting artifacts.
56 Raises:
57 ValueError: If graph validation fails or execution errors occur.
58 """
59 # Validate and sort graph (pass context for validation)
60 context = context or {}
61 sorted_nodes = self.resolver.resolve(graph, context_keys=set(context.keys()))
63 # Track artifacts by node ID
64 artifacts_by_node: dict[str, Any] = {}
66 # Inject context values into artifacts_by_node before execution
67 # This makes external dependencies available to any node that declares them in deps
68 for key, value in context.items():
69 # Context values must be cacheable
70 if not is_cacheable(value):
71 raise ValueError(
72 f"Context value for '{key}' is not cacheable, got {type(value)}"
73 )
74 # Store native types as-is (no wrapping)
75 artifacts_by_node[key] = value
77 # Execute nodes in topological order
78 for node_id in sorted_nodes:
79 node = graph[node_id]
81 # Phase 1: Build manifest
82 manifest = self._build_manifest(node, node_id, graph, artifacts_by_node)
83 digest = hash_manifest(manifest)
85 # Phase 2: Execute or retrieve from cache
86 if self.store.exists(node.op_name, digest):
87 # Cache hit: retrieve from store
88 artifact = self.store.get(node.op_name, digest)
89 else:
90 # Cache miss: execute operation
91 op = self.registry.get(node.op_name)
92 artifact = self._invoke_op(op, node.op_name, manifest)
94 # Persist to store
95 self.store.put(node.op_name, digest, artifact)
97 artifacts_by_node[node_id] = artifact
99 return artifacts_by_node
101 def _build_manifest(
102 self,
103 node: "Node",
104 node_id: str,
105 graph: dict[str, "Node"],
106 artifacts_by_node: dict[str, Any],
107 ) -> dict[str, Any]:
108 """Build the input manifest for a node (Phase 1).
110 The manifest is built entirely from resolved params. Dependencies are NOT
111 injected into the manifest directly - they are only available for ref()/cel()
112 resolution within params.
114 Args:
115 node: The node to build manifest for.
116 node_id: The ID of the node.
117 graph: The full graph (for reference).
118 artifacts_by_node: Already computed artifacts for upstream nodes.
120 Returns:
121 The manifest dictionary mapping parameter names to resolved values.
122 """
123 # Collect dependency artifacts for ref()/cel() resolution
124 dependencies: dict[str, Any] = {}
125 for dep_id in node.deps:
126 if dep_id not in artifacts_by_node:
127 raise ValueError(
128 f"Node '{node_id}' depends on '{dep_id}' but artifact not found. "
129 f"This should not happen if graph is topologically sorted or "
130 f"if '{dep_id}' is provided in context."
131 )
132 dependencies[dep_id] = artifacts_by_node[dep_id]
134 # Manifest = resolved params only. No dependency injection.
135 # ref() and cel() markers in params are resolved using dependencies.
136 return resolve_params(node.params, dependencies)
138 def _invoke_op(self, op: Any, op_name: str, manifest: dict[str, Any]) -> Any:
139 """Invoke an operation with kwargs dispatch and return validation.
141 Args:
142 op: The callable operation to invoke.
143 op_name: The name of the operation (for error messages).
144 manifest: The manifest dictionary mapping parameter names to values.
146 Returns:
147 The operation result (native type or ICacheable domain type).
149 Raises:
150 ValueError: If required parameters are missing.
151 TypeError: If return value is not cacheable.
152 """
153 # Inspect function signature to map manifest keys to function parameters
154 sig = inspect.signature(op)
155 kwargs: dict[str, Any] = {}
157 # Map manifest keys to function parameters by name
158 for name, param in sig.parameters.items():
159 if name in manifest:
160 value = manifest[name]
161 kwargs[name] = value
162 elif param.default is not inspect.Parameter.empty:
163 # Parameter has a default value, skip it
164 pass
165 elif param.kind == inspect.Parameter.VAR_KEYWORD:
166 # Function accepts **kwargs, will handle below
167 pass
168 else:
169 # Required parameter missing
170 raise ValueError(f"Op '{op_name}': missing required parameter '{name}'")
172 # If function has **kwargs, pass remaining manifest keys
173 has_var_kwargs = any(
174 p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
175 )
176 if has_var_kwargs:
177 for key, val in manifest.items():
178 if key not in kwargs:
179 kwargs[key] = val
181 # Invoke the operation
182 result = op(**kwargs)
184 # Validate return value is cacheable
185 if not is_cacheable(result):
186 raise TypeError(
187 f"Op '{op_name}' returned {type(result).__name__}, "
188 f"which is not a cacheable type"
189 )
191 # Return as-is (no wrapping needed)
192 return result