Coverage for src / invariant / node.py: 96.97%

33 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-20 16:05 +0000

1"""Node class representing a vertex in the DAG.""" 

2 

3from dataclasses import dataclass 

4from typing import Any 

5 

6from invariant.params import ref 

7 

8 

9@dataclass(frozen=True) 

10class Node: 

11 """A vertex in the DAG defining what operation to perform. 

12 

13 Attributes: 

14 op_name: The name of the operation to execute (must be registered). 

15 params: Static parameters for this node (dict of parameter name -> value). 

16 May contain ref() and cel() markers, and ${...} string interpolation. 

17 deps: List of node IDs that this node depends on (upstream dependencies). 

18 """ 

19 

20 op_name: str 

21 params: dict[str, Any] 

22 deps: list[str] 

23 

24 def __post_init__(self) -> None: 

25 """Validate node configuration.""" 

26 if not self.op_name: 

27 raise ValueError("op_name cannot be empty") 

28 if not isinstance(self.params, dict): 

29 raise ValueError("params must be a dictionary") 

30 if not isinstance(self.deps, list): 

31 raise ValueError("deps must be a list") 

32 

33 # Validate that all ref() markers reference declared dependencies 

34 self._validate_refs() 

35 

36 def _validate_refs(self) -> None: 

37 """Validate that all ref() markers in params reference declared dependencies.""" 

38 deps_set = set(self.deps) 

39 refs = self._collect_refs(self.params) 

40 

41 for ref_marker in refs: 

42 if ref_marker.dep not in deps_set: 

43 raise ValueError( 

44 f"ref('{ref_marker.dep}') in params references undeclared dependency. " 

45 f"Declared deps: {self.deps}. " 

46 f"Add '{ref_marker.dep}' to deps list." 

47 ) 

48 

49 def _collect_refs(self, value: Any) -> list[ref]: 

50 """Recursively collect all ref() markers from a value.""" 

51 refs: list[ref] = [] 

52 if isinstance(value, ref): 

53 refs.append(value) 

54 elif isinstance(value, dict): 

55 for v in value.values(): 

56 refs.extend(self._collect_refs(v)) 

57 elif isinstance(value, list): 

58 for item in value: 

59 refs.extend(self._collect_refs(item)) 

60 return refs