Coverage for src / invariant / graph_serialization.py: 84.65%
215 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 10:21 +0100
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 10:21 +0100
1"""Graph serialization: JSON wire format for Invariant graphs.
3Encodes graphs (Node, SubGraphNode) and params (ref, cel, Decimal, tuple,
4ICacheable) for storage and transmission. Distinct from artifact serialization
5in store/codec.py.
6"""
8import base64
9import importlib
10import json
11from decimal import Decimal
12from io import BytesIO
13from typing import Any
15from invariant.graph import Graph
16from invariant.node import Node, SubGraphNode
17from invariant.params import cel, ref
18from invariant.protocol import ICacheable
20SUPPORTED_VERSIONS = {1}
21FORMAT_ID = "invariant-graph"
23RESERVED_KEYS = frozenset(
24 {"$ref", "$cel", "$decimal", "$tuple", "$literal", "$icacheable"}
25)
28def _encode_param_value(value: Any) -> Any:
29 """Recursively encode a parameter value to JSON-serializable form."""
30 # ref marker
31 if isinstance(value, ref):
32 return {"$ref": value.dep}
34 # cel marker
35 if isinstance(value, cel):
36 return {"$cel": value.expr}
38 # Decimal
39 if isinstance(value, Decimal):
40 return {"$decimal": str(value)}
42 # tuple
43 if isinstance(value, tuple):
44 return {"$tuple": [_encode_param_value(item) for item in value]}
46 # ICacheable
47 if isinstance(value, ICacheable):
48 type_name = f"{value.__class__.__module__}.{value.__class__.__name__}"
49 if hasattr(value, "to_json_value") and callable(
50 getattr(value, "to_json_value")
51 ):
52 return {"$icacheable": {"type": type_name, "value": value.to_json_value()}}
53 stream = BytesIO()
54 value.to_stream(stream)
55 payload_b64 = base64.b64encode(stream.getvalue()).decode("ascii")
56 return {"$icacheable": {"type": type_name, "payload_b64": payload_b64}}
58 # dict
59 if isinstance(value, dict):
60 encoded = {k: _encode_param_value(v) for k, v in value.items()}
61 # Collision: plain dict that would decode as marker -> wrap in $literal
62 if len(encoded) == 1:
63 (single_key,) = encoded.keys()
64 if single_key in RESERVED_KEYS:
65 return {"$literal": encoded}
66 return encoded
68 # list
69 if isinstance(value, list):
70 return [_encode_param_value(item) for item in value]
72 # Primitives: None, bool, int, str
73 return value
76def _decode_param_value(obj: Any, literal_mode: bool = False) -> Any:
77 """Recursively decode a JSON value to Python parameter value."""
78 # In literal mode, never treat dicts as markers
79 if literal_mode:
80 if isinstance(obj, dict):
81 return {
82 k: _decode_param_value(v, literal_mode=True) for k, v in obj.items()
83 }
84 if isinstance(obj, list):
85 return [_decode_param_value(item, literal_mode=True) for item in obj]
86 return obj
88 # Single-key dict with reserved key -> marker or escape
89 if isinstance(obj, dict):
90 if len(obj) == 1:
91 (key, val) = next(iter(obj.items()))
92 if key == "$ref":
93 return ref(val)
94 if key == "$cel":
95 return cel(val)
96 if key == "$decimal":
97 return Decimal(val)
98 if key == "$tuple":
99 return tuple(_decode_param_value(item) for item in val)
100 if key == "$literal":
101 return _decode_param_value(val, literal_mode=True)
102 if key == "$icacheable":
103 return _decode_icacheable(val)
104 # Multi-key or non-reserved: recursive decode
105 return {k: _decode_param_value(v) for k, v in obj.items()}
107 if isinstance(obj, list):
108 return [_decode_param_value(item) for item in obj]
110 return obj
113def _decode_icacheable(obj: dict) -> Any:
114 """Decode $icacheable object to ICacheable instance."""
115 if not isinstance(obj, dict):
116 raise ValueError("$icacheable value must be an object")
117 type_name = obj.get("type")
118 if not type_name or not isinstance(type_name, str):
119 raise ValueError("$icacheable must have non-empty string 'type'")
120 if "payload_b64" in obj and "value" in obj:
121 raise ValueError(
122 "$icacheable must have exactly one of 'payload_b64' or 'value'"
123 )
124 if "payload_b64" not in obj and "value" not in obj:
125 raise ValueError("$icacheable must have 'payload_b64' or 'value'")
127 module_path, class_name = type_name.rsplit(".", 1)
128 try:
129 module = importlib.import_module(module_path)
130 cls = getattr(module, class_name)
131 except (ImportError, AttributeError) as e:
132 raise ValueError(
133 f"$icacheable type '{type_name}' could not be imported: {e}"
134 ) from e
136 if "value" in obj:
137 if not hasattr(cls, "from_json_value"):
138 raise ValueError(
139 f"$icacheable type '{type_name}' has 'value' but no from_json_value method"
140 )
141 return cls.from_json_value(obj["value"])
143 # payload_b64
144 try:
145 payload = base64.b64decode(obj["payload_b64"])
146 except Exception as e:
147 raise ValueError(f"$icacheable payload_b64 is invalid base64: {e}") from e
148 stream = BytesIO(payload)
149 try:
150 return cls.from_stream(stream)
151 except Exception as e:
152 raise ValueError(
153 f"$icacheable from_stream failed for '{type_name}': {e}"
154 ) from e
157def _encode_params(params: dict[str, Any]) -> dict[str, Any]:
158 """Encode params dict with sorted keys for determinism."""
159 return dict(sorted((k, _encode_param_value(v)) for k, v in params.items()))
162def _decode_params(obj: dict) -> dict[str, Any]:
163 """Decode params dict."""
164 return {k: _decode_param_value(v) for k, v in obj.items()}
167def _encode_vertex(vertex: Node | SubGraphNode) -> dict:
168 """Encode a single vertex (Node or SubGraphNode) to JSON object."""
169 if isinstance(vertex, Node):
170 result: dict = {
171 "kind": "node",
172 "op_name": vertex.op_name,
173 "params": _encode_params(vertex.params),
174 "deps": sorted(vertex.deps),
175 }
176 if not vertex.cache:
177 result["cache"] = False
178 return result
179 # SubGraphNode
180 return {
181 "kind": "subgraph",
182 "params": _encode_params(vertex.params),
183 "deps": sorted(vertex.deps),
184 "graph": _encode_graph(vertex.graph),
185 "output": vertex.output,
186 }
189def _decode_vertex(
190 obj: dict, legacy_kind_inference: bool = False
191) -> Node | SubGraphNode:
192 """Decode a JSON object to Node or SubGraphNode. Validates before construction."""
193 if not isinstance(obj, dict):
194 raise ValueError("Vertex must be an object")
196 kind = obj.get("kind")
197 if kind is None and legacy_kind_inference:
198 if "op_name" in obj and "graph" not in obj:
199 kind = "node"
200 elif "graph" in obj and "output" in obj:
201 kind = "subgraph"
202 else:
203 raise ValueError(
204 "Vertex has no 'kind' and cannot infer from op_name/graph/output"
205 )
206 if kind is None:
207 raise ValueError("Vertex must have 'kind'")
208 if kind not in ("node", "subgraph"):
209 raise ValueError(f"Vertex has unsupported kind: {kind!r}")
211 if kind == "node":
212 _validate_node(obj, expected_kind=kind)
213 return Node(
214 op_name=obj["op_name"].strip(),
215 params=_decode_params(obj["params"]),
216 deps=list(obj["deps"]),
217 cache=obj.get("cache", True),
218 )
219 if kind == "subgraph":
220 _validate_subgraph(obj, legacy_kind_inference)
221 return SubGraphNode(
222 params=_decode_params(obj["params"]),
223 deps=list(obj["deps"]),
224 graph=_decode_graph(obj["graph"], legacy_kind_inference),
225 output=obj["output"],
226 )
227 raise ValueError(f"Vertex has unsupported kind: {kind!r}")
230def _validate_node(obj: dict, expected_kind: str | None = None) -> None:
231 """Validate node object before construction."""
232 kind = expected_kind if expected_kind is not None else obj.get("kind")
233 if kind != "node":
234 raise ValueError("Node must have kind 'node'")
235 op_name = obj.get("op_name")
236 if not isinstance(op_name, str):
237 raise ValueError("Node must have string 'op_name'")
238 if not op_name.strip():
239 raise ValueError("Node op_name cannot be empty")
240 if "params" not in obj or not isinstance(obj["params"], dict):
241 raise ValueError("Node must have 'params' object")
242 if "deps" not in obj or not isinstance(obj["deps"], list):
243 raise ValueError("Node must have 'deps' array")
244 for i, dep in enumerate(obj["deps"]):
245 if not isinstance(dep, str):
246 raise ValueError(f"Node deps[{i}] must be string, got {type(dep).__name__}")
247 cache_val = obj.get("cache")
248 if cache_val is not None and not isinstance(cache_val, bool):
249 raise ValueError("Node 'cache' must be boolean when present")
252def _validate_subgraph(obj: dict, legacy_kind_inference: bool = False) -> None:
253 """Validate subgraph object before construction."""
254 kind = obj.get("kind")
255 if not legacy_kind_inference and kind != "subgraph":
256 raise ValueError("SubGraphNode must have kind 'subgraph'")
257 if "params" not in obj or not isinstance(obj["params"], dict):
258 raise ValueError("SubGraphNode must have 'params' object")
259 if "deps" not in obj or not isinstance(obj["deps"], list):
260 raise ValueError("SubGraphNode must have 'deps' array")
261 for i, dep in enumerate(obj["deps"]):
262 if not isinstance(dep, str):
263 raise ValueError(
264 f"SubGraphNode deps[{i}] must be string, got {type(dep).__name__}"
265 )
266 if "graph" not in obj or not isinstance(obj["graph"], dict):
267 raise ValueError("SubGraphNode must have 'graph' object")
268 output = obj.get("output")
269 if not isinstance(output, str):
270 raise ValueError("SubGraphNode must have string 'output'")
271 if output not in obj["graph"]:
272 raise ValueError(
273 f"SubGraphNode output '{output}' must be key in graph. "
274 f"Graph keys: {list(obj['graph'].keys())}"
275 )
276 for node_id, vertex_obj in obj["graph"].items():
277 _validate_vertex_for_kind(vertex_obj, node_id, legacy_kind_inference)
280def _validate_vertex_for_kind(
281 vertex_obj: Any, node_id: str, legacy_kind_inference: bool = False
282) -> None:
283 """Validate a vertex object has valid kind and structure."""
284 if not isinstance(vertex_obj, dict):
285 raise ValueError(f"Vertex '{node_id}' must be an object")
286 kind = vertex_obj.get("kind")
287 if kind is None and legacy_kind_inference:
288 if "op_name" in vertex_obj and "graph" not in vertex_obj:
289 kind = "node"
290 elif "graph" in vertex_obj and "output" in vertex_obj:
291 kind = "subgraph"
292 else:
293 raise ValueError(
294 f"Vertex '{node_id}' has no 'kind' and cannot infer from op_name/graph/output"
295 )
296 if kind == "node":
297 _validate_node(vertex_obj, expected_kind="node")
298 elif kind == "subgraph":
299 _validate_subgraph(vertex_obj, legacy_kind_inference)
300 else:
301 raise ValueError(f"Vertex '{node_id}' has unsupported kind: {kind!r}")
304def _encode_graph(graph: Graph) -> dict:
305 """Encode graph to JSON object with sorted keys."""
306 return dict(sorted((k, _encode_vertex(v)) for k, v in graph.items()))
309def _decode_graph(obj: dict, legacy_kind_inference: bool = False) -> Graph:
310 """Decode graph from JSON object."""
311 if not isinstance(obj, dict):
312 raise ValueError("Graph must be an object")
313 result: Graph = {}
314 for node_id, vertex_obj in obj.items():
315 result[node_id] = _decode_vertex(vertex_obj, legacy_kind_inference)
316 return result
319def _validate_envelope(obj: dict) -> None:
320 """Validate top-level envelope."""
321 if not isinstance(obj, dict):
322 raise ValueError("Document must be a JSON object")
323 fmt = obj.get("format")
324 if fmt != FORMAT_ID:
325 raise ValueError(f"Document format must be '{FORMAT_ID}', got {fmt!r}")
326 version = obj.get("version")
327 if version not in SUPPORTED_VERSIONS:
328 raise ValueError(
329 f"Document version {version} is not supported. Supported: {sorted(SUPPORTED_VERSIONS)}"
330 )
331 if "graph" not in obj:
332 raise ValueError("Document must have 'graph'")
333 if not isinstance(obj["graph"], dict):
334 raise ValueError("Document 'graph' must be an object")
337def dump_graph_to_dict(graph: Graph) -> dict:
338 """Serialize graph to envelope dict. Deterministic (sorted keys)."""
339 return {
340 "format": FORMAT_ID,
341 "version": 1,
342 "graph": _encode_graph(graph),
343 }
346def dump_graph(graph: Graph) -> str:
347 """Serialize graph to JSON string. Deterministic output."""
348 return json.dumps(dump_graph_to_dict(graph), sort_keys=True)
351def load_graph_from_dict(obj: dict, legacy_kind_inference: bool = False) -> Graph:
352 """Load graph from envelope dict."""
353 _validate_envelope(obj)
354 return _decode_graph(obj["graph"], legacy_kind_inference)
357def load_graph(data: str | bytes, legacy_kind_inference: bool = False) -> Graph:
358 """Deserialize JSON string or bytes to graph."""
359 if isinstance(data, bytes):
360 data = data.decode("utf-8")
361 obj = json.loads(data)
362 return load_graph_from_dict(obj, legacy_kind_inference)