Coverage for arrakis/publish.py: 24.2%

99 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-16 15:43 -0700

1# Copyright (c) 2022, California Institute of Technology and contributors 

2# 

3# You should have received a copy of the licensing terms for this 

4# software included in the file "LICENSE" located in the top-level 

5# directory of this package. If you did not, you can view a copy at 

6# https://git.ligo.org/ngdd/arrakis-python/-/raw/main/LICENSE 

7 

8"""Publisher API.""" 

9 

10from __future__ import annotations 

11 

12import contextlib 

13import logging 

14from typing import TYPE_CHECKING, Literal 

15 

16import pyarrow 

17from pyarrow.flight import connect 

18 

19from . import constants 

20from .client import Client 

21from .flight import ( 

22 MultiEndpointStream, 

23 RequestType, 

24 RequestValidator, 

25 create_descriptor, 

26 parse_url, 

27) 

28 

29try: 

30 from confluent_kafka import Producer 

31except ImportError: 

32 HAS_KAFKA = False 

33else: 

34 HAS_KAFKA = True 

35 

36if TYPE_CHECKING: 

37 from collections.abc import Iterable 

38 from datetime import timedelta 

39 

40 from .block import SeriesBlock 

41 from .channel import Channel 

42 

43 

44logger = logging.getLogger("arrakis") 

45 

46 

47def serialize_batch(batch: pyarrow.RecordBatch): 

48 """Serialize a record batch to bytes. 

49 

50 Parameters 

51 ---------- 

52 batch : pyarrow.RecordBatch 

53 The batch to serialize. 

54 

55 Returns 

56 ------- 

57 bytes 

58 The serialized buffer. 

59 

60 """ 

61 sink = pyarrow.BufferOutputStream() 

62 with pyarrow.ipc.new_stream(sink, batch.schema) as writer: 

63 writer.write_batch(batch) 

64 return sink.getvalue() 

65 

66 

67def channel_to_dtype_name(channel: Channel) -> str: 

68 """Given a channel, return the data type's name.""" 

69 assert channel.data_type is not None 

70 return channel.data_type.name 

71 

72 

73class Publisher: 

74 """Publisher to publish timeseries to Arrakis service. 

75 

76 Parameters 

77 ---------- 

78 id : str 

79 Publisher ID string. 

80 url : str 

81 Initial Flight URL to connect to. 

82 

83 """ 

84 

85 def __init__(self, publisher_id: str, url: str | None = None): 

86 if not HAS_KAFKA: 

87 msg = ( 

88 "Publishing requires confluent-kafka to be installed." 

89 "This is provided by the 'publish' extra or it can be " 

90 "installed manually through pip or conda." 

91 ) 

92 raise ImportError(msg) 

93 

94 self.publisher_id = publisher_id 

95 self.initial_url = parse_url(url) 

96 

97 self.channels: dict[str, Channel] = {} 

98 

99 self._producer: Producer 

100 self._partitions: dict[str, str] 

101 self._registered = False 

102 self._validator = RequestValidator() 

103 

104 def register(self): 

105 assert not self._registered, "has already registered" 

106 

107 self.channels = { 

108 channel.name: channel 

109 for channel in Client(self.initial_url).find(publisher=self.publisher_id) 

110 } 

111 if not self.channels: 

112 msg = f"unknown publisher ID '{self.publisher_id}'." 

113 raise ValueError(msg) 

114 

115 # extract the channel partition map 

116 self._partitions = {} 

117 for channel in self.channels.values(): 

118 if not channel.partition_id: 

119 msg = f"could not determine partition_id for channel {channel}." 

120 raise ValueError(msg) 

121 self._partitions[channel.name] = channel.partition_id 

122 

123 self._registered = True 

124 

125 return self 

126 

127 def enter(self): 

128 if not self._registered: 

129 msg = "must register publisher interface before publishing." 

130 raise RuntimeError(msg) 

131 

132 # get connection properties 

133 descriptor = create_descriptor( 

134 RequestType.Publish, 

135 publisher_id=self.publisher_id, 

136 validator=self._validator, 

137 ) 

138 properties: dict[str, str] = {} 

139 with connect(self.initial_url) as client: 

140 flight_info = client.get_flight_info(descriptor) 

141 with MultiEndpointStream(flight_info.endpoints, client) as stream: 

142 for data in stream.unpack(): 

143 kv_pairs = data["properties"] 

144 properties.update(dict(kv_pairs)) 

145 

146 # set up producer 

147 self._producer = Producer( 

148 { 

149 "message.max.bytes": 10_000_000, # 10 MB 

150 "enable.idempotence": True, 

151 **properties, 

152 } 

153 ) 

154 

155 def __enter__(self) -> Publisher: 

156 self.enter() 

157 return self 

158 

159 def publish( 

160 self, 

161 block: SeriesBlock, 

162 timeout: timedelta = constants.DEFAULT_TIMEOUT, 

163 ) -> None: 

164 """Publish timeseries data 

165 

166 Parameters 

167 ---------- 

168 block : SeriesBlock 

169 A data block with all channels to publish. 

170 timeout : timedelta, optional 

171 The maximum time to wait to publish before timing out. 

172 Default is 2 seconds. 

173 

174 """ 

175 if not self._producer: 

176 msg = ( 

177 "publication interface not initialized, " 

178 "please use context manager when publishing." 

179 ) 

180 raise RuntimeError(msg) 

181 

182 for name, channel in block.channels.items(): 

183 if channel != self.channels[name]: 

184 msg = f"invalid channel for this publisher: {channel}" 

185 raise ValueError(msg) 

186 

187 # FIXME: updating partitions should only be allowed for 

188 # special blessed publishers, that are currently not using 

189 # this interface, so we're disabling this functionality for 

190 # the time being, until we have a better way to manage it. 

191 # 

192 # # check for new metadata changes 

193 # changed = set(block.channels.values()) - set(self.channels.values()) 

194 # # exchange to transfer metadata and get new/updated partition IDs 

195 # if changed: 

196 # self._update_partitions(changed) 

197 

198 # publish data for each data type, splitting into 

199 # subblocks based on a maximum channel maximum 

200 for partition_id, batch in block.to_row_batches(self._partitions): 

201 topic = f"arrakis-{partition_id}" 

202 logger.debug("publishing to topic %s: %s", topic, batch) 

203 self._producer.produce(topic=topic, value=serialize_batch(batch)) 

204 self._producer.flush() 

205 

206 def _update_partitions(self, channels: Iterable[Channel]) -> None: 

207 # set up flight 

208 assert self._registered, "has not registered yet" 

209 descriptor = create_descriptor( 

210 RequestType.Partition, 

211 publisher_id=self.publisher_id, 

212 validator=self._validator, 

213 ) 

214 # FIXME: should we not get FlightInfo first? 

215 with connect(self.initial_url) as client: 

216 writer, reader = client.do_exchange(descriptor) 

217 

218 # send over list of channels to map new/updated partitions for 

219 dtypes = [channel_to_dtype_name(channel) for channel in channels] 

220 schema = pyarrow.schema( 

221 [ 

222 pyarrow.field("channel", pyarrow.string(), nullable=False), 

223 pyarrow.field("data_type", pyarrow.string(), nullable=False), 

224 pyarrow.field("sample_rate", pyarrow.int32(), nullable=False), 

225 pyarrow.field("partition_id", pyarrow.string()), 

226 ] 

227 ) 

228 batch = pyarrow.RecordBatch.from_arrays( 

229 [ 

230 pyarrow.array( 

231 [str(channel) for channel in channels], 

232 type=schema.field("channel").type, 

233 ), 

234 pyarrow.array(dtypes, type=schema.field("data_type").type), 

235 pyarrow.array( 

236 [channel.sample_rate for channel in channels], 

237 type=schema.field("sample_rate").type, 

238 ), 

239 pyarrow.array( 

240 [None for _ in channels], 

241 type=schema.field("partition_id").type, 

242 ), 

243 ], 

244 schema=schema, 

245 ) 

246 

247 # send over the partitions 

248 writer.begin(schema) 

249 writer.write_batch(batch) 

250 writer.done_writing() 

251 

252 # get back the partition IDs and update 

253 partitions = reader.read_all().to_pydict() 

254 for channel, id_ in zip(partitions["channel"], partitions["partition_id"]): 

255 self._partitions[channel] = id_ 

256 

257 def close(self) -> None: 

258 logger.info("closing kafka producer...") 

259 with contextlib.suppress(Exception): 

260 self._producer.flush() 

261 

262 def __exit__(self, *exc) -> Literal[False]: 

263 self.close() 

264 return False