Coverage for src / invariant / node.py: 98.25%

57 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 10:21 +0100

1"""Node and SubGraphNode classes representing vertices in the DAG.""" 

2 

3from __future__ import annotations 

4 

5from dataclasses import dataclass 

6from typing import Any 

7 

8from invariant.params import ref 

9 

10 

11def _collect_refs(value: Any) -> list[ref]: 

12 """Recursively collect all ref() markers from a value.""" 

13 refs: list[ref] = [] 

14 if isinstance(value, ref): 

15 refs.append(value) 

16 elif isinstance(value, dict): 

17 for v in value.values(): 

18 refs.extend(_collect_refs(v)) 

19 elif isinstance(value, list): 

20 for item in value: 

21 refs.extend(_collect_refs(item)) 

22 return refs 

23 

24 

25@dataclass(frozen=True) 

26class Node: 

27 """A vertex in the DAG defining what operation to perform. 

28 

29 Attributes: 

30 op_name: The name of the operation to execute (must be registered). 

31 params: Static parameters for this node (dict of parameter name -> value). 

32 May contain ref() and cel() markers, and ${...} string interpolation. 

33 deps: List of node IDs that this node depends on (upstream dependencies). 

34 cache: When True (default), the node's result is cached. When False, the op 

35 is always executed and the result is never stored (ephemeral node). 

36 """ 

37 

38 op_name: str 

39 params: dict[str, Any] 

40 deps: list[str] 

41 cache: bool = True 

42 

43 def __post_init__(self) -> None: 

44 """Validate node configuration.""" 

45 if not self.op_name: 

46 raise ValueError("op_name cannot be empty") 

47 if not isinstance(self.params, dict): 

48 raise ValueError("params must be a dictionary") 

49 if not isinstance(self.deps, list): 

50 raise ValueError("deps must be a list") 

51 

52 # Validate that all ref() markers reference declared dependencies 

53 self._validate_refs() 

54 

55 def _validate_refs(self) -> None: 

56 """Validate that all ref() markers in params reference declared dependencies.""" 

57 deps_set = set(self.deps) 

58 refs = _collect_refs(self.params) 

59 

60 for ref_marker in refs: 

61 if ref_marker.dep not in deps_set: 

62 raise ValueError( 

63 f"ref('{ref_marker.dep}') in params references undeclared dependency. " 

64 f"Declared deps: {self.deps}. " 

65 f"Add '{ref_marker.dep}' to deps list." 

66 ) 

67 

68 

69@dataclass(frozen=True) 

70class SubGraphNode: 

71 """A vertex that expands to an internal DAG at execution time. 

72 

73 Has deps and params like Node, but carries an internal graph and output node ID 

74 instead of an op_name. The executor runs the internal graph with resolved params 

75 as context and returns the designated output node's artifact. 

76 """ 

77 

78 params: dict[str, Any] 

79 deps: list[str] 

80 graph: dict[str, Node | SubGraphNode] 

81 output: str 

82 

83 def __post_init__(self) -> None: 

84 """Validate SubGraphNode configuration.""" 

85 if not isinstance(self.params, dict): 

86 raise ValueError("params must be a dictionary") 

87 if not isinstance(self.deps, list): 

88 raise ValueError("deps must be a list") 

89 if not isinstance(self.graph, dict): 

90 raise ValueError("graph must be a dictionary") 

91 if self.output not in self.graph: 

92 raise ValueError( 

93 f"output '{self.output}' must be a key in graph. " 

94 f"Graph keys: {list(self.graph.keys())}." 

95 ) 

96 self._validate_refs() 

97 

98 def _validate_refs(self) -> None: 

99 """Validate that all ref() markers in params reference declared dependencies.""" 

100 deps_set = set(self.deps) 

101 refs = _collect_refs(self.params) 

102 for ref_marker in refs: 

103 if ref_marker.dep not in deps_set: 

104 raise ValueError( 

105 f"ref('{ref_marker.dep}') in params references undeclared dependency. " 

106 f"Declared deps: {self.deps}. " 

107 f"Add '{ref_marker.dep}' to deps list." 

108 )