Coverage for src / invariant / executor.py: 91.38%

58 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-20 16:05 +0000

1"""Executor: The runtime engine for executing DAGs.""" 

2 

3import inspect 

4from typing import TYPE_CHECKING, Any 

5 

6from invariant.cacheable import is_cacheable 

7from invariant.expressions import resolve_params 

8from invariant.graph import GraphResolver 

9from invariant.hashing import hash_manifest 

10 

11if TYPE_CHECKING: 

12 from invariant.node import Node 

13 from invariant.registry import OpRegistry 

14 from invariant.store.base import ArtifactStore 

15 

16 

17class Executor: 

18 """Runtime engine for executing DAGs. 

19 

20 Manages the two-phase execution: 

21 - Phase 1: Context Resolution (Graph -> Manifest) 

22 - Phase 2: Action Execution (Manifest -> Artifact) 

23 """ 

24 

25 def __init__( 

26 self, 

27 registry: "OpRegistry", 

28 store: "ArtifactStore", 

29 resolver: "GraphResolver | None" = None, 

30 ) -> None: 

31 """Initialize Executor. 

32 

33 Args: 

34 registry: OpRegistry for looking up operations. 

35 store: ArtifactStore for caching artifacts. 

36 resolver: Optional GraphResolver. If None, creates one with registry. 

37 """ 

38 self.registry = registry 

39 self.store = store 

40 self.resolver = resolver or GraphResolver(registry) 

41 

42 def execute( 

43 self, graph: dict[str, "Node"], context: dict[str, Any] | None = None 

44 ) -> dict[str, Any]: 

45 """Execute a graph and return artifacts for each node. 

46 

47 Args: 

48 graph: Dictionary mapping node IDs to Node objects. 

49 context: Optional dictionary of external dependencies (values not in graph). 

50 These are injected as artifacts available to any node that declares 

51 them in deps. 

52 

53 Returns: 

54 Dictionary mapping node IDs to their resulting artifacts. 

55 

56 Raises: 

57 ValueError: If graph validation fails or execution errors occur. 

58 """ 

59 # Validate and sort graph (pass context for validation) 

60 context = context or {} 

61 sorted_nodes = self.resolver.resolve(graph, context_keys=set(context.keys())) 

62 

63 # Track artifacts by node ID 

64 artifacts_by_node: dict[str, Any] = {} 

65 

66 # Inject context values into artifacts_by_node before execution 

67 # This makes external dependencies available to any node that declares them in deps 

68 for key, value in context.items(): 

69 # Context values must be cacheable 

70 if not is_cacheable(value): 

71 raise ValueError( 

72 f"Context value for '{key}' is not cacheable, got {type(value)}" 

73 ) 

74 # Store native types as-is (no wrapping) 

75 artifacts_by_node[key] = value 

76 

77 # Execute nodes in topological order 

78 for node_id in sorted_nodes: 

79 node = graph[node_id] 

80 

81 # Phase 1: Build manifest 

82 manifest = self._build_manifest(node, node_id, graph, artifacts_by_node) 

83 digest = hash_manifest(manifest) 

84 

85 # Phase 2: Execute or retrieve from cache 

86 if self.store.exists(node.op_name, digest): 

87 # Cache hit: retrieve from store 

88 artifact = self.store.get(node.op_name, digest) 

89 else: 

90 # Cache miss: execute operation 

91 op = self.registry.get(node.op_name) 

92 artifact = self._invoke_op(op, node.op_name, manifest) 

93 

94 # Persist to store 

95 self.store.put(node.op_name, digest, artifact) 

96 

97 artifacts_by_node[node_id] = artifact 

98 

99 return artifacts_by_node 

100 

101 def _build_manifest( 

102 self, 

103 node: "Node", 

104 node_id: str, 

105 graph: dict[str, "Node"], 

106 artifacts_by_node: dict[str, Any], 

107 ) -> dict[str, Any]: 

108 """Build the input manifest for a node (Phase 1). 

109 

110 The manifest is built entirely from resolved params. Dependencies are NOT 

111 injected into the manifest directly - they are only available for ref()/cel() 

112 resolution within params. 

113 

114 Args: 

115 node: The node to build manifest for. 

116 node_id: The ID of the node. 

117 graph: The full graph (for reference). 

118 artifacts_by_node: Already computed artifacts for upstream nodes. 

119 

120 Returns: 

121 The manifest dictionary mapping parameter names to resolved values. 

122 """ 

123 # Collect dependency artifacts for ref()/cel() resolution 

124 dependencies: dict[str, Any] = {} 

125 for dep_id in node.deps: 

126 if dep_id not in artifacts_by_node: 

127 raise ValueError( 

128 f"Node '{node_id}' depends on '{dep_id}' but artifact not found. " 

129 f"This should not happen if graph is topologically sorted or " 

130 f"if '{dep_id}' is provided in context." 

131 ) 

132 dependencies[dep_id] = artifacts_by_node[dep_id] 

133 

134 # Manifest = resolved params only. No dependency injection. 

135 # ref() and cel() markers in params are resolved using dependencies. 

136 return resolve_params(node.params, dependencies) 

137 

138 def _invoke_op(self, op: Any, op_name: str, manifest: dict[str, Any]) -> Any: 

139 """Invoke an operation with kwargs dispatch and return validation. 

140 

141 Args: 

142 op: The callable operation to invoke. 

143 op_name: The name of the operation (for error messages). 

144 manifest: The manifest dictionary mapping parameter names to values. 

145 

146 Returns: 

147 The operation result (native type or ICacheable domain type). 

148 

149 Raises: 

150 ValueError: If required parameters are missing. 

151 TypeError: If return value is not cacheable. 

152 """ 

153 # Inspect function signature to map manifest keys to function parameters 

154 sig = inspect.signature(op) 

155 kwargs: dict[str, Any] = {} 

156 

157 # Map manifest keys to function parameters by name 

158 for name, param in sig.parameters.items(): 

159 if name in manifest: 

160 value = manifest[name] 

161 kwargs[name] = value 

162 elif param.default is not inspect.Parameter.empty: 

163 # Parameter has a default value, skip it 

164 pass 

165 elif param.kind == inspect.Parameter.VAR_KEYWORD: 

166 # Function accepts **kwargs, will handle below 

167 pass 

168 else: 

169 # Required parameter missing 

170 raise ValueError(f"Op '{op_name}': missing required parameter '{name}'") 

171 

172 # If function has **kwargs, pass remaining manifest keys 

173 has_var_kwargs = any( 

174 p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() 

175 ) 

176 if has_var_kwargs: 

177 for key, val in manifest.items(): 

178 if key not in kwargs: 

179 kwargs[key] = val 

180 

181 # Invoke the operation 

182 result = op(**kwargs) 

183 

184 # Validate return value is cacheable 

185 if not is_cacheable(result): 

186 raise TypeError( 

187 f"Op '{op_name}' returned {type(result).__name__}, " 

188 f"which is not a cacheable type" 

189 ) 

190 

191 # Return as-is (no wrapping needed) 

192 return result