Coverage for src / invariant / executor.py: 91.18%

68 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 10:21 +0100

1"""Executor: The runtime engine for executing DAGs.""" 

2 

3import inspect 

4from typing import TYPE_CHECKING, Any 

5 

6from invariant.cacheable import is_cacheable 

7from invariant.expressions import resolve_params 

8from invariant.graph import Graph, GraphResolver 

9from invariant.hashing import hash_manifest 

10from invariant.node import Node, SubGraphNode 

11 

12if TYPE_CHECKING: 

13 from invariant.registry import OpRegistry 

14 from invariant.store.base import ArtifactStore 

15 

16 

17class Executor: 

18 """Runtime engine for executing DAGs. 

19 

20 Manages the two-phase execution: 

21 - Phase 1: Context Resolution (Graph -> Manifest) 

22 - Phase 2: Action Execution (Manifest -> Artifact) 

23 """ 

24 

25 def __init__( 

26 self, 

27 registry: "OpRegistry", 

28 store: "ArtifactStore", 

29 resolver: "GraphResolver | None" = None, 

30 ) -> None: 

31 """Initialize Executor. 

32 

33 Args: 

34 registry: OpRegistry for looking up operations. 

35 store: ArtifactStore for caching artifacts. 

36 resolver: Optional GraphResolver. If None, creates one with registry. 

37 """ 

38 self.registry = registry 

39 self.store = store 

40 self.resolver = resolver or GraphResolver(registry) 

41 

42 def execute( 

43 self, graph: Graph, context: dict[str, Any] | None = None 

44 ) -> dict[str, Any]: 

45 """Execute a graph and return artifacts for each node. 

46 

47 Args: 

48 graph: Dictionary mapping node IDs to Node or SubGraphNode objects. 

49 context: Optional dictionary of external dependencies (values not in graph). 

50 These are injected as artifacts available to any node that declares 

51 them in deps. 

52 

53 Returns: 

54 Dictionary mapping node IDs to their resulting artifacts. 

55 

56 Raises: 

57 ValueError: If graph validation fails or execution errors occur. 

58 """ 

59 # Validate and sort graph (pass context for validation) 

60 context = context or {} 

61 sorted_nodes = self.resolver.resolve(graph, context_keys=set(context.keys())) 

62 

63 # Track artifacts by node ID 

64 artifacts_by_node: dict[str, Any] = {} 

65 

66 # Inject context values into artifacts_by_node before execution 

67 # This makes external dependencies available to any node that declares them in deps 

68 for key, value in context.items(): 

69 # Context values must be cacheable 

70 if not is_cacheable(value): 

71 raise ValueError( 

72 f"Context value for '{key}' is not cacheable, got {type(value)}" 

73 ) 

74 # Store native types as-is (no wrapping) 

75 artifacts_by_node[key] = value 

76 

77 # Execute nodes in topological order 

78 for node_id in sorted_nodes: 

79 node = graph[node_id] 

80 

81 if isinstance(node, SubGraphNode): 

82 # SubGraphNode: run internal graph with resolved params as context 

83 manifest = self._build_manifest(node, node_id, graph, artifacts_by_node) 

84 inner_results = self.execute(node.graph, context=manifest) 

85 if node.output not in inner_results: 

86 raise ValueError( 

87 f"SubGraphNode '{node_id}' output '{node.output}' not in " 

88 f"internal results. Keys: {list(inner_results.keys())}." 

89 ) 

90 artifacts_by_node[node_id] = inner_results[node.output] 

91 else: 

92 # Node: Phase 1 build manifest, Phase 2 cache lookup or execute op 

93 manifest = self._build_manifest(node, node_id, graph, artifacts_by_node) 

94 if not node.cache: 

95 # Ephemeral node: always execute, never cache 

96 op = self.registry.get(node.op_name) 

97 artifact = self._invoke_op(op, node.op_name, manifest) 

98 else: 

99 digest = hash_manifest(manifest) 

100 if self.store.exists(node.op_name, digest): 

101 artifact = self.store.get(node.op_name, digest) 

102 else: 

103 op = self.registry.get(node.op_name) 

104 artifact = self._invoke_op(op, node.op_name, manifest) 

105 self.store.put(node.op_name, digest, artifact) 

106 artifacts_by_node[node_id] = artifact 

107 

108 return artifacts_by_node 

109 

110 def _build_manifest( 

111 self, 

112 node: Node | SubGraphNode, 

113 node_id: str, 

114 graph: Graph, 

115 artifacts_by_node: dict[str, Any], 

116 ) -> dict[str, Any]: 

117 """Build the input manifest for a node (Phase 1). 

118 

119 The manifest is built entirely from resolved params. Dependencies are NOT 

120 injected into the manifest directly - they are only available for ref()/cel() 

121 resolution within params. 

122 

123 Args: 

124 node: The node to build manifest for. 

125 node_id: The ID of the node. 

126 graph: The full graph (for reference). 

127 artifacts_by_node: Already computed artifacts for upstream nodes. 

128 

129 Returns: 

130 The manifest dictionary mapping parameter names to resolved values. 

131 """ 

132 # Collect dependency artifacts for ref()/cel() resolution 

133 dependencies: dict[str, Any] = {} 

134 for dep_id in node.deps: 

135 if dep_id not in artifacts_by_node: 

136 raise ValueError( 

137 f"Node '{node_id}' depends on '{dep_id}' but artifact not found. " 

138 f"This should not happen if graph is topologically sorted or " 

139 f"if '{dep_id}' is provided in context." 

140 ) 

141 dependencies[dep_id] = artifacts_by_node[dep_id] 

142 

143 # Manifest = resolved params only. No dependency injection. 

144 # ref() and cel() markers in params are resolved using dependencies. 

145 return resolve_params(node.params, dependencies) 

146 

147 def _invoke_op(self, op: Any, op_name: str, manifest: dict[str, Any]) -> Any: 

148 """Invoke an operation with kwargs dispatch and return validation. 

149 

150 Args: 

151 op: The callable operation to invoke. 

152 op_name: The name of the operation (for error messages). 

153 manifest: The manifest dictionary mapping parameter names to values. 

154 

155 Returns: 

156 The operation result (native type or ICacheable domain type). 

157 

158 Raises: 

159 ValueError: If required parameters are missing. 

160 TypeError: If return value is not cacheable. 

161 """ 

162 # Inspect function signature to map manifest keys to function parameters 

163 sig = inspect.signature(op) 

164 kwargs: dict[str, Any] = {} 

165 

166 # Map manifest keys to function parameters by name 

167 for name, param in sig.parameters.items(): 

168 if name in manifest: 

169 value = manifest[name] 

170 kwargs[name] = value 

171 elif param.default is not inspect.Parameter.empty: 

172 # Parameter has a default value, skip it 

173 pass 

174 elif param.kind == inspect.Parameter.VAR_KEYWORD: 

175 # Function accepts **kwargs, will handle below 

176 pass 

177 else: 

178 # Required parameter missing 

179 raise ValueError(f"Op '{op_name}': missing required parameter '{name}'") 

180 

181 # If function has **kwargs, pass remaining manifest keys 

182 has_var_kwargs = any( 

183 p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() 

184 ) 

185 if has_var_kwargs: 

186 for key, val in manifest.items(): 

187 if key not in kwargs: 

188 kwargs[key] = val 

189 

190 # Invoke the operation 

191 result = op(**kwargs) 

192 

193 # Validate return value is cacheable 

194 if not is_cacheable(result): 

195 raise TypeError( 

196 f"Op '{op_name}' returned {type(result).__name__}, " 

197 f"which is not a cacheable type" 

198 ) 

199 

200 # Return as-is (no wrapping needed) 

201 return result