Coverage for src/ramses_tx/gateway.py: 27%
130 statements
« prev ^ index » next coverage.py v7.11.3, created at 2026-01-05 21:46 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2026-01-05 21:46 +0100
1#!/usr/bin/env python3
3# TODO:
4# - self._tasks is not ThreadSafe
7"""RAMSES RF - The serial to RF gateway (HGI80, not RFG100)."""
9from __future__ import annotations
11import asyncio
12import logging
13from collections.abc import Callable
14from datetime import datetime as dt
15from threading import Lock
16from typing import TYPE_CHECKING, Any, Never
18from .address import ALL_DEV_ADDR, HGI_DEV_ADDR, NON_DEV_ADDR
19from .command import Command
20from .const import (
21 DEFAULT_DISABLE_QOS,
22 DEFAULT_GAP_DURATION,
23 DEFAULT_MAX_RETRIES,
24 DEFAULT_NUM_REPEATS,
25 DEFAULT_SEND_TIMEOUT,
26 DEFAULT_WAIT_FOR_REPLY,
27 SZ_ACTIVE_HGI,
28 Priority,
29)
30from .message import Message
31from .packet import Packet
32from .protocol import protocol_factory
33from .schemas import (
34 SZ_DISABLE_QOS,
35 SZ_DISABLE_SENDING,
36 SZ_ENFORCE_KNOWN_LIST,
37 SZ_LOG_ALL_MQTT,
38 SZ_PACKET_LOG,
39 SZ_PORT_CONFIG,
40 SZ_PORT_NAME,
41 SZ_SQLITE_INDEX,
42 PktLogConfigT,
43 PortConfigT,
44 select_device_filter_mode,
45)
46from .transport import transport_factory
47from .typing import QosParams
49from .const import ( # noqa: F401, isort: skip, pylint: disable=unused-import
50 I_,
51 RP,
52 RQ,
53 W_,
54 Code,
55)
57if TYPE_CHECKING:
58 from .const import VerbT
59 from .frame import PayloadT
60 from .protocol import RamsesProtocolT
61 from .schemas import DeviceIdT, DeviceListT
62 from .transport import RamsesTransportT
64_MsgHandlerT = Callable[[Message], None]
67DEV_MODE = False
69_LOGGER = logging.getLogger(__name__)
72class Engine:
73 """The engine class."""
75 def __init__(
76 self,
77 port_name: str | None,
78 input_file: str | None = None,
79 port_config: PortConfigT | None = None,
80 packet_log: PktLogConfigT | None = None,
81 block_list: DeviceListT | None = None,
82 known_list: DeviceListT | None = None,
83 loop: asyncio.AbstractEventLoop | None = None,
84 **kwargs: Any,
85 ) -> None:
86 if port_name and input_file:
87 _LOGGER.warning(
88 "Port (%s) specified, so file (%s) ignored", port_name, input_file
89 )
90 input_file = None
92 self._disable_sending = kwargs.pop(SZ_DISABLE_SENDING, None)
93 if input_file:
94 self._disable_sending = True
95 elif not port_name:
96 raise TypeError("Either a port_name or an input_file must be specified")
98 self.ser_name = port_name
99 self._input_file = input_file
101 self._port_config: PortConfigT | dict[Never, Never] = port_config or {}
102 self._packet_log: PktLogConfigT | dict[Never, Never] = packet_log or {}
103 self._loop = loop or asyncio.get_running_loop()
105 self._exclude: DeviceListT = block_list or {}
106 self._include: DeviceListT = known_list or {}
107 self._unwanted: list[DeviceIdT] = [
108 NON_DEV_ADDR.id,
109 ALL_DEV_ADDR.id,
110 "01:000001", # type: ignore[list-item] # why this one?
111 ]
112 self._enforce_known_list = select_device_filter_mode(
113 kwargs.pop(SZ_ENFORCE_KNOWN_LIST, None),
114 self._include,
115 self._exclude,
116 )
117 self._sqlite_index = kwargs.pop(SZ_SQLITE_INDEX, False) # default True?
118 self._log_all_mqtt = kwargs.pop(SZ_LOG_ALL_MQTT, False)
119 self._kwargs: dict[str, Any] = kwargs # HACK
121 self._engine_lock = Lock() # FIXME: threading lock, or asyncio lock?
122 self._engine_state: (
123 tuple[_MsgHandlerT | None, bool | None, *tuple[Any, ...]] | None
124 ) = None
126 self._protocol: RamsesProtocolT = None # type: ignore[assignment]
127 self._transport: RamsesTransportT | None = None # None until self.start()
129 self._prev_msg: Message | None = None
130 self._this_msg: Message | None = None
132 self._tasks: list[asyncio.Task] = [] # type: ignore[type-arg]
134 self._set_msg_handler(self._msg_handler) # sets self._protocol
136 def __str__(self) -> str:
137 if not self._transport:
138 return f"{HGI_DEV_ADDR.id} ({self.ser_name})"
140 device_id = self._transport.get_extra_info(
141 SZ_ACTIVE_HGI, default=HGI_DEV_ADDR.id
142 )
143 return f"{device_id} ({self.ser_name})"
145 def _dt_now(self) -> dt:
146 return self._transport._dt_now() if self._transport else dt.now()
148 def _set_msg_handler(self, msg_handler: _MsgHandlerT) -> None:
149 """Create an appropriate protocol for the packet source (transport).
151 The corresponding transport will be created later.
152 """
154 self._protocol = protocol_factory(
155 msg_handler,
156 disable_sending=self._disable_sending,
157 disable_qos=self._kwargs.pop(SZ_DISABLE_QOS, DEFAULT_DISABLE_QOS),
158 enforce_include_list=self._enforce_known_list,
159 exclude_list=self._exclude,
160 include_list=self._include,
161 )
163 def add_msg_handler(
164 self,
165 msg_handler: Callable[[Message], None],
166 /,
167 msg_filter: Callable[[Message], bool] | None = None,
168 ) -> None:
169 """Create a client protocol for the RAMSES-II message transport.
171 The optional filter will return True if the message is to be handled.
172 """
174 # if msg_filter is not None and not is_callback(msg_filter):
175 # raise TypeError(f"Msg filter {msg_filter} is not a callback")
177 if not msg_filter:
178 msg_filter = lambda _: True # noqa: E731
179 else:
180 raise NotImplementedError
182 self._protocol.add_handler(msg_handler, msg_filter=msg_filter)
184 async def start(self) -> None:
185 """Create a suitable transport for the specified packet source.
187 Initiate receiving (Messages) and sending (Commands).
188 """
190 pkt_source: dict[str, Any] = {} # [str, dict | str | TextIO]
191 if self.ser_name:
192 pkt_source[SZ_PORT_NAME] = self.ser_name
193 pkt_source[SZ_PORT_CONFIG] = self._port_config
194 else: # if self._input_file:
195 pkt_source[SZ_PACKET_LOG] = self._input_file # filename as string
197 # incl. await protocol.wait_for_connection_made(timeout=5)
198 self._transport = await transport_factory(
199 self._protocol,
200 disable_sending=self._disable_sending,
201 loop=self._loop,
202 log_all=self._log_all_mqtt,
203 **pkt_source,
204 **self._kwargs, # HACK: odd/misc params, e.g. comms_params
205 )
207 self._kwargs = {} # HACK
209 await self._protocol.wait_for_connection_made()
211 # TODO: should this be removed (if so, pytest all before committing)
212 if self._input_file:
213 await self._protocol.wait_for_connection_lost()
215 async def stop(self) -> None:
216 """Close the transport (will stop the protocol)."""
218 async def cancel_all_tasks() -> None: # TODO: needs a lock?
219 _ = [t.cancel() for t in self._tasks if not t.done()]
220 try: # FIXME: this is broken
221 if tasks := (t for t in self._tasks if not t.done()):
222 await asyncio.gather(*tasks)
223 except asyncio.CancelledError:
224 pass
226 await cancel_all_tasks()
228 if self._transport:
229 self._transport.close()
230 await self._protocol.wait_for_connection_lost()
232 return None
234 def _pause(self, *args: Any) -> None:
235 """Pause the (active) engine or raise a RuntimeError."""
237 if not self._engine_lock.acquire(blocking=False):
238 raise RuntimeError("Unable to pause engine, failed to acquire lock")
240 if self._engine_state is not None:
241 self._engine_lock.release()
242 raise RuntimeError("Unable to pause engine, it is already paused")
244 self._engine_state = (None, None, tuple()) # aka not None
245 self._engine_lock.release() # is ok to release now
247 self._protocol.pause_writing() # TODO: call_soon()?
248 if self._transport:
249 self._transport.pause_reading() # TODO: call_soon()?
251 self._protocol._msg_handler, handler = None, self._protocol._msg_handler # type: ignore[assignment]
252 self._disable_sending, read_only = True, self._disable_sending
254 self._engine_state = (handler, read_only, *args)
256 def _resume(self) -> tuple[Any]: # FIXME: not atomic
257 """Resume the (paused) engine or raise a RuntimeError."""
259 args: tuple[Any] # mypy
261 if not self._engine_lock.acquire(timeout=0.1):
262 raise RuntimeError("Unable to resume engine, failed to acquire lock")
264 if self._engine_state is None:
265 self._engine_lock.release()
266 raise RuntimeError("Unable to resume engine, it was not paused")
268 self._protocol._msg_handler, self._disable_sending, *args = self._engine_state # type: ignore[assignment]
269 self._engine_lock.release()
271 if self._transport:
272 self._transport.resume_reading()
273 if not self._disable_sending:
274 self._protocol.resume_writing()
276 self._engine_state = None
278 return args
280 def add_task(self, task: asyncio.Task[Any]) -> None: # TODO: needs a lock?
281 # keep a track of tasks, so we can tidy-up
282 self._tasks = [t for t in self._tasks if not t.done()]
283 self._tasks.append(task)
285 @staticmethod
286 def create_cmd(
287 verb: VerbT, device_id: DeviceIdT, code: Code, payload: PayloadT, **kwargs: Any
288 ) -> Command:
289 """Make a command addressed to device_id."""
291 if [
292 k for k in kwargs if k not in ("from_id", "seqn")
293 ]: # FIXME: deprecate QoS in kwargs
294 raise RuntimeError("Deprecated kwargs: %s", kwargs)
296 return Command.from_attrs(verb, device_id, code, payload, **kwargs)
298 async def async_send_cmd(
299 self,
300 cmd: Command,
301 /,
302 *,
303 gap_duration: float = DEFAULT_GAP_DURATION,
304 num_repeats: int = DEFAULT_NUM_REPEATS,
305 priority: Priority = Priority.DEFAULT,
306 max_retries: int = DEFAULT_MAX_RETRIES,
307 timeout: float = DEFAULT_SEND_TIMEOUT,
308 wait_for_reply: bool | None = DEFAULT_WAIT_FOR_REPLY,
309 ) -> Packet:
310 """Send a Command and return the corresponding Packet.
312 If wait_for_reply is True (*and* the Command has a rx_header), return the
313 reply Packet. Otherwise, simply return the echo Packet.
315 If the expected Packet can't be returned, raise:
316 ProtocolSendFailed: tried to Tx Command, but didn't get echo/reply
317 ProtocolError: didn't attempt to Tx Command for some reason
318 """
320 qos = QosParams(
321 max_retries=max_retries,
322 timeout=timeout,
323 wait_for_reply=wait_for_reply,
324 )
326 # adjust priority, WFR here?
327 # if cmd.code in (Code._0005, Code._000C) and qos.wait_for_reply is None:
328 # qos.wait_for_reply = True
330 return await self._protocol.send_cmd(
331 cmd,
332 gap_duration=gap_duration,
333 num_repeats=num_repeats,
334 priority=priority,
335 qos=qos,
336 ) # may: raise ProtocolError/ProtocolSendFailed
338 def _msg_handler(self, msg: Message) -> None:
339 # HACK: This is one consequence of an unpleasant anachronism
340 msg.__class__ = Message # HACK (next line too)
341 msg._gwy = self # type: ignore[assignment]
343 self._this_msg, self._prev_msg = msg, self._this_msg