Coverage for src/ramses_tx/gateway.py: 27%

130 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2026-01-05 21:46 +0100

1#!/usr/bin/env python3 

2 

3# TODO: 

4# - self._tasks is not ThreadSafe 

5 

6 

7"""RAMSES RF - The serial to RF gateway (HGI80, not RFG100).""" 

8 

9from __future__ import annotations 

10 

11import asyncio 

12import logging 

13from collections.abc import Callable 

14from datetime import datetime as dt 

15from threading import Lock 

16from typing import TYPE_CHECKING, Any, Never 

17 

18from .address import ALL_DEV_ADDR, HGI_DEV_ADDR, NON_DEV_ADDR 

19from .command import Command 

20from .const import ( 

21 DEFAULT_DISABLE_QOS, 

22 DEFAULT_GAP_DURATION, 

23 DEFAULT_MAX_RETRIES, 

24 DEFAULT_NUM_REPEATS, 

25 DEFAULT_SEND_TIMEOUT, 

26 DEFAULT_WAIT_FOR_REPLY, 

27 SZ_ACTIVE_HGI, 

28 Priority, 

29) 

30from .message import Message 

31from .packet import Packet 

32from .protocol import protocol_factory 

33from .schemas import ( 

34 SZ_DISABLE_QOS, 

35 SZ_DISABLE_SENDING, 

36 SZ_ENFORCE_KNOWN_LIST, 

37 SZ_LOG_ALL_MQTT, 

38 SZ_PACKET_LOG, 

39 SZ_PORT_CONFIG, 

40 SZ_PORT_NAME, 

41 SZ_SQLITE_INDEX, 

42 PktLogConfigT, 

43 PortConfigT, 

44 select_device_filter_mode, 

45) 

46from .transport import transport_factory 

47from .typing import QosParams 

48 

49from .const import ( # noqa: F401, isort: skip, pylint: disable=unused-import 

50 I_, 

51 RP, 

52 RQ, 

53 W_, 

54 Code, 

55) 

56 

57if TYPE_CHECKING: 

58 from .const import VerbT 

59 from .frame import PayloadT 

60 from .protocol import RamsesProtocolT 

61 from .schemas import DeviceIdT, DeviceListT 

62 from .transport import RamsesTransportT 

63 

64_MsgHandlerT = Callable[[Message], None] 

65 

66 

67DEV_MODE = False 

68 

69_LOGGER = logging.getLogger(__name__) 

70 

71 

72class Engine: 

73 """The engine class.""" 

74 

75 def __init__( 

76 self, 

77 port_name: str | None, 

78 input_file: str | None = None, 

79 port_config: PortConfigT | None = None, 

80 packet_log: PktLogConfigT | None = None, 

81 block_list: DeviceListT | None = None, 

82 known_list: DeviceListT | None = None, 

83 loop: asyncio.AbstractEventLoop | None = None, 

84 **kwargs: Any, 

85 ) -> None: 

86 if port_name and input_file: 

87 _LOGGER.warning( 

88 "Port (%s) specified, so file (%s) ignored", port_name, input_file 

89 ) 

90 input_file = None 

91 

92 self._disable_sending = kwargs.pop(SZ_DISABLE_SENDING, None) 

93 if input_file: 

94 self._disable_sending = True 

95 elif not port_name: 

96 raise TypeError("Either a port_name or an input_file must be specified") 

97 

98 self.ser_name = port_name 

99 self._input_file = input_file 

100 

101 self._port_config: PortConfigT | dict[Never, Never] = port_config or {} 

102 self._packet_log: PktLogConfigT | dict[Never, Never] = packet_log or {} 

103 self._loop = loop or asyncio.get_running_loop() 

104 

105 self._exclude: DeviceListT = block_list or {} 

106 self._include: DeviceListT = known_list or {} 

107 self._unwanted: list[DeviceIdT] = [ 

108 NON_DEV_ADDR.id, 

109 ALL_DEV_ADDR.id, 

110 "01:000001", # type: ignore[list-item] # why this one? 

111 ] 

112 self._enforce_known_list = select_device_filter_mode( 

113 kwargs.pop(SZ_ENFORCE_KNOWN_LIST, None), 

114 self._include, 

115 self._exclude, 

116 ) 

117 self._sqlite_index = kwargs.pop(SZ_SQLITE_INDEX, False) # default True? 

118 self._log_all_mqtt = kwargs.pop(SZ_LOG_ALL_MQTT, False) 

119 self._kwargs: dict[str, Any] = kwargs # HACK 

120 

121 self._engine_lock = Lock() # FIXME: threading lock, or asyncio lock? 

122 self._engine_state: ( 

123 tuple[_MsgHandlerT | None, bool | None, *tuple[Any, ...]] | None 

124 ) = None 

125 

126 self._protocol: RamsesProtocolT = None # type: ignore[assignment] 

127 self._transport: RamsesTransportT | None = None # None until self.start() 

128 

129 self._prev_msg: Message | None = None 

130 self._this_msg: Message | None = None 

131 

132 self._tasks: list[asyncio.Task] = [] # type: ignore[type-arg] 

133 

134 self._set_msg_handler(self._msg_handler) # sets self._protocol 

135 

136 def __str__(self) -> str: 

137 if not self._transport: 

138 return f"{HGI_DEV_ADDR.id} ({self.ser_name})" 

139 

140 device_id = self._transport.get_extra_info( 

141 SZ_ACTIVE_HGI, default=HGI_DEV_ADDR.id 

142 ) 

143 return f"{device_id} ({self.ser_name})" 

144 

145 def _dt_now(self) -> dt: 

146 return self._transport._dt_now() if self._transport else dt.now() 

147 

148 def _set_msg_handler(self, msg_handler: _MsgHandlerT) -> None: 

149 """Create an appropriate protocol for the packet source (transport). 

150 

151 The corresponding transport will be created later. 

152 """ 

153 

154 self._protocol = protocol_factory( 

155 msg_handler, 

156 disable_sending=self._disable_sending, 

157 disable_qos=self._kwargs.pop(SZ_DISABLE_QOS, DEFAULT_DISABLE_QOS), 

158 enforce_include_list=self._enforce_known_list, 

159 exclude_list=self._exclude, 

160 include_list=self._include, 

161 ) 

162 

163 def add_msg_handler( 

164 self, 

165 msg_handler: Callable[[Message], None], 

166 /, 

167 msg_filter: Callable[[Message], bool] | None = None, 

168 ) -> None: 

169 """Create a client protocol for the RAMSES-II message transport. 

170 

171 The optional filter will return True if the message is to be handled. 

172 """ 

173 

174 # if msg_filter is not None and not is_callback(msg_filter): 

175 # raise TypeError(f"Msg filter {msg_filter} is not a callback") 

176 

177 if not msg_filter: 

178 msg_filter = lambda _: True # noqa: E731 

179 else: 

180 raise NotImplementedError 

181 

182 self._protocol.add_handler(msg_handler, msg_filter=msg_filter) 

183 

184 async def start(self) -> None: 

185 """Create a suitable transport for the specified packet source. 

186 

187 Initiate receiving (Messages) and sending (Commands). 

188 """ 

189 

190 pkt_source: dict[str, Any] = {} # [str, dict | str | TextIO] 

191 if self.ser_name: 

192 pkt_source[SZ_PORT_NAME] = self.ser_name 

193 pkt_source[SZ_PORT_CONFIG] = self._port_config 

194 else: # if self._input_file: 

195 pkt_source[SZ_PACKET_LOG] = self._input_file # filename as string 

196 

197 # incl. await protocol.wait_for_connection_made(timeout=5) 

198 self._transport = await transport_factory( 

199 self._protocol, 

200 disable_sending=self._disable_sending, 

201 loop=self._loop, 

202 log_all=self._log_all_mqtt, 

203 **pkt_source, 

204 **self._kwargs, # HACK: odd/misc params, e.g. comms_params 

205 ) 

206 

207 self._kwargs = {} # HACK 

208 

209 await self._protocol.wait_for_connection_made() 

210 

211 # TODO: should this be removed (if so, pytest all before committing) 

212 if self._input_file: 

213 await self._protocol.wait_for_connection_lost() 

214 

215 async def stop(self) -> None: 

216 """Close the transport (will stop the protocol).""" 

217 

218 async def cancel_all_tasks() -> None: # TODO: needs a lock? 

219 _ = [t.cancel() for t in self._tasks if not t.done()] 

220 try: # FIXME: this is broken 

221 if tasks := (t for t in self._tasks if not t.done()): 

222 await asyncio.gather(*tasks) 

223 except asyncio.CancelledError: 

224 pass 

225 

226 await cancel_all_tasks() 

227 

228 if self._transport: 

229 self._transport.close() 

230 await self._protocol.wait_for_connection_lost() 

231 

232 return None 

233 

234 def _pause(self, *args: Any) -> None: 

235 """Pause the (active) engine or raise a RuntimeError.""" 

236 

237 if not self._engine_lock.acquire(blocking=False): 

238 raise RuntimeError("Unable to pause engine, failed to acquire lock") 

239 

240 if self._engine_state is not None: 

241 self._engine_lock.release() 

242 raise RuntimeError("Unable to pause engine, it is already paused") 

243 

244 self._engine_state = (None, None, tuple()) # aka not None 

245 self._engine_lock.release() # is ok to release now 

246 

247 self._protocol.pause_writing() # TODO: call_soon()? 

248 if self._transport: 

249 self._transport.pause_reading() # TODO: call_soon()? 

250 

251 self._protocol._msg_handler, handler = None, self._protocol._msg_handler # type: ignore[assignment] 

252 self._disable_sending, read_only = True, self._disable_sending 

253 

254 self._engine_state = (handler, read_only, *args) 

255 

256 def _resume(self) -> tuple[Any]: # FIXME: not atomic 

257 """Resume the (paused) engine or raise a RuntimeError.""" 

258 

259 args: tuple[Any] # mypy 

260 

261 if not self._engine_lock.acquire(timeout=0.1): 

262 raise RuntimeError("Unable to resume engine, failed to acquire lock") 

263 

264 if self._engine_state is None: 

265 self._engine_lock.release() 

266 raise RuntimeError("Unable to resume engine, it was not paused") 

267 

268 self._protocol._msg_handler, self._disable_sending, *args = self._engine_state # type: ignore[assignment] 

269 self._engine_lock.release() 

270 

271 if self._transport: 

272 self._transport.resume_reading() 

273 if not self._disable_sending: 

274 self._protocol.resume_writing() 

275 

276 self._engine_state = None 

277 

278 return args 

279 

280 def add_task(self, task: asyncio.Task[Any]) -> None: # TODO: needs a lock? 

281 # keep a track of tasks, so we can tidy-up 

282 self._tasks = [t for t in self._tasks if not t.done()] 

283 self._tasks.append(task) 

284 

285 @staticmethod 

286 def create_cmd( 

287 verb: VerbT, device_id: DeviceIdT, code: Code, payload: PayloadT, **kwargs: Any 

288 ) -> Command: 

289 """Make a command addressed to device_id.""" 

290 

291 if [ 

292 k for k in kwargs if k not in ("from_id", "seqn") 

293 ]: # FIXME: deprecate QoS in kwargs 

294 raise RuntimeError("Deprecated kwargs: %s", kwargs) 

295 

296 return Command.from_attrs(verb, device_id, code, payload, **kwargs) 

297 

298 async def async_send_cmd( 

299 self, 

300 cmd: Command, 

301 /, 

302 *, 

303 gap_duration: float = DEFAULT_GAP_DURATION, 

304 num_repeats: int = DEFAULT_NUM_REPEATS, 

305 priority: Priority = Priority.DEFAULT, 

306 max_retries: int = DEFAULT_MAX_RETRIES, 

307 timeout: float = DEFAULT_SEND_TIMEOUT, 

308 wait_for_reply: bool | None = DEFAULT_WAIT_FOR_REPLY, 

309 ) -> Packet: 

310 """Send a Command and return the corresponding Packet. 

311 

312 If wait_for_reply is True (*and* the Command has a rx_header), return the 

313 reply Packet. Otherwise, simply return the echo Packet. 

314 

315 If the expected Packet can't be returned, raise: 

316 ProtocolSendFailed: tried to Tx Command, but didn't get echo/reply 

317 ProtocolError: didn't attempt to Tx Command for some reason 

318 """ 

319 

320 qos = QosParams( 

321 max_retries=max_retries, 

322 timeout=timeout, 

323 wait_for_reply=wait_for_reply, 

324 ) 

325 

326 # adjust priority, WFR here? 

327 # if cmd.code in (Code._0005, Code._000C) and qos.wait_for_reply is None: 

328 # qos.wait_for_reply = True 

329 

330 return await self._protocol.send_cmd( 

331 cmd, 

332 gap_duration=gap_duration, 

333 num_repeats=num_repeats, 

334 priority=priority, 

335 qos=qos, 

336 ) # may: raise ProtocolError/ProtocolSendFailed 

337 

338 def _msg_handler(self, msg: Message) -> None: 

339 # HACK: This is one consequence of an unpleasant anachronism 

340 msg.__class__ = Message # HACK (next line too) 

341 msg._gwy = self # type: ignore[assignment] 

342 

343 self._this_msg, self._prev_msg = msg, self._this_msg