Coverage for src / invariant / graph_serialization.py: 84.65%

215 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 10:21 +0100

1"""Graph serialization: JSON wire format for Invariant graphs. 

2 

3Encodes graphs (Node, SubGraphNode) and params (ref, cel, Decimal, tuple, 

4ICacheable) for storage and transmission. Distinct from artifact serialization 

5in store/codec.py. 

6""" 

7 

8import base64 

9import importlib 

10import json 

11from decimal import Decimal 

12from io import BytesIO 

13from typing import Any 

14 

15from invariant.graph import Graph 

16from invariant.node import Node, SubGraphNode 

17from invariant.params import cel, ref 

18from invariant.protocol import ICacheable 

19 

20SUPPORTED_VERSIONS = {1} 

21FORMAT_ID = "invariant-graph" 

22 

23RESERVED_KEYS = frozenset( 

24 {"$ref", "$cel", "$decimal", "$tuple", "$literal", "$icacheable"} 

25) 

26 

27 

28def _encode_param_value(value: Any) -> Any: 

29 """Recursively encode a parameter value to JSON-serializable form.""" 

30 # ref marker 

31 if isinstance(value, ref): 

32 return {"$ref": value.dep} 

33 

34 # cel marker 

35 if isinstance(value, cel): 

36 return {"$cel": value.expr} 

37 

38 # Decimal 

39 if isinstance(value, Decimal): 

40 return {"$decimal": str(value)} 

41 

42 # tuple 

43 if isinstance(value, tuple): 

44 return {"$tuple": [_encode_param_value(item) for item in value]} 

45 

46 # ICacheable 

47 if isinstance(value, ICacheable): 

48 type_name = f"{value.__class__.__module__}.{value.__class__.__name__}" 

49 if hasattr(value, "to_json_value") and callable( 

50 getattr(value, "to_json_value") 

51 ): 

52 return {"$icacheable": {"type": type_name, "value": value.to_json_value()}} 

53 stream = BytesIO() 

54 value.to_stream(stream) 

55 payload_b64 = base64.b64encode(stream.getvalue()).decode("ascii") 

56 return {"$icacheable": {"type": type_name, "payload_b64": payload_b64}} 

57 

58 # dict 

59 if isinstance(value, dict): 

60 encoded = {k: _encode_param_value(v) for k, v in value.items()} 

61 # Collision: plain dict that would decode as marker -> wrap in $literal 

62 if len(encoded) == 1: 

63 (single_key,) = encoded.keys() 

64 if single_key in RESERVED_KEYS: 

65 return {"$literal": encoded} 

66 return encoded 

67 

68 # list 

69 if isinstance(value, list): 

70 return [_encode_param_value(item) for item in value] 

71 

72 # Primitives: None, bool, int, str 

73 return value 

74 

75 

76def _decode_param_value(obj: Any, literal_mode: bool = False) -> Any: 

77 """Recursively decode a JSON value to Python parameter value.""" 

78 # In literal mode, never treat dicts as markers 

79 if literal_mode: 

80 if isinstance(obj, dict): 

81 return { 

82 k: _decode_param_value(v, literal_mode=True) for k, v in obj.items() 

83 } 

84 if isinstance(obj, list): 

85 return [_decode_param_value(item, literal_mode=True) for item in obj] 

86 return obj 

87 

88 # Single-key dict with reserved key -> marker or escape 

89 if isinstance(obj, dict): 

90 if len(obj) == 1: 

91 (key, val) = next(iter(obj.items())) 

92 if key == "$ref": 

93 return ref(val) 

94 if key == "$cel": 

95 return cel(val) 

96 if key == "$decimal": 

97 return Decimal(val) 

98 if key == "$tuple": 

99 return tuple(_decode_param_value(item) for item in val) 

100 if key == "$literal": 

101 return _decode_param_value(val, literal_mode=True) 

102 if key == "$icacheable": 

103 return _decode_icacheable(val) 

104 # Multi-key or non-reserved: recursive decode 

105 return {k: _decode_param_value(v) for k, v in obj.items()} 

106 

107 if isinstance(obj, list): 

108 return [_decode_param_value(item) for item in obj] 

109 

110 return obj 

111 

112 

113def _decode_icacheable(obj: dict) -> Any: 

114 """Decode $icacheable object to ICacheable instance.""" 

115 if not isinstance(obj, dict): 

116 raise ValueError("$icacheable value must be an object") 

117 type_name = obj.get("type") 

118 if not type_name or not isinstance(type_name, str): 

119 raise ValueError("$icacheable must have non-empty string 'type'") 

120 if "payload_b64" in obj and "value" in obj: 

121 raise ValueError( 

122 "$icacheable must have exactly one of 'payload_b64' or 'value'" 

123 ) 

124 if "payload_b64" not in obj and "value" not in obj: 

125 raise ValueError("$icacheable must have 'payload_b64' or 'value'") 

126 

127 module_path, class_name = type_name.rsplit(".", 1) 

128 try: 

129 module = importlib.import_module(module_path) 

130 cls = getattr(module, class_name) 

131 except (ImportError, AttributeError) as e: 

132 raise ValueError( 

133 f"$icacheable type '{type_name}' could not be imported: {e}" 

134 ) from e 

135 

136 if "value" in obj: 

137 if not hasattr(cls, "from_json_value"): 

138 raise ValueError( 

139 f"$icacheable type '{type_name}' has 'value' but no from_json_value method" 

140 ) 

141 return cls.from_json_value(obj["value"]) 

142 

143 # payload_b64 

144 try: 

145 payload = base64.b64decode(obj["payload_b64"]) 

146 except Exception as e: 

147 raise ValueError(f"$icacheable payload_b64 is invalid base64: {e}") from e 

148 stream = BytesIO(payload) 

149 try: 

150 return cls.from_stream(stream) 

151 except Exception as e: 

152 raise ValueError( 

153 f"$icacheable from_stream failed for '{type_name}': {e}" 

154 ) from e 

155 

156 

157def _encode_params(params: dict[str, Any]) -> dict[str, Any]: 

158 """Encode params dict with sorted keys for determinism.""" 

159 return dict(sorted((k, _encode_param_value(v)) for k, v in params.items())) 

160 

161 

162def _decode_params(obj: dict) -> dict[str, Any]: 

163 """Decode params dict.""" 

164 return {k: _decode_param_value(v) for k, v in obj.items()} 

165 

166 

167def _encode_vertex(vertex: Node | SubGraphNode) -> dict: 

168 """Encode a single vertex (Node or SubGraphNode) to JSON object.""" 

169 if isinstance(vertex, Node): 

170 result: dict = { 

171 "kind": "node", 

172 "op_name": vertex.op_name, 

173 "params": _encode_params(vertex.params), 

174 "deps": sorted(vertex.deps), 

175 } 

176 if not vertex.cache: 

177 result["cache"] = False 

178 return result 

179 # SubGraphNode 

180 return { 

181 "kind": "subgraph", 

182 "params": _encode_params(vertex.params), 

183 "deps": sorted(vertex.deps), 

184 "graph": _encode_graph(vertex.graph), 

185 "output": vertex.output, 

186 } 

187 

188 

189def _decode_vertex( 

190 obj: dict, legacy_kind_inference: bool = False 

191) -> Node | SubGraphNode: 

192 """Decode a JSON object to Node or SubGraphNode. Validates before construction.""" 

193 if not isinstance(obj, dict): 

194 raise ValueError("Vertex must be an object") 

195 

196 kind = obj.get("kind") 

197 if kind is None and legacy_kind_inference: 

198 if "op_name" in obj and "graph" not in obj: 

199 kind = "node" 

200 elif "graph" in obj and "output" in obj: 

201 kind = "subgraph" 

202 else: 

203 raise ValueError( 

204 "Vertex has no 'kind' and cannot infer from op_name/graph/output" 

205 ) 

206 if kind is None: 

207 raise ValueError("Vertex must have 'kind'") 

208 if kind not in ("node", "subgraph"): 

209 raise ValueError(f"Vertex has unsupported kind: {kind!r}") 

210 

211 if kind == "node": 

212 _validate_node(obj, expected_kind=kind) 

213 return Node( 

214 op_name=obj["op_name"].strip(), 

215 params=_decode_params(obj["params"]), 

216 deps=list(obj["deps"]), 

217 cache=obj.get("cache", True), 

218 ) 

219 if kind == "subgraph": 

220 _validate_subgraph(obj, legacy_kind_inference) 

221 return SubGraphNode( 

222 params=_decode_params(obj["params"]), 

223 deps=list(obj["deps"]), 

224 graph=_decode_graph(obj["graph"], legacy_kind_inference), 

225 output=obj["output"], 

226 ) 

227 raise ValueError(f"Vertex has unsupported kind: {kind!r}") 

228 

229 

230def _validate_node(obj: dict, expected_kind: str | None = None) -> None: 

231 """Validate node object before construction.""" 

232 kind = expected_kind if expected_kind is not None else obj.get("kind") 

233 if kind != "node": 

234 raise ValueError("Node must have kind 'node'") 

235 op_name = obj.get("op_name") 

236 if not isinstance(op_name, str): 

237 raise ValueError("Node must have string 'op_name'") 

238 if not op_name.strip(): 

239 raise ValueError("Node op_name cannot be empty") 

240 if "params" not in obj or not isinstance(obj["params"], dict): 

241 raise ValueError("Node must have 'params' object") 

242 if "deps" not in obj or not isinstance(obj["deps"], list): 

243 raise ValueError("Node must have 'deps' array") 

244 for i, dep in enumerate(obj["deps"]): 

245 if not isinstance(dep, str): 

246 raise ValueError(f"Node deps[{i}] must be string, got {type(dep).__name__}") 

247 cache_val = obj.get("cache") 

248 if cache_val is not None and not isinstance(cache_val, bool): 

249 raise ValueError("Node 'cache' must be boolean when present") 

250 

251 

252def _validate_subgraph(obj: dict, legacy_kind_inference: bool = False) -> None: 

253 """Validate subgraph object before construction.""" 

254 kind = obj.get("kind") 

255 if not legacy_kind_inference and kind != "subgraph": 

256 raise ValueError("SubGraphNode must have kind 'subgraph'") 

257 if "params" not in obj or not isinstance(obj["params"], dict): 

258 raise ValueError("SubGraphNode must have 'params' object") 

259 if "deps" not in obj or not isinstance(obj["deps"], list): 

260 raise ValueError("SubGraphNode must have 'deps' array") 

261 for i, dep in enumerate(obj["deps"]): 

262 if not isinstance(dep, str): 

263 raise ValueError( 

264 f"SubGraphNode deps[{i}] must be string, got {type(dep).__name__}" 

265 ) 

266 if "graph" not in obj or not isinstance(obj["graph"], dict): 

267 raise ValueError("SubGraphNode must have 'graph' object") 

268 output = obj.get("output") 

269 if not isinstance(output, str): 

270 raise ValueError("SubGraphNode must have string 'output'") 

271 if output not in obj["graph"]: 

272 raise ValueError( 

273 f"SubGraphNode output '{output}' must be key in graph. " 

274 f"Graph keys: {list(obj['graph'].keys())}" 

275 ) 

276 for node_id, vertex_obj in obj["graph"].items(): 

277 _validate_vertex_for_kind(vertex_obj, node_id, legacy_kind_inference) 

278 

279 

280def _validate_vertex_for_kind( 

281 vertex_obj: Any, node_id: str, legacy_kind_inference: bool = False 

282) -> None: 

283 """Validate a vertex object has valid kind and structure.""" 

284 if not isinstance(vertex_obj, dict): 

285 raise ValueError(f"Vertex '{node_id}' must be an object") 

286 kind = vertex_obj.get("kind") 

287 if kind is None and legacy_kind_inference: 

288 if "op_name" in vertex_obj and "graph" not in vertex_obj: 

289 kind = "node" 

290 elif "graph" in vertex_obj and "output" in vertex_obj: 

291 kind = "subgraph" 

292 else: 

293 raise ValueError( 

294 f"Vertex '{node_id}' has no 'kind' and cannot infer from op_name/graph/output" 

295 ) 

296 if kind == "node": 

297 _validate_node(vertex_obj, expected_kind="node") 

298 elif kind == "subgraph": 

299 _validate_subgraph(vertex_obj, legacy_kind_inference) 

300 else: 

301 raise ValueError(f"Vertex '{node_id}' has unsupported kind: {kind!r}") 

302 

303 

304def _encode_graph(graph: Graph) -> dict: 

305 """Encode graph to JSON object with sorted keys.""" 

306 return dict(sorted((k, _encode_vertex(v)) for k, v in graph.items())) 

307 

308 

309def _decode_graph(obj: dict, legacy_kind_inference: bool = False) -> Graph: 

310 """Decode graph from JSON object.""" 

311 if not isinstance(obj, dict): 

312 raise ValueError("Graph must be an object") 

313 result: Graph = {} 

314 for node_id, vertex_obj in obj.items(): 

315 result[node_id] = _decode_vertex(vertex_obj, legacy_kind_inference) 

316 return result 

317 

318 

319def _validate_envelope(obj: dict) -> None: 

320 """Validate top-level envelope.""" 

321 if not isinstance(obj, dict): 

322 raise ValueError("Document must be a JSON object") 

323 fmt = obj.get("format") 

324 if fmt != FORMAT_ID: 

325 raise ValueError(f"Document format must be '{FORMAT_ID}', got {fmt!r}") 

326 version = obj.get("version") 

327 if version not in SUPPORTED_VERSIONS: 

328 raise ValueError( 

329 f"Document version {version} is not supported. Supported: {sorted(SUPPORTED_VERSIONS)}" 

330 ) 

331 if "graph" not in obj: 

332 raise ValueError("Document must have 'graph'") 

333 if not isinstance(obj["graph"], dict): 

334 raise ValueError("Document 'graph' must be an object") 

335 

336 

337def dump_graph_to_dict(graph: Graph) -> dict: 

338 """Serialize graph to envelope dict. Deterministic (sorted keys).""" 

339 return { 

340 "format": FORMAT_ID, 

341 "version": 1, 

342 "graph": _encode_graph(graph), 

343 } 

344 

345 

346def dump_graph(graph: Graph) -> str: 

347 """Serialize graph to JSON string. Deterministic output.""" 

348 return json.dumps(dump_graph_to_dict(graph), sort_keys=True) 

349 

350 

351def load_graph_from_dict(obj: dict, legacy_kind_inference: bool = False) -> Graph: 

352 """Load graph from envelope dict.""" 

353 _validate_envelope(obj) 

354 return _decode_graph(obj["graph"], legacy_kind_inference) 

355 

356 

357def load_graph(data: str | bytes, legacy_kind_inference: bool = False) -> Graph: 

358 """Deserialize JSON string or bytes to graph.""" 

359 if isinstance(data, bytes): 

360 data = data.decode("utf-8") 

361 obj = json.loads(data) 

362 return load_graph_from_dict(obj, legacy_kind_inference)