Coverage for src / invariant / store / codec.py: 100.00%

155 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 10:21 +0100

1"""Serialization codec for the full is_cacheable universe. 

2 

3Handles native types (int, str, Decimal, bool, None, dict, list, tuple) 

4and ICacheable domain types (e.g. Polynomial) uniformly. 

5""" 

6 

7import importlib 

8from decimal import Decimal 

9from io import BytesIO 

10from typing import Any 

11 

12from invariant.cacheable import is_cacheable 

13from invariant.protocol import ICacheable 

14 

15 

16def serialize(value: Any) -> bytes: 

17 """Serialize a cacheable value to bytes. 

18 

19 Supports the full is_cacheable universe: 

20 - Native types: int, str, bool, None, Decimal 

21 - Containers: dict, list, tuple (recursive) 

22 - ICacheable domain types: uses to_stream() 

23 

24 Args: 

25 value: The value to serialize. Must be cacheable. 

26 

27 Returns: 

28 Serialized bytes with type information. 

29 

30 Raises: 

31 TypeError: If value is not cacheable. 

32 """ 

33 if not is_cacheable(value): 

34 raise TypeError(f"Value is not cacheable: {type(value)}") 

35 

36 stream = BytesIO() 

37 _serialize_value(value, stream) 

38 return stream.getvalue() 

39 

40 

41def deserialize(data: bytes) -> Any: 

42 """Deserialize bytes to a cacheable value. 

43 

44 Args: 

45 data: Serialized bytes from serialize(). 

46 

47 Returns: 

48 The deserialized value. 

49 

50 Raises: 

51 ValueError: If data format is invalid. 

52 """ 

53 stream = BytesIO(data) 

54 return _deserialize_value(stream) 

55 

56 

57def _serialize_value(value: Any, stream: BytesIO) -> None: 

58 """Internal recursive serialization.""" 

59 # None 

60 if value is None: 

61 stream.write(b"none") 

62 return 

63 

64 # bool (check before int since bool is subclass of int) 

65 if isinstance(value, bool): 

66 stream.write(b"bool") 

67 stream.write(b"\x01" if value else b"\x00") 

68 return 

69 

70 # int 

71 if isinstance(value, int): 

72 stream.write(b"int_") 

73 # Use 8-byte signed big-endian 

74 stream.write(value.to_bytes(8, byteorder="big", signed=True)) 

75 return 

76 

77 # str 

78 if isinstance(value, str): 

79 stream.write(b"str_") 

80 data = value.encode("utf-8") 

81 stream.write(len(data).to_bytes(8, byteorder="big", signed=False)) 

82 stream.write(data) 

83 return 

84 

85 # Decimal 

86 if isinstance(value, Decimal): 

87 stream.write(b"decm") 

88 # Store as canonical string 

89 data = str(value).encode("utf-8") 

90 stream.write(len(data).to_bytes(8, byteorder="big", signed=False)) 

91 stream.write(data) 

92 return 

93 

94 # dict 

95 if isinstance(value, dict): 

96 stream.write(b"dict") 

97 # Write length 

98 stream.write(len(value).to_bytes(8, byteorder="big", signed=False)) 

99 # Write key-value pairs (sorted by key for determinism) 

100 for key, val in sorted(value.items()): 

101 # Key (must be str) 

102 key_data = key.encode("utf-8") 

103 stream.write(len(key_data).to_bytes(8, byteorder="big", signed=False)) 

104 stream.write(key_data) 

105 # Value (recursive) 

106 _serialize_value(val, stream) 

107 return 

108 

109 # list 

110 if isinstance(value, list): 

111 stream.write(b"list") 

112 # Write length 

113 stream.write(len(value).to_bytes(8, byteorder="big", signed=False)) 

114 # Write elements (recursive) 

115 for item in value: 

116 _serialize_value(item, stream) 

117 return 

118 

119 # tuple 

120 if isinstance(value, tuple): 

121 stream.write(b"tupl") 

122 # Write length 

123 stream.write(len(value).to_bytes(8, byteorder="big", signed=False)) 

124 # Write elements (recursive) 

125 for item in value: 

126 _serialize_value(item, stream) 

127 return 

128 

129 # ICacheable domain types 

130 if isinstance(value, ICacheable): 

131 stream.write(b"icac") 

132 # Store type information 

133 type_name = f"{value.__class__.__module__}.{value.__class__.__name__}" 

134 type_name_bytes = type_name.encode("utf-8") 

135 stream.write(len(type_name_bytes).to_bytes(4, byteorder="big", signed=False)) 

136 stream.write(type_name_bytes) 

137 # Use existing to_stream() method 

138 value.to_stream(stream) 

139 return 

140 

141 # Should never reach here if is_cacheable() is correct 

142 raise TypeError(f"Unsupported type for serialization: {type(value)}") 

143 

144 

145def _deserialize_value(stream: BytesIO) -> Any: 

146 """Internal recursive deserialization.""" 

147 # Read type tag (4 bytes) 

148 tag = stream.read(4) 

149 if len(tag) != 4: 

150 raise ValueError("Invalid serialization format: truncated type tag") 

151 

152 # None 

153 if tag == b"none": 

154 return None 

155 

156 # bool 

157 if tag == b"bool": 

158 byte = stream.read(1) 

159 if len(byte) != 1: 

160 raise ValueError("Invalid serialization format: truncated bool") 

161 return byte == b"\x01" 

162 

163 # int 

164 if tag == b"int_": 

165 data = stream.read(8) 

166 if len(data) != 8: 

167 raise ValueError("Invalid serialization format: truncated int") 

168 return int.from_bytes(data, byteorder="big", signed=True) 

169 

170 # str 

171 if tag == b"str_": 

172 length_data = stream.read(8) 

173 if len(length_data) != 8: 

174 raise ValueError("Invalid serialization format: truncated str length") 

175 length = int.from_bytes(length_data, byteorder="big", signed=False) 

176 data = stream.read(length) 

177 if len(data) != length: 

178 raise ValueError("Invalid serialization format: truncated str data") 

179 return data.decode("utf-8") 

180 

181 # Decimal 

182 if tag == b"decm": 

183 length_data = stream.read(8) 

184 if len(length_data) != 8: 

185 raise ValueError("Invalid serialization format: truncated decimal length") 

186 length = int.from_bytes(length_data, byteorder="big", signed=False) 

187 data = stream.read(length) 

188 if len(data) != length: 

189 raise ValueError("Invalid serialization format: truncated decimal data") 

190 return Decimal(data.decode("utf-8")) 

191 

192 # dict 

193 if tag == b"dict": 

194 length_data = stream.read(8) 

195 if len(length_data) != 8: 

196 raise ValueError("Invalid serialization format: truncated dict length") 

197 length = int.from_bytes(length_data, byteorder="big", signed=False) 

198 result = {} 

199 for _ in range(length): 

200 # Read key 

201 key_length_data = stream.read(8) 

202 if len(key_length_data) != 8: 

203 raise ValueError( 

204 "Invalid serialization format: truncated dict key length" 

205 ) 

206 key_length = int.from_bytes(key_length_data, byteorder="big", signed=False) 

207 key_data = stream.read(key_length) 

208 if len(key_data) != key_length: 

209 raise ValueError("Invalid serialization format: truncated dict key") 

210 key = key_data.decode("utf-8") 

211 # Read value (recursive) 

212 value = _deserialize_value(stream) 

213 result[key] = value 

214 return result 

215 

216 # list 

217 if tag == b"list": 

218 length_data = stream.read(8) 

219 if len(length_data) != 8: 

220 raise ValueError("Invalid serialization format: truncated list length") 

221 length = int.from_bytes(length_data, byteorder="big", signed=False) 

222 result = [] 

223 for _ in range(length): 

224 item = _deserialize_value(stream) 

225 result.append(item) 

226 return result 

227 

228 # tuple 

229 if tag == b"tupl": 

230 length_data = stream.read(8) 

231 if len(length_data) != 8: 

232 raise ValueError("Invalid serialization format: truncated tuple length") 

233 length = int.from_bytes(length_data, byteorder="big", signed=False) 

234 result = [] 

235 for _ in range(length): 

236 item = _deserialize_value(stream) 

237 result.append(item) 

238 return tuple(result) 

239 

240 # ICacheable domain types 

241 if tag == b"icac": 

242 # Read type name length 

243 type_name_len_data = stream.read(4) 

244 if len(type_name_len_data) != 4: 

245 raise ValueError( 

246 "Invalid serialization format: truncated ICacheable type name length" 

247 ) 

248 type_name_len = int.from_bytes( 

249 type_name_len_data, byteorder="big", signed=False 

250 ) 

251 # Read type name 

252 type_name_bytes = stream.read(type_name_len) 

253 if len(type_name_bytes) != type_name_len: 

254 raise ValueError( 

255 "Invalid serialization format: truncated ICacheable type name" 

256 ) 

257 type_name = type_name_bytes.decode("utf-8") 

258 # Import the class 

259 module_path, class_name = type_name.rsplit(".", 1) 

260 module = importlib.import_module(module_path) 

261 cls = getattr(module, class_name) 

262 # Deserialize using from_stream() 

263 return cls.from_stream(stream) 

264 

265 raise ValueError(f"Unknown type tag: {tag!r}")