Coverage for src / invariant / store / codec.py: 100.00%
155 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 10:21 +0100
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 10:21 +0100
1"""Serialization codec for the full is_cacheable universe.
3Handles native types (int, str, Decimal, bool, None, dict, list, tuple)
4and ICacheable domain types (e.g. Polynomial) uniformly.
5"""
7import importlib
8from decimal import Decimal
9from io import BytesIO
10from typing import Any
12from invariant.cacheable import is_cacheable
13from invariant.protocol import ICacheable
16def serialize(value: Any) -> bytes:
17 """Serialize a cacheable value to bytes.
19 Supports the full is_cacheable universe:
20 - Native types: int, str, bool, None, Decimal
21 - Containers: dict, list, tuple (recursive)
22 - ICacheable domain types: uses to_stream()
24 Args:
25 value: The value to serialize. Must be cacheable.
27 Returns:
28 Serialized bytes with type information.
30 Raises:
31 TypeError: If value is not cacheable.
32 """
33 if not is_cacheable(value):
34 raise TypeError(f"Value is not cacheable: {type(value)}")
36 stream = BytesIO()
37 _serialize_value(value, stream)
38 return stream.getvalue()
41def deserialize(data: bytes) -> Any:
42 """Deserialize bytes to a cacheable value.
44 Args:
45 data: Serialized bytes from serialize().
47 Returns:
48 The deserialized value.
50 Raises:
51 ValueError: If data format is invalid.
52 """
53 stream = BytesIO(data)
54 return _deserialize_value(stream)
57def _serialize_value(value: Any, stream: BytesIO) -> None:
58 """Internal recursive serialization."""
59 # None
60 if value is None:
61 stream.write(b"none")
62 return
64 # bool (check before int since bool is subclass of int)
65 if isinstance(value, bool):
66 stream.write(b"bool")
67 stream.write(b"\x01" if value else b"\x00")
68 return
70 # int
71 if isinstance(value, int):
72 stream.write(b"int_")
73 # Use 8-byte signed big-endian
74 stream.write(value.to_bytes(8, byteorder="big", signed=True))
75 return
77 # str
78 if isinstance(value, str):
79 stream.write(b"str_")
80 data = value.encode("utf-8")
81 stream.write(len(data).to_bytes(8, byteorder="big", signed=False))
82 stream.write(data)
83 return
85 # Decimal
86 if isinstance(value, Decimal):
87 stream.write(b"decm")
88 # Store as canonical string
89 data = str(value).encode("utf-8")
90 stream.write(len(data).to_bytes(8, byteorder="big", signed=False))
91 stream.write(data)
92 return
94 # dict
95 if isinstance(value, dict):
96 stream.write(b"dict")
97 # Write length
98 stream.write(len(value).to_bytes(8, byteorder="big", signed=False))
99 # Write key-value pairs (sorted by key for determinism)
100 for key, val in sorted(value.items()):
101 # Key (must be str)
102 key_data = key.encode("utf-8")
103 stream.write(len(key_data).to_bytes(8, byteorder="big", signed=False))
104 stream.write(key_data)
105 # Value (recursive)
106 _serialize_value(val, stream)
107 return
109 # list
110 if isinstance(value, list):
111 stream.write(b"list")
112 # Write length
113 stream.write(len(value).to_bytes(8, byteorder="big", signed=False))
114 # Write elements (recursive)
115 for item in value:
116 _serialize_value(item, stream)
117 return
119 # tuple
120 if isinstance(value, tuple):
121 stream.write(b"tupl")
122 # Write length
123 stream.write(len(value).to_bytes(8, byteorder="big", signed=False))
124 # Write elements (recursive)
125 for item in value:
126 _serialize_value(item, stream)
127 return
129 # ICacheable domain types
130 if isinstance(value, ICacheable):
131 stream.write(b"icac")
132 # Store type information
133 type_name = f"{value.__class__.__module__}.{value.__class__.__name__}"
134 type_name_bytes = type_name.encode("utf-8")
135 stream.write(len(type_name_bytes).to_bytes(4, byteorder="big", signed=False))
136 stream.write(type_name_bytes)
137 # Use existing to_stream() method
138 value.to_stream(stream)
139 return
141 # Should never reach here if is_cacheable() is correct
142 raise TypeError(f"Unsupported type for serialization: {type(value)}")
145def _deserialize_value(stream: BytesIO) -> Any:
146 """Internal recursive deserialization."""
147 # Read type tag (4 bytes)
148 tag = stream.read(4)
149 if len(tag) != 4:
150 raise ValueError("Invalid serialization format: truncated type tag")
152 # None
153 if tag == b"none":
154 return None
156 # bool
157 if tag == b"bool":
158 byte = stream.read(1)
159 if len(byte) != 1:
160 raise ValueError("Invalid serialization format: truncated bool")
161 return byte == b"\x01"
163 # int
164 if tag == b"int_":
165 data = stream.read(8)
166 if len(data) != 8:
167 raise ValueError("Invalid serialization format: truncated int")
168 return int.from_bytes(data, byteorder="big", signed=True)
170 # str
171 if tag == b"str_":
172 length_data = stream.read(8)
173 if len(length_data) != 8:
174 raise ValueError("Invalid serialization format: truncated str length")
175 length = int.from_bytes(length_data, byteorder="big", signed=False)
176 data = stream.read(length)
177 if len(data) != length:
178 raise ValueError("Invalid serialization format: truncated str data")
179 return data.decode("utf-8")
181 # Decimal
182 if tag == b"decm":
183 length_data = stream.read(8)
184 if len(length_data) != 8:
185 raise ValueError("Invalid serialization format: truncated decimal length")
186 length = int.from_bytes(length_data, byteorder="big", signed=False)
187 data = stream.read(length)
188 if len(data) != length:
189 raise ValueError("Invalid serialization format: truncated decimal data")
190 return Decimal(data.decode("utf-8"))
192 # dict
193 if tag == b"dict":
194 length_data = stream.read(8)
195 if len(length_data) != 8:
196 raise ValueError("Invalid serialization format: truncated dict length")
197 length = int.from_bytes(length_data, byteorder="big", signed=False)
198 result = {}
199 for _ in range(length):
200 # Read key
201 key_length_data = stream.read(8)
202 if len(key_length_data) != 8:
203 raise ValueError(
204 "Invalid serialization format: truncated dict key length"
205 )
206 key_length = int.from_bytes(key_length_data, byteorder="big", signed=False)
207 key_data = stream.read(key_length)
208 if len(key_data) != key_length:
209 raise ValueError("Invalid serialization format: truncated dict key")
210 key = key_data.decode("utf-8")
211 # Read value (recursive)
212 value = _deserialize_value(stream)
213 result[key] = value
214 return result
216 # list
217 if tag == b"list":
218 length_data = stream.read(8)
219 if len(length_data) != 8:
220 raise ValueError("Invalid serialization format: truncated list length")
221 length = int.from_bytes(length_data, byteorder="big", signed=False)
222 result = []
223 for _ in range(length):
224 item = _deserialize_value(stream)
225 result.append(item)
226 return result
228 # tuple
229 if tag == b"tupl":
230 length_data = stream.read(8)
231 if len(length_data) != 8:
232 raise ValueError("Invalid serialization format: truncated tuple length")
233 length = int.from_bytes(length_data, byteorder="big", signed=False)
234 result = []
235 for _ in range(length):
236 item = _deserialize_value(stream)
237 result.append(item)
238 return tuple(result)
240 # ICacheable domain types
241 if tag == b"icac":
242 # Read type name length
243 type_name_len_data = stream.read(4)
244 if len(type_name_len_data) != 4:
245 raise ValueError(
246 "Invalid serialization format: truncated ICacheable type name length"
247 )
248 type_name_len = int.from_bytes(
249 type_name_len_data, byteorder="big", signed=False
250 )
251 # Read type name
252 type_name_bytes = stream.read(type_name_len)
253 if len(type_name_bytes) != type_name_len:
254 raise ValueError(
255 "Invalid serialization format: truncated ICacheable type name"
256 )
257 type_name = type_name_bytes.decode("utf-8")
258 # Import the class
259 module_path, class_name = type_name.rsplit(".", 1)
260 module = importlib.import_module(module_path)
261 cls = getattr(module, class_name)
262 # Deserialize using from_stream()
263 return cls.from_stream(stream)
265 raise ValueError(f"Unknown type tag: {tag!r}")