Coverage for zanj/serializing.py: 93%
58 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-28 16:46 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-28 16:46 -0600
1from __future__ import annotations
3import json
4import sys
5from dataclasses import dataclass
6from typing import IO, Any, Callable, Iterable, Sequence
8import numpy as np
9from muutils.json_serialize.array import arr_metadata
10from muutils.json_serialize.json_serialize import ( # JsonSerializer,
11 DEFAULT_HANDLERS,
12 ObjectPath,
13 SerializerHandler,
14)
15from muutils.json_serialize.util import (
16 JSONdict,
17 JSONitem,
18 MonoTuple,
19 _FORMAT_KEY,
20 _REF_KEY,
21)
23from zanj.externals import ExternalItem, ExternalItemType, _ZANJ_pre
25KW_ONLY_KWARGS: dict = dict()
26if sys.version_info >= (3, 10):
27 KW_ONLY_KWARGS["kw_only"] = True
29# pylint: disable=unused-argument, protected-access, unexpected-keyword-arg
30# for some reason pylint complains about kwargs to ZANJSerializerHandler
33def jsonl_metadata(data: list[JSONdict]) -> dict:
34 """metadata about a jsonl object"""
35 all_cols: set[str] = set([col for item in data for col in item.keys()])
36 return {
37 "data[0]": data[0],
38 "len(data)": len(data),
39 "columns": {
40 col: {
41 "types": list(
42 set([type(item[col]).__name__ for item in data if col in item])
43 ),
44 "len": len([item[col] for item in data if col in item]),
45 }
46 for col in all_cols
47 if col != _FORMAT_KEY
48 },
49 }
52def store_npy(self: _ZANJ_pre, fp: IO[bytes], data: np.ndarray) -> None:
53 """store numpy array to given file as .npy"""
54 # TODO: Type `<module 'numpy.lib'>` has no attribute `format` --> zanj/serializing.py:54:5
55 # info: rule `unresolved-attribute` is enabled by default
56 np.lib.format.write_array( # ty: ignore[unresolved-attribute]
57 fp=fp,
58 array=np.asanyarray(data),
59 allow_pickle=False,
60 )
63def store_jsonl(self: _ZANJ_pre, fp: IO[bytes], data: Sequence[JSONitem]) -> None:
64 """store sequence to given file as .jsonl"""
66 for item in data:
67 fp.write(json.dumps(item).encode("utf-8"))
68 fp.write("\n".encode("utf-8"))
71EXTERNAL_STORE_FUNCS: dict[
72 ExternalItemType, Callable[[_ZANJ_pre, IO[bytes], Any], None]
73] = {
74 "npy": store_npy,
75 "jsonl": store_jsonl,
76}
79@dataclass(**KW_ONLY_KWARGS)
80class ZANJSerializerHandler(SerializerHandler):
81 """a handler for ZANJ serialization"""
83 # unique identifier for the handler, saved in _FORMAT_KEY field
84 # uid: str
85 # source package of the handler -- note that this might be overridden by ZANJ
86 source_pckg: str
87 # (self_config, object) -> whether to use this handler
88 check: Callable[[_ZANJ_pre, Any, ObjectPath], bool]
89 # (self_config, object, path) -> serialized object
90 serialize_func: Callable[[_ZANJ_pre, Any, ObjectPath], JSONitem]
91 # optional description of how this serializer works
92 # desc: str = "(no description)"
95def zanj_external_serialize(
96 jser: _ZANJ_pre,
97 data: Any,
98 path: ObjectPath,
99 item_type: ExternalItemType,
100 _format: str,
101) -> JSONitem:
102 """stores a numpy array or jsonl externally in a ZANJ object
104 # Parameters:
105 - `jser: ZANJ`
106 - `data: Any`
107 - `path: ObjectPath`
108 - `item_type: ExternalItemType`
110 # Returns:
111 - `JSONitem`
112 json data with reference
114 # Modifies:
115 - modifies `jser._externals`
116 """
117 # get the path, make sure its unique
118 assert isinstance(path, tuple), (
119 f"path must be a tuple, got {type(path) = } {path = }"
120 )
121 joined_path: str = "/".join([str(p) for p in path])
122 archive_path: str = f"{joined_path}.{item_type}"
124 if archive_path in jser._externals:
125 raise ValueError(f"external path {archive_path} already exists!")
126 if any([p.startswith(joined_path) for p in jser._externals.keys()]):
127 raise ValueError(f"external path {joined_path} is a prefix of another path!")
129 # process the data if needed, assemble metadata
130 data_new: Any = data
131 output: dict = {
132 _FORMAT_KEY: _format,
133 _REF_KEY: archive_path,
134 }
135 if item_type == "npy":
136 # check type
137 data_type_str: str = str(type(data))
138 if data_type_str == "<class 'torch.Tensor'>":
139 # detach and convert
140 data_new = data.detach().cpu().numpy()
141 elif data_type_str == "<class 'numpy.ndarray'>":
142 pass
143 else:
144 # if not a numpy array, except
145 raise TypeError(f"expected numpy.ndarray, got {data_type_str}")
146 # get metadata
147 output.update(arr_metadata(data))
148 elif item_type.startswith("jsonl"):
149 # check via mro to avoid importing pandas
150 if any("pandas.core.frame.DataFrame" in str(t) for t in data.__class__.__mro__):
151 output["columns"] = data.columns.tolist()
152 data_new = data.to_dict(orient="records")
153 elif isinstance(data, (list, tuple, Iterable, Sequence)):
154 data_new = [
155 jser.json_serialize(item, tuple(path) + (i,))
156 for i, item in enumerate(data)
157 ]
158 else:
159 raise TypeError(
160 f"expected list or pandas.DataFrame for jsonl, got {type(data)}"
161 )
163 if all([isinstance(item, dict) for item in data_new]):
164 output.update(jsonl_metadata(data_new))
166 # store the item for external serialization
167 jser._externals[archive_path] = ExternalItem(
168 item_type=item_type,
169 data=data_new,
170 path=path,
171 )
173 return output
176DEFAULT_SERIALIZER_HANDLERS_ZANJ: MonoTuple[ZANJSerializerHandler] = tuple(
177 [
178 ZANJSerializerHandler(
179 check=lambda self, obj, path: (
180 isinstance(obj, np.ndarray)
181 and obj.size >= self.external_array_threshold
182 ),
183 serialize_func=lambda self, obj, path: zanj_external_serialize(
184 self, obj, path, item_type="npy", _format="numpy.ndarray:external"
185 ),
186 uid="numpy.ndarray:external",
187 source_pckg="zanj",
188 desc="external numpy array",
189 ),
190 ZANJSerializerHandler(
191 check=lambda self, obj, path: (
192 str(type(obj)) == "<class 'torch.Tensor'>"
193 and int(obj.nelement()) >= self.external_array_threshold
194 ),
195 serialize_func=lambda self, obj, path: zanj_external_serialize(
196 self, obj, path, item_type="npy", _format="torch.Tensor:external"
197 ),
198 uid="torch.Tensor:external",
199 source_pckg="zanj",
200 desc="external torch tensor",
201 ),
202 ZANJSerializerHandler(
203 check=lambda self, obj, path: isinstance(obj, list)
204 and len(obj) >= self.external_list_threshold,
205 serialize_func=lambda self, obj, path: zanj_external_serialize(
206 self, obj, path, item_type="jsonl", _format="list:external"
207 ),
208 uid="list:external",
209 source_pckg="zanj",
210 desc="external list",
211 ),
212 ZANJSerializerHandler(
213 check=lambda self, obj, path: isinstance(obj, tuple)
214 and len(obj) >= self.external_list_threshold,
215 serialize_func=lambda self, obj, path: zanj_external_serialize(
216 self, obj, path, item_type="jsonl", _format="tuple:external"
217 ),
218 uid="tuple:external",
219 source_pckg="zanj",
220 desc="external tuple",
221 ),
222 ZANJSerializerHandler(
223 check=lambda self, obj, path: (
224 any(
225 "pandas.core.frame.DataFrame" in str(t)
226 for t in obj.__class__.__mro__
227 )
228 and len(obj) >= self.external_list_threshold
229 ),
230 serialize_func=lambda self, obj, path: zanj_external_serialize(
231 self, obj, path, item_type="jsonl", _format="pandas.DataFrame:external"
232 ),
233 uid="pandas.DataFrame:external",
234 source_pckg="zanj",
235 desc="external pandas DataFrame",
236 ),
237 # ZANJSerializerHandler(
238 # check=lambda self, obj, path: "<class 'torch.nn.modules.module.Module'>"
239 # in [str(t) for t in obj.__class__.__mro__],
240 # serialize_func=lambda self, obj, path: zanj_serialize_torchmodule(
241 # self, obj, path,
242 # ),
243 # uid="torch.nn.Module",
244 # source_pckg="zanj",
245 # desc="fallback torch serialization",
246 # ),
247 ]
248) + tuple(
249 DEFAULT_HANDLERS # type: ignore[arg-type]
250)
252# the complaint above is:
253# error: Argument 1 to "tuple" has incompatible type "Sequence[SerializerHandler]"; expected "Iterable[ZANJSerializerHandler]" [arg-type]