Coverage for zanj/serializing.py: 93%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-28 16:46 -0600

1from __future__ import annotations 

2 

3import json 

4import sys 

5from dataclasses import dataclass 

6from typing import IO, Any, Callable, Iterable, Sequence 

7 

8import numpy as np 

9from muutils.json_serialize.array import arr_metadata 

10from muutils.json_serialize.json_serialize import ( # JsonSerializer, 

11 DEFAULT_HANDLERS, 

12 ObjectPath, 

13 SerializerHandler, 

14) 

15from muutils.json_serialize.util import ( 

16 JSONdict, 

17 JSONitem, 

18 MonoTuple, 

19 _FORMAT_KEY, 

20 _REF_KEY, 

21) 

22 

23from zanj.externals import ExternalItem, ExternalItemType, _ZANJ_pre 

24 

25KW_ONLY_KWARGS: dict = dict() 

26if sys.version_info >= (3, 10): 

27 KW_ONLY_KWARGS["kw_only"] = True 

28 

29# pylint: disable=unused-argument, protected-access, unexpected-keyword-arg 

30# for some reason pylint complains about kwargs to ZANJSerializerHandler 

31 

32 

33def jsonl_metadata(data: list[JSONdict]) -> dict: 

34 """metadata about a jsonl object""" 

35 all_cols: set[str] = set([col for item in data for col in item.keys()]) 

36 return { 

37 "data[0]": data[0], 

38 "len(data)": len(data), 

39 "columns": { 

40 col: { 

41 "types": list( 

42 set([type(item[col]).__name__ for item in data if col in item]) 

43 ), 

44 "len": len([item[col] for item in data if col in item]), 

45 } 

46 for col in all_cols 

47 if col != _FORMAT_KEY 

48 }, 

49 } 

50 

51 

52def store_npy(self: _ZANJ_pre, fp: IO[bytes], data: np.ndarray) -> None: 

53 """store numpy array to given file as .npy""" 

54 # TODO: Type `<module 'numpy.lib'>` has no attribute `format` --> zanj/serializing.py:54:5 

55 # info: rule `unresolved-attribute` is enabled by default 

56 np.lib.format.write_array( # ty: ignore[unresolved-attribute] 

57 fp=fp, 

58 array=np.asanyarray(data), 

59 allow_pickle=False, 

60 ) 

61 

62 

63def store_jsonl(self: _ZANJ_pre, fp: IO[bytes], data: Sequence[JSONitem]) -> None: 

64 """store sequence to given file as .jsonl""" 

65 

66 for item in data: 

67 fp.write(json.dumps(item).encode("utf-8")) 

68 fp.write("\n".encode("utf-8")) 

69 

70 

71EXTERNAL_STORE_FUNCS: dict[ 

72 ExternalItemType, Callable[[_ZANJ_pre, IO[bytes], Any], None] 

73] = { 

74 "npy": store_npy, 

75 "jsonl": store_jsonl, 

76} 

77 

78 

79@dataclass(**KW_ONLY_KWARGS) 

80class ZANJSerializerHandler(SerializerHandler): 

81 """a handler for ZANJ serialization""" 

82 

83 # unique identifier for the handler, saved in _FORMAT_KEY field 

84 # uid: str 

85 # source package of the handler -- note that this might be overridden by ZANJ 

86 source_pckg: str 

87 # (self_config, object) -> whether to use this handler 

88 check: Callable[[_ZANJ_pre, Any, ObjectPath], bool] 

89 # (self_config, object, path) -> serialized object 

90 serialize_func: Callable[[_ZANJ_pre, Any, ObjectPath], JSONitem] 

91 # optional description of how this serializer works 

92 # desc: str = "(no description)" 

93 

94 

95def zanj_external_serialize( 

96 jser: _ZANJ_pre, 

97 data: Any, 

98 path: ObjectPath, 

99 item_type: ExternalItemType, 

100 _format: str, 

101) -> JSONitem: 

102 """stores a numpy array or jsonl externally in a ZANJ object 

103 

104 # Parameters: 

105 - `jser: ZANJ` 

106 - `data: Any` 

107 - `path: ObjectPath` 

108 - `item_type: ExternalItemType` 

109 

110 # Returns: 

111 - `JSONitem` 

112 json data with reference 

113 

114 # Modifies: 

115 - modifies `jser._externals` 

116 """ 

117 # get the path, make sure its unique 

118 assert isinstance(path, tuple), ( 

119 f"path must be a tuple, got {type(path) = } {path = }" 

120 ) 

121 joined_path: str = "/".join([str(p) for p in path]) 

122 archive_path: str = f"{joined_path}.{item_type}" 

123 

124 if archive_path in jser._externals: 

125 raise ValueError(f"external path {archive_path} already exists!") 

126 if any([p.startswith(joined_path) for p in jser._externals.keys()]): 

127 raise ValueError(f"external path {joined_path} is a prefix of another path!") 

128 

129 # process the data if needed, assemble metadata 

130 data_new: Any = data 

131 output: dict = { 

132 _FORMAT_KEY: _format, 

133 _REF_KEY: archive_path, 

134 } 

135 if item_type == "npy": 

136 # check type 

137 data_type_str: str = str(type(data)) 

138 if data_type_str == "<class 'torch.Tensor'>": 

139 # detach and convert 

140 data_new = data.detach().cpu().numpy() 

141 elif data_type_str == "<class 'numpy.ndarray'>": 

142 pass 

143 else: 

144 # if not a numpy array, except 

145 raise TypeError(f"expected numpy.ndarray, got {data_type_str}") 

146 # get metadata 

147 output.update(arr_metadata(data)) 

148 elif item_type.startswith("jsonl"): 

149 # check via mro to avoid importing pandas 

150 if any("pandas.core.frame.DataFrame" in str(t) for t in data.__class__.__mro__): 

151 output["columns"] = data.columns.tolist() 

152 data_new = data.to_dict(orient="records") 

153 elif isinstance(data, (list, tuple, Iterable, Sequence)): 

154 data_new = [ 

155 jser.json_serialize(item, tuple(path) + (i,)) 

156 for i, item in enumerate(data) 

157 ] 

158 else: 

159 raise TypeError( 

160 f"expected list or pandas.DataFrame for jsonl, got {type(data)}" 

161 ) 

162 

163 if all([isinstance(item, dict) for item in data_new]): 

164 output.update(jsonl_metadata(data_new)) 

165 

166 # store the item for external serialization 

167 jser._externals[archive_path] = ExternalItem( 

168 item_type=item_type, 

169 data=data_new, 

170 path=path, 

171 ) 

172 

173 return output 

174 

175 

176DEFAULT_SERIALIZER_HANDLERS_ZANJ: MonoTuple[ZANJSerializerHandler] = tuple( 

177 [ 

178 ZANJSerializerHandler( 

179 check=lambda self, obj, path: ( 

180 isinstance(obj, np.ndarray) 

181 and obj.size >= self.external_array_threshold 

182 ), 

183 serialize_func=lambda self, obj, path: zanj_external_serialize( 

184 self, obj, path, item_type="npy", _format="numpy.ndarray:external" 

185 ), 

186 uid="numpy.ndarray:external", 

187 source_pckg="zanj", 

188 desc="external numpy array", 

189 ), 

190 ZANJSerializerHandler( 

191 check=lambda self, obj, path: ( 

192 str(type(obj)) == "<class 'torch.Tensor'>" 

193 and int(obj.nelement()) >= self.external_array_threshold 

194 ), 

195 serialize_func=lambda self, obj, path: zanj_external_serialize( 

196 self, obj, path, item_type="npy", _format="torch.Tensor:external" 

197 ), 

198 uid="torch.Tensor:external", 

199 source_pckg="zanj", 

200 desc="external torch tensor", 

201 ), 

202 ZANJSerializerHandler( 

203 check=lambda self, obj, path: isinstance(obj, list) 

204 and len(obj) >= self.external_list_threshold, 

205 serialize_func=lambda self, obj, path: zanj_external_serialize( 

206 self, obj, path, item_type="jsonl", _format="list:external" 

207 ), 

208 uid="list:external", 

209 source_pckg="zanj", 

210 desc="external list", 

211 ), 

212 ZANJSerializerHandler( 

213 check=lambda self, obj, path: isinstance(obj, tuple) 

214 and len(obj) >= self.external_list_threshold, 

215 serialize_func=lambda self, obj, path: zanj_external_serialize( 

216 self, obj, path, item_type="jsonl", _format="tuple:external" 

217 ), 

218 uid="tuple:external", 

219 source_pckg="zanj", 

220 desc="external tuple", 

221 ), 

222 ZANJSerializerHandler( 

223 check=lambda self, obj, path: ( 

224 any( 

225 "pandas.core.frame.DataFrame" in str(t) 

226 for t in obj.__class__.__mro__ 

227 ) 

228 and len(obj) >= self.external_list_threshold 

229 ), 

230 serialize_func=lambda self, obj, path: zanj_external_serialize( 

231 self, obj, path, item_type="jsonl", _format="pandas.DataFrame:external" 

232 ), 

233 uid="pandas.DataFrame:external", 

234 source_pckg="zanj", 

235 desc="external pandas DataFrame", 

236 ), 

237 # ZANJSerializerHandler( 

238 # check=lambda self, obj, path: "<class 'torch.nn.modules.module.Module'>" 

239 # in [str(t) for t in obj.__class__.__mro__], 

240 # serialize_func=lambda self, obj, path: zanj_serialize_torchmodule( 

241 # self, obj, path, 

242 # ), 

243 # uid="torch.nn.Module", 

244 # source_pckg="zanj", 

245 # desc="fallback torch serialization", 

246 # ), 

247 ] 

248) + tuple( 

249 DEFAULT_HANDLERS # type: ignore[arg-type] 

250) 

251 

252# the complaint above is: 

253# error: Argument 1 to "tuple" has incompatible type "Sequence[SerializerHandler]"; expected "Iterable[ZANJSerializerHandler]" [arg-type]