Coverage for arrakis/channel.py: 90.6%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-16 15:43 -0700

1# Copyright (c) 2022, California Institute of Technology and contributors 

2# 

3# You should have received a copy of the licensing terms for this 

4# software included in the file "LICENSE" located in the top-level 

5# directory of this package. If you did not, you can view a copy at 

6# https://git.ligo.org/ngdd/arrakis-python/-/raw/main/LICENSE 

7 

8"""Channel information.""" 

9 

10from __future__ import annotations 

11 

12import json 

13from dataclasses import asdict, dataclass 

14from functools import cached_property 

15from typing import TYPE_CHECKING 

16 

17import numpy 

18 

19if TYPE_CHECKING: 

20 import pyarrow 

21 

22 

23@dataclass(frozen=True) 

24class Channel: 

25 """Metadata associated with a channel. 

26 

27 Channels have the form {domain}:*. 

28 

29 Parameters 

30 ---------- 

31 name : str 

32 The name associated with this channel. 

33 data_type : numpy.dtype 

34 The data type associated with this channel. 

35 sample_rate : float 

36 The sampling rate associated with this channel. 

37 time : int, optional 

38 The timestamp when this metadata became active. 

39 publisher : str 

40 The publisher associated with this channel. 

41 partition_id : str, optional 

42 The partition ID associated with this channel. 

43 

44 """ 

45 

46 name: str 

47 data_type: numpy.dtype 

48 sample_rate: float 

49 time: int | None = None 

50 publisher: str | None = None 

51 partition_id: str | None = None 

52 

53 @property 

54 def dtype(self): 

55 return self.data_type 

56 

57 def __post_init__(self) -> None: 

58 # cast to numpy dtype object, as raw types like numpy.float64 are not 

59 self.validate() 

60 object.__setattr__(self, "data_type", numpy.dtype(self.data_type)) 

61 

62 def validate(self) -> None: 

63 components = self.name.split(":") 

64 if len(components) != 2: 

65 msg = "channel is malformed, needs to be in the form {domain}:*" 

66 raise ValueError(msg) 

67 

68 def __repr__(self) -> str: 

69 return f"<{self.name}, {self.sample_rate} Hz, {self.data_type}>" 

70 

71 def __str__(self) -> str: 

72 return self.name 

73 

74 def __eq__(self, other) -> bool: 

75 # name, data type and sample rate are always required to match 

76 is_equal = ( 

77 self.name == other.name 

78 and self.dtype == other.dtype 

79 and self.sample_rate == other.sample_rate 

80 ) 

81 

82 # optional fields match only if both are defined 

83 if self.time is not None and other.time is not None: 

84 is_equal &= self.time == other.time 

85 if self.publisher and other.publisher: 

86 is_equal &= self.publisher == other.publisher 

87 if self.partition_id and other.partition_id: 

88 is_equal &= self.partition_id == other.partition_id 

89 

90 return is_equal 

91 

92 @cached_property 

93 def domain(self) -> str: 

94 """The domain associated with this channel.""" 

95 return self.name.split(":", 1)[0] 

96 

97 def to_json(self, time: int | None = None) -> str: 

98 """Serialize channel metadata to JSON. 

99 

100 Parameters 

101 ---------- 

102 time : int, optional 

103 If specified, the timestamp when this metadata became active. 

104 

105 """ 

106 # generate dict from dataclass and adjust fields 

107 # to be JSON compatible. In addition, store the 

108 # channel name, as well as updating the timestamp 

109 # if passed in. 

110 obj = asdict(self) 

111 obj["data_type"] = numpy.dtype(self.data_type).name 

112 if time is not None: 

113 obj["time"] = time 

114 return json.dumps(obj) 

115 

116 @classmethod 

117 def from_json(cls, payload: str) -> Channel: 

118 """Create a Channel from its JSON representation. 

119 

120 Parameters 

121 ---------- 

122 payload : str 

123 The JSON-serialized channel. 

124 

125 Returns 

126 ------- 

127 Channel 

128 The newly created channel. 

129 

130 """ 

131 obj = json.loads(payload) 

132 obj["data_type"] = numpy.dtype(obj["data_type"]) 

133 return cls(**obj) 

134 

135 @classmethod 

136 def from_field(cls, field: pyarrow.field) -> Channel: 

137 """Create a Channel from Arrow Flight field metadata. 

138 

139 Parameters 

140 ---------- 

141 field : pyarrow.field 

142 The channel field containing relevant metadata. 

143 

144 Returns 

145 ------- 

146 Channel 

147 The newly created channel. 

148 

149 """ 

150 data_type = numpy.dtype(_list_dtype_to_str(field.type)) 

151 sample_rate = float(field.metadata[b"rate"].decode()) 

152 return cls(field.name, data_type, sample_rate) 

153 

154 

155def _list_dtype_to_str(dtype: pyarrow.ListType) -> str: 

156 """Return a string representation of the list's inner data type. 

157 

158 Note that this does not always match the string representation 

159 of Arrow's internal data types, to match the behavior across 

160 different languages. 

161 

162 Parameters 

163 ---------- 

164 dtype : pyarrow.ListType 

165 The list data type to inspect. 

166 

167 Returns 

168 ------- 

169 str 

170 A string representation of the list's inner data type. 

171 

172 """ 

173 inner_dtype = str(dtype.value_type) 

174 if inner_dtype == "float": 

175 return "float32" 

176 if inner_dtype == "double": 

177 return "float64" 

178 return inner_dtype