Coverage for arrakis/channel.py: 90.6%
64 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-16 15:43 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-16 15:43 -0700
1# Copyright (c) 2022, California Institute of Technology and contributors
2#
3# You should have received a copy of the licensing terms for this
4# software included in the file "LICENSE" located in the top-level
5# directory of this package. If you did not, you can view a copy at
6# https://git.ligo.org/ngdd/arrakis-python/-/raw/main/LICENSE
8"""Channel information."""
10from __future__ import annotations
12import json
13from dataclasses import asdict, dataclass
14from functools import cached_property
15from typing import TYPE_CHECKING
17import numpy
19if TYPE_CHECKING:
20 import pyarrow
23@dataclass(frozen=True)
24class Channel:
25 """Metadata associated with a channel.
27 Channels have the form {domain}:*.
29 Parameters
30 ----------
31 name : str
32 The name associated with this channel.
33 data_type : numpy.dtype
34 The data type associated with this channel.
35 sample_rate : float
36 The sampling rate associated with this channel.
37 time : int, optional
38 The timestamp when this metadata became active.
39 publisher : str
40 The publisher associated with this channel.
41 partition_id : str, optional
42 The partition ID associated with this channel.
44 """
46 name: str
47 data_type: numpy.dtype
48 sample_rate: float
49 time: int | None = None
50 publisher: str | None = None
51 partition_id: str | None = None
53 @property
54 def dtype(self):
55 return self.data_type
57 def __post_init__(self) -> None:
58 # cast to numpy dtype object, as raw types like numpy.float64 are not
59 self.validate()
60 object.__setattr__(self, "data_type", numpy.dtype(self.data_type))
62 def validate(self) -> None:
63 components = self.name.split(":")
64 if len(components) != 2:
65 msg = "channel is malformed, needs to be in the form {domain}:*"
66 raise ValueError(msg)
68 def __repr__(self) -> str:
69 return f"<{self.name}, {self.sample_rate} Hz, {self.data_type}>"
71 def __str__(self) -> str:
72 return self.name
74 def __eq__(self, other) -> bool:
75 # name, data type and sample rate are always required to match
76 is_equal = (
77 self.name == other.name
78 and self.dtype == other.dtype
79 and self.sample_rate == other.sample_rate
80 )
82 # optional fields match only if both are defined
83 if self.time is not None and other.time is not None:
84 is_equal &= self.time == other.time
85 if self.publisher and other.publisher:
86 is_equal &= self.publisher == other.publisher
87 if self.partition_id and other.partition_id:
88 is_equal &= self.partition_id == other.partition_id
90 return is_equal
92 @cached_property
93 def domain(self) -> str:
94 """The domain associated with this channel."""
95 return self.name.split(":", 1)[0]
97 def to_json(self, time: int | None = None) -> str:
98 """Serialize channel metadata to JSON.
100 Parameters
101 ----------
102 time : int, optional
103 If specified, the timestamp when this metadata became active.
105 """
106 # generate dict from dataclass and adjust fields
107 # to be JSON compatible. In addition, store the
108 # channel name, as well as updating the timestamp
109 # if passed in.
110 obj = asdict(self)
111 obj["data_type"] = numpy.dtype(self.data_type).name
112 if time is not None:
113 obj["time"] = time
114 return json.dumps(obj)
116 @classmethod
117 def from_json(cls, payload: str) -> Channel:
118 """Create a Channel from its JSON representation.
120 Parameters
121 ----------
122 payload : str
123 The JSON-serialized channel.
125 Returns
126 -------
127 Channel
128 The newly created channel.
130 """
131 obj = json.loads(payload)
132 obj["data_type"] = numpy.dtype(obj["data_type"])
133 return cls(**obj)
135 @classmethod
136 def from_field(cls, field: pyarrow.field) -> Channel:
137 """Create a Channel from Arrow Flight field metadata.
139 Parameters
140 ----------
141 field : pyarrow.field
142 The channel field containing relevant metadata.
144 Returns
145 -------
146 Channel
147 The newly created channel.
149 """
150 data_type = numpy.dtype(_list_dtype_to_str(field.type))
151 sample_rate = float(field.metadata[b"rate"].decode())
152 return cls(field.name, data_type, sample_rate)
155def _list_dtype_to_str(dtype: pyarrow.ListType) -> str:
156 """Return a string representation of the list's inner data type.
158 Note that this does not always match the string representation
159 of Arrow's internal data types, to match the behavior across
160 different languages.
162 Parameters
163 ----------
164 dtype : pyarrow.ListType
165 The list data type to inspect.
167 Returns
168 -------
169 str
170 A string representation of the list's inner data type.
172 """
173 inner_dtype = str(dtype.value_type)
174 if inner_dtype == "float":
175 return "float32"
176 if inner_dtype == "double":
177 return "float64"
178 return inner_dtype