Coverage for src / invariant / registry.py: 85.94%

64 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 10:21 +0100

1"""OpRegistry for mapping operation names to callables.""" 

2 

3import types 

4from importlib.metadata import entry_points 

5from typing import Any, Callable 

6 

7# Type alias for op packages: dict mapping short names to op callables 

8OpPackage = dict[str, Callable[..., Any]] 

9 

10 

11class OpRegistry: 

12 """Singleton registry mapping string identifiers to executable Python callables. 

13 

14 Decouples the "string" name in the graph definition from the actual Python code. 

15 """ 

16 

17 _instance: "OpRegistry | None" = None 

18 _initialized: bool = False 

19 

20 def __new__(cls) -> "OpRegistry": 

21 """Ensure singleton pattern.""" 

22 if cls._instance is None: 

23 cls._instance = super().__new__(cls) 

24 return cls._instance 

25 

26 def __init__(self) -> None: 

27 """Initialize the registry (only once).""" 

28 if not OpRegistry._initialized: 

29 self._ops: dict[str, Callable[..., Any]] = {} 

30 OpRegistry._initialized = True 

31 

32 def register(self, name: str, op: Callable[..., Any]) -> None: 

33 """Register an operation. 

34 

35 Args: 

36 name: The string identifier for the operation. 

37 op: The callable that implements the operation. 

38 Should be a plain Python function with typed parameters. 

39 

40 Raises: 

41 ValueError: If name is empty or already registered. 

42 """ 

43 if not name: 

44 raise ValueError("Operation name cannot be empty") 

45 if name in self._ops: 

46 raise ValueError(f"Operation '{name}' is already registered") 

47 self._ops[name] = op 

48 

49 def get(self, name: str) -> Callable[..., Any]: 

50 """Get an operation by name. 

51 

52 Args: 

53 name: The string identifier for the operation. 

54 

55 Returns: 

56 The callable that implements the operation. 

57 

58 Raises: 

59 KeyError: If operation is not registered. 

60 """ 

61 if name not in self._ops: 

62 raise KeyError(f"Operation '{name}' is not registered") 

63 return self._ops[name] 

64 

65 def has(self, name: str) -> bool: 

66 """Check if an operation is registered. 

67 

68 Args: 

69 name: The string identifier for the operation. 

70 

71 Returns: 

72 True if registered, False otherwise. 

73 """ 

74 return name in self._ops 

75 

76 def clear(self) -> None: 

77 """Clear all registered operations (mainly for testing).""" 

78 self._ops.clear() 

79 

80 def register_package(self, prefix: str, ops: OpPackage | Any) -> None: 

81 """Register all ops from a package under a common prefix. 

82 

83 Args: 

84 prefix: The namespace prefix (e.g. "poly"). 

85 ops: Either a dict mapping short names to callables (OpPackage), 

86 or a Python module that has an OPS dict attribute. 

87 

88 Raises: 

89 ValueError: If prefix is empty, ops is invalid, or any operation 

90 name is already registered. 

91 AttributeError: If ops is a module but doesn't have an OPS attribute. 

92 """ 

93 if not prefix: 

94 raise ValueError("Package prefix cannot be empty") 

95 

96 # Extract the ops dict from the input 

97 ops_dict: OpPackage 

98 if isinstance(ops, dict): 

99 ops_dict = ops 

100 elif isinstance(ops, types.ModuleType): 

101 # It's a module - check for OPS attribute 

102 if not hasattr(ops, "OPS"): 

103 raise AttributeError( 

104 f"Module {ops.__name__} does not have an OPS attribute" 

105 ) 

106 ops_dict = ops.OPS 

107 if not isinstance(ops_dict, dict): 

108 raise ValueError(f"OPS attribute must be a dict, got {type(ops_dict)}") 

109 elif hasattr(ops, "OPS"): 

110 # Object with OPS attribute (not a module) 

111 ops_dict = ops.OPS 

112 if not isinstance(ops_dict, dict): 

113 raise ValueError(f"OPS attribute must be a dict, got {type(ops_dict)}") 

114 else: 

115 raise ValueError( 

116 f"ops must be a dict or module with OPS attribute, got {type(ops)}" 

117 ) 

118 

119 # Register each op with the prefix 

120 for name, op in ops_dict.items(): 

121 full_name = f"{prefix}:{name}" 

122 self.register(full_name, op) 

123 

124 def auto_discover(self) -> None: 

125 """Discover and register op packages from entry points. 

126 

127 Scans the 'invariant.ops' entry point group. Each entry point 

128 should resolve to either: 

129 - A dict[str, Callable] (the OPS dict directly) 

130 - A callable that returns such a dict 

131 

132 The entry point name becomes the package prefix. 

133 

134 Raises: 

135 ValueError: If any operation name is already registered (via register_package). 

136 """ 

137 eps = entry_points(group="invariant.ops") 

138 

139 for ep in eps: 

140 try: 

141 # Load the entry point 

142 loaded = ep.load() 

143 

144 # Extract the ops dict 

145 ops_dict: OpPackage 

146 if isinstance(loaded, dict): 

147 ops_dict = loaded 

148 elif callable(loaded): 

149 # Callable that returns the dict 

150 result = loaded() 

151 if not isinstance(result, dict): 

152 continue # Skip invalid entry points 

153 ops_dict = result 

154 else: 

155 continue # Skip invalid entry points 

156 

157 # Register the package using the entry point name as prefix 

158 self.register_package(ep.name, ops_dict) 

159 except Exception: 

160 # Skip invalid entry points silently 

161 continue