Coverage for arrakis/publish.py: 24.2%
99 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-16 15:43 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-16 15:43 -0700
1# Copyright (c) 2022, California Institute of Technology and contributors
2#
3# You should have received a copy of the licensing terms for this
4# software included in the file "LICENSE" located in the top-level
5# directory of this package. If you did not, you can view a copy at
6# https://git.ligo.org/ngdd/arrakis-python/-/raw/main/LICENSE
8"""Publisher API."""
10from __future__ import annotations
12import contextlib
13import logging
14from typing import TYPE_CHECKING, Literal
16import pyarrow
17from pyarrow.flight import connect
19from . import constants
20from .client import Client
21from .flight import (
22 MultiEndpointStream,
23 RequestType,
24 RequestValidator,
25 create_descriptor,
26 parse_url,
27)
29try:
30 from confluent_kafka import Producer
31except ImportError:
32 HAS_KAFKA = False
33else:
34 HAS_KAFKA = True
36if TYPE_CHECKING:
37 from collections.abc import Iterable
38 from datetime import timedelta
40 from .block import SeriesBlock
41 from .channel import Channel
44logger = logging.getLogger("arrakis")
47def serialize_batch(batch: pyarrow.RecordBatch):
48 """Serialize a record batch to bytes.
50 Parameters
51 ----------
52 batch : pyarrow.RecordBatch
53 The batch to serialize.
55 Returns
56 -------
57 bytes
58 The serialized buffer.
60 """
61 sink = pyarrow.BufferOutputStream()
62 with pyarrow.ipc.new_stream(sink, batch.schema) as writer:
63 writer.write_batch(batch)
64 return sink.getvalue()
67def channel_to_dtype_name(channel: Channel) -> str:
68 """Given a channel, return the data type's name."""
69 assert channel.data_type is not None
70 return channel.data_type.name
73class Publisher:
74 """Publisher to publish timeseries to Arrakis service.
76 Parameters
77 ----------
78 id : str
79 Publisher ID string.
80 url : str
81 Initial Flight URL to connect to.
83 """
85 def __init__(self, publisher_id: str, url: str | None = None):
86 if not HAS_KAFKA:
87 msg = (
88 "Publishing requires confluent-kafka to be installed."
89 "This is provided by the 'publish' extra or it can be "
90 "installed manually through pip or conda."
91 )
92 raise ImportError(msg)
94 self.publisher_id = publisher_id
95 self.initial_url = parse_url(url)
97 self.channels: dict[str, Channel] = {}
99 self._producer: Producer
100 self._partitions: dict[str, str]
101 self._registered = False
102 self._validator = RequestValidator()
104 def register(self):
105 assert not self._registered, "has already registered"
107 self.channels = {
108 channel.name: channel
109 for channel in Client(self.initial_url).find(publisher=self.publisher_id)
110 }
111 if not self.channels:
112 msg = f"unknown publisher ID '{self.publisher_id}'."
113 raise ValueError(msg)
115 # extract the channel partition map
116 self._partitions = {}
117 for channel in self.channels.values():
118 if not channel.partition_id:
119 msg = f"could not determine partition_id for channel {channel}."
120 raise ValueError(msg)
121 self._partitions[channel.name] = channel.partition_id
123 self._registered = True
125 return self
127 def enter(self):
128 if not self._registered:
129 msg = "must register publisher interface before publishing."
130 raise RuntimeError(msg)
132 # get connection properties
133 descriptor = create_descriptor(
134 RequestType.Publish,
135 publisher_id=self.publisher_id,
136 validator=self._validator,
137 )
138 properties: dict[str, str] = {}
139 with connect(self.initial_url) as client:
140 flight_info = client.get_flight_info(descriptor)
141 with MultiEndpointStream(flight_info.endpoints, client) as stream:
142 for data in stream.unpack():
143 kv_pairs = data["properties"]
144 properties.update(dict(kv_pairs))
146 # set up producer
147 self._producer = Producer(
148 {
149 "message.max.bytes": 10_000_000, # 10 MB
150 "enable.idempotence": True,
151 **properties,
152 }
153 )
155 def __enter__(self) -> Publisher:
156 self.enter()
157 return self
159 def publish(
160 self,
161 block: SeriesBlock,
162 timeout: timedelta = constants.DEFAULT_TIMEOUT,
163 ) -> None:
164 """Publish timeseries data
166 Parameters
167 ----------
168 block : SeriesBlock
169 A data block with all channels to publish.
170 timeout : timedelta, optional
171 The maximum time to wait to publish before timing out.
172 Default is 2 seconds.
174 """
175 if not self._producer:
176 msg = (
177 "publication interface not initialized, "
178 "please use context manager when publishing."
179 )
180 raise RuntimeError(msg)
182 for name, channel in block.channels.items():
183 if channel != self.channels[name]:
184 msg = f"invalid channel for this publisher: {channel}"
185 raise ValueError(msg)
187 # FIXME: updating partitions should only be allowed for
188 # special blessed publishers, that are currently not using
189 # this interface, so we're disabling this functionality for
190 # the time being, until we have a better way to manage it.
191 #
192 # # check for new metadata changes
193 # changed = set(block.channels.values()) - set(self.channels.values())
194 # # exchange to transfer metadata and get new/updated partition IDs
195 # if changed:
196 # self._update_partitions(changed)
198 # publish data for each data type, splitting into
199 # subblocks based on a maximum channel maximum
200 for partition_id, batch in block.to_row_batches(self._partitions):
201 topic = f"arrakis-{partition_id}"
202 logger.debug("publishing to topic %s: %s", topic, batch)
203 self._producer.produce(topic=topic, value=serialize_batch(batch))
204 self._producer.flush()
206 def _update_partitions(self, channels: Iterable[Channel]) -> None:
207 # set up flight
208 assert self._registered, "has not registered yet"
209 descriptor = create_descriptor(
210 RequestType.Partition,
211 publisher_id=self.publisher_id,
212 validator=self._validator,
213 )
214 # FIXME: should we not get FlightInfo first?
215 with connect(self.initial_url) as client:
216 writer, reader = client.do_exchange(descriptor)
218 # send over list of channels to map new/updated partitions for
219 dtypes = [channel_to_dtype_name(channel) for channel in channels]
220 schema = pyarrow.schema(
221 [
222 pyarrow.field("channel", pyarrow.string(), nullable=False),
223 pyarrow.field("data_type", pyarrow.string(), nullable=False),
224 pyarrow.field("sample_rate", pyarrow.int32(), nullable=False),
225 pyarrow.field("partition_id", pyarrow.string()),
226 ]
227 )
228 batch = pyarrow.RecordBatch.from_arrays(
229 [
230 pyarrow.array(
231 [str(channel) for channel in channels],
232 type=schema.field("channel").type,
233 ),
234 pyarrow.array(dtypes, type=schema.field("data_type").type),
235 pyarrow.array(
236 [channel.sample_rate for channel in channels],
237 type=schema.field("sample_rate").type,
238 ),
239 pyarrow.array(
240 [None for _ in channels],
241 type=schema.field("partition_id").type,
242 ),
243 ],
244 schema=schema,
245 )
247 # send over the partitions
248 writer.begin(schema)
249 writer.write_batch(batch)
250 writer.done_writing()
252 # get back the partition IDs and update
253 partitions = reader.read_all().to_pydict()
254 for channel, id_ in zip(partitions["channel"], partitions["partition_id"]):
255 self._partitions[channel] = id_
257 def close(self) -> None:
258 logger.info("closing kafka producer...")
259 with contextlib.suppress(Exception):
260 self._producer.flush()
262 def __exit__(self, *exc) -> Literal[False]:
263 self.close()
264 return False