Coverage for src/ramses_tx/protocol.py: 22%

322 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2026-01-05 21:46 +0100

1#!/usr/bin/env python3 

2"""RAMSES RF - RAMSES-II compatible packet protocol.""" 

3 

4from __future__ import annotations 

5 

6import asyncio 

7import logging 

8from collections.abc import Awaitable, Callable 

9from datetime import datetime as dt 

10from typing import TYPE_CHECKING, Any, Final, TypeAlias 

11 

12from . import exceptions as exc 

13from .address import ALL_DEV_ADDR, HGI_DEV_ADDR, NON_DEV_ADDR 

14from .command import Command 

15from .const import ( 

16 DEFAULT_DISABLE_QOS, 

17 DEFAULT_GAP_DURATION, 

18 DEFAULT_NUM_REPEATS, 

19 DEV_TYPE_MAP, 

20 SZ_ACTIVE_HGI, 

21 SZ_IS_EVOFW3, 

22 DevType, 

23 Priority, 

24) 

25from .logger import set_logger_timesource 

26from .message import Message 

27from .packet import Packet 

28from .protocol_fsm import ProtocolContext 

29from .schemas import SZ_BLOCK_LIST, SZ_CLASS, SZ_KNOWN_LIST, SZ_PORT_NAME 

30from .transport import transport_factory 

31from .typing import ExceptionT, MsgFilterT, MsgHandlerT, QosParams 

32 

33from .const import ( # noqa: F401, isort: skip, pylint: disable=unused-import 

34 I_, 

35 RP, 

36 RQ, 

37 W_, 

38 Code, 

39) 

40 

41if TYPE_CHECKING: 

42 from .schemas import DeviceIdT, DeviceListT 

43 from .transport import RamsesTransportT 

44 

45 

46TIP = f", configure the {SZ_KNOWN_LIST}/{SZ_BLOCK_LIST} as required" 

47 

48# 

49# NOTE: All debug flags should be False for deployment to end-users 

50_DBG_DISABLE_IMPERSONATION_ALERTS: Final[bool] = False 

51_DBG_DISABLE_QOS: Final[bool] = False 

52_DBG_FORCE_LOG_PACKETS: Final[bool] = False 

53 

54_LOGGER = logging.getLogger(__name__) 

55 

56 

57DEFAULT_QOS = QosParams() 

58 

59 

60class _BaseProtocol(asyncio.Protocol): 

61 """Base class for RAMSES II protocols.""" 

62 

63 WRITER_TASK = "writer_task" 

64 

65 def __init__(self, msg_handler: MsgHandlerT) -> None: 

66 self._msg_handler = msg_handler 

67 self._msg_handlers: list[MsgHandlerT] = [] 

68 

69 self._transport: RamsesTransportT | None = None 

70 self._loop = asyncio.get_running_loop() 

71 

72 self._pause_writing = False # FIXME: Start in R/O mode as no connection yet? 

73 self._wait_connection_lost: asyncio.Future[None] | None = None 

74 self._wait_connection_made: asyncio.Future[RamsesTransportT] = ( 

75 self._loop.create_future() 

76 ) 

77 

78 self._this_msg: Message | None = None 

79 self._prev_msg: Message | None = None 

80 

81 self._is_evofw3: bool | None = None 

82 

83 self._active_hgi: DeviceIdT | None = None 

84 self._context: ProtocolContext | None = None 

85 

86 @property 

87 def hgi_id(self) -> DeviceIdT: 

88 return HGI_DEV_ADDR.id 

89 

90 def add_handler( 

91 self, 

92 msg_handler: MsgHandlerT, 

93 /, 

94 *, 

95 msg_filter: MsgFilterT | None = None, 

96 ) -> Callable[[], None]: 

97 """Add a Message handler to the list of such callbacks. 

98 

99 Returns a callback that can be used to subsequently remove the Message handler. 

100 """ 

101 

102 def del_handler() -> None: 

103 if msg_handler in self._msg_handlers: 

104 self._msg_handlers.remove(msg_handler) 

105 

106 if msg_handler not in self._msg_handlers: 

107 self._msg_handlers.append(msg_handler) 

108 

109 return del_handler 

110 

111 def connection_made(self, transport: RamsesTransportT) -> None: # type: ignore[override] 

112 """Called when the connection to the Transport is established. 

113 

114 The argument is the transport representing the pipe connection. To receive data, 

115 wait for pkt_received() calls. When the connection is closed, connection_lost() 

116 is called. 

117 """ 

118 

119 if self._wait_connection_made.done(): 

120 return 

121 

122 self._wait_connection_lost = self._loop.create_future() 

123 self._wait_connection_made.set_result(transport) 

124 self._transport = transport 

125 

126 async def wait_for_connection_made(self, timeout: float = 1) -> RamsesTransportT: 

127 """A courtesy function to wait until connection_made() has been invoked. 

128 

129 Will raise TransportError if isn't connected within timeout seconds. 

130 """ 

131 

132 try: 

133 return await asyncio.wait_for(self._wait_connection_made, timeout) 

134 except TimeoutError as err: 

135 raise exc.TransportError( 

136 f"Transport did not bind to Protocol within {timeout} secs" 

137 ) from err 

138 

139 def connection_lost(self, err: ExceptionT | None) -> None: # type: ignore[override] 

140 """Called when the connection to the Transport is lost or closed. 

141 

142 The argument is an exception object or None (the latter meaning a regular EOF is 

143 received or the connection was aborted or closed). 

144 """ 

145 

146 # FIX: Check if _wait_connection_lost exists before asserting 

147 # This handles cases where connection was never fully established (e.g. timeout) 

148 if not self._wait_connection_lost: 

149 _LOGGER.debug( 

150 "connection_lost called but no connection was established (ignoring)" 

151 ) 

152 # Reset the connection made future for next attempt 

153 if self._wait_connection_made.done(): 

154 self._wait_connection_made = self._loop.create_future() 

155 return 

156 

157 if self._wait_connection_lost.done(): # BUG: why is callback invoked twice? 

158 return 

159 

160 self._wait_connection_made = self._loop.create_future() 

161 if err: 

162 self._wait_connection_lost.set_exception(err) 

163 else: 

164 self._wait_connection_lost.set_result(None) 

165 

166 async def wait_for_connection_lost(self, timeout: float = 1) -> ExceptionT | None: 

167 """A courtesy function to wait until connection_lost() has been invoked. 

168 

169 Includes scenarios where neither connection_made() nor connection_lost() were 

170 invoked. 

171 

172 Will raise TransportError if isn't disconnect within timeout seconds. 

173 """ 

174 

175 if not self._wait_connection_lost: 

176 return None 

177 

178 try: 

179 return await asyncio.wait_for(self._wait_connection_lost, timeout) 

180 except TimeoutError as err: 

181 raise exc.TransportError( 

182 f"Transport did not unbind from Protocol within {timeout} secs" 

183 ) from err 

184 

185 def pause_writing(self) -> None: 

186 """Called when the transport's buffer goes over the high-water mark. 

187 

188 Pause and resume calls are paired -- pause_writing() is called once when the 

189 buffer goes strictly over the high-water mark (even if subsequent writes 

190 increases the buffer size even more), and eventually resume_writing() is called 

191 once when the buffer size reaches the low-water mark. 

192 

193 Note that if the buffer size equals the high-water mark, pause_writing() is not 

194 called -- it must go strictly over. Conversely, resume_writing() is called when 

195 the buffer size is equal or lower than the low-water mark. These end conditions 

196 are important to ensure that things go as expected when either mark is zero. 

197 

198 NOTE: This is the only Protocol callback that is not called through 

199 EventLoop.call_soon() -- if it were, it would have no effect when it's most 

200 needed (when the app keeps writing without yielding until pause_writing() is 

201 called). 

202 """ 

203 

204 self._pause_writing = True 

205 

206 def resume_writing(self) -> None: 

207 """Called when the transport's buffer drains below the low-water mark. 

208 

209 See pause_writing() for details. 

210 """ 

211 

212 self._pause_writing = False 

213 

214 async def _send_impersonation_alert(self, cmd: Command) -> None: 

215 """Allow the Protocol to send an impersonation alert (stub).""" 

216 return 

217 

218 async def send_cmd( 

219 self, 

220 cmd: Command, 

221 /, 

222 *, 

223 gap_duration: float = DEFAULT_GAP_DURATION, 

224 num_repeats: int = DEFAULT_NUM_REPEATS, 

225 priority: Priority = Priority.DEFAULT, 

226 qos: QosParams = DEFAULT_QOS, # max_retries, timeout, wait_for_reply 

227 ) -> Packet: 

228 """Send a Command with Qos (with retries, until success or ProtocolError). 

229 

230 Returns the Command's response Packet or the Command echo if a response is not 

231 expected (e.g. sending an RP). 

232 

233 If wait_for_reply is True, return the RQ's RP (or W's I), or raise an exception 

234 if one doesn't arrive. If it is False, return the echo of the Command only. If 

235 it is None (the default), act as True for RQs, and False for all other Commands. 

236 

237 num_repeats is # of times to send the Command, in addition to the fist transmit, 

238 with gap_duration seconds between each transmission. If wait_for_reply is True, 

239 then num_repeats is ignored. 

240 

241 Commands are queued and sent FIFO, except higher-priority Commands are always 

242 sent first. 

243 

244 Will raise: 

245 ProtocolSendFailed: tried to Tx Command, but didn't get echo/reply 

246 ProtocolError: didn't attempt to Tx Command for some reason 

247 """ 

248 

249 assert gap_duration == DEFAULT_GAP_DURATION 

250 assert 0 <= num_repeats <= 3 # if QoS, only Tx x1, with no repeats 

251 

252 # FIX: Patch command with actual HGI ID if it uses the default placeholder 

253 # NOTE: HGI80s (TI 3410) require the default ID (18:000730), or they will silent-fail 

254 

255 if ( 

256 self.hgi_id 

257 and self._is_evofw3 # Only patch if using evofw3 (not HGI80) 

258 and cmd._addrs[0].id == HGI_DEV_ADDR.id 

259 and self.hgi_id != HGI_DEV_ADDR.id 

260 ): 

261 # The command uses the default 18:000730, but we know the real ID. 

262 # Reconstruct the command string with the correct address. 

263 

264 _LOGGER.debug( 

265 f"Patching command with active HGI ID: swapped {HGI_DEV_ADDR.id} -> {self.hgi_id} for {cmd._hdr}" 

266 ) 

267 

268 # Get current addresses as strings 

269 # The command uses the default 18:000730, but we know the real ID. 

270 # Reconstruct the command string with the correct address. 

271 

272 # Get current addresses as strings 

273 new_addrs = [a.id for a in cmd._addrs] 

274 

275 # ONLY patch the Source Address (Index 0). 

276 # Leave Dest (Index 1/2) alone to avoid breaking tests that expect 18:000730 there. 

277 new_addrs[0] = self.hgi_id 

278 

279 new_frame = f"{cmd.verb} {cmd.seqn} {new_addrs[0]} {new_addrs[1]} {new_addrs[2]} {cmd.code} {int(cmd.len_):03d} {cmd.payload}" 

280 cmd = Command(new_frame) 

281 

282 if qos and not self._context: 

283 _LOGGER.warning(f"{cmd} < QoS is currently disabled by this Protocol") 

284 

285 if cmd.src.id != self.hgi_id: # Was HGI_DEV_ADDR.id 

286 await self._send_impersonation_alert(cmd) 

287 

288 if qos.wait_for_reply and num_repeats: 

289 _LOGGER.warning(f"{cmd} < num_repeats set to 0, as wait_for_reply is True") 

290 num_repeats = 0 # the lesser crime over wait_for_reply=False 

291 

292 pkt = await self._send_cmd( # may: raise ProtocolError/ProtocolSendFailed 

293 cmd, 

294 gap_duration=gap_duration, 

295 num_repeats=num_repeats, 

296 priority=priority, 

297 qos=qos, 

298 ) 

299 

300 if not pkt: # HACK: temporary workaround for returning None 

301 raise exc.ProtocolSendFailed(f"Failed to send command: {cmd} (REPORT THIS)") 

302 

303 return pkt 

304 

305 async def _send_cmd( 

306 self, 

307 cmd: Command, 

308 /, 

309 *, 

310 gap_duration: float = DEFAULT_GAP_DURATION, 

311 num_repeats: int = DEFAULT_NUM_REPEATS, 

312 priority: Priority = Priority.DEFAULT, 

313 qos: QosParams = DEFAULT_QOS, 

314 ) -> Packet: # only cmd, no args, kwargs 

315 # await self._send_frame( 

316 # str(cmd), num_repeats=num_repeats, gap_duration=gap_duration 

317 # ) 

318 raise NotImplementedError(f"{self}: Unexpected error") 

319 

320 async def _send_frame( 

321 self, frame: str, num_repeats: int = 0, gap_duration: float = 0.0 

322 ) -> None: # _send_frame() -> transport 

323 """Write to the transport.""" 

324 

325 if self._transport is None: 

326 raise exc.ProtocolSendFailed("Transport is not connected") 

327 

328 await self._transport.write_frame(frame) 

329 for _ in range(num_repeats - 1): 

330 await asyncio.sleep(gap_duration) 

331 await self._transport.write_frame(frame) 

332 

333 def pkt_received(self, pkt: Packet) -> None: 

334 """A wrapper for self._pkt_received(pkt).""" 

335 if _DBG_FORCE_LOG_PACKETS: 

336 _LOGGER.warning(f"Recv'd: {pkt._rssi} {pkt}") 

337 elif _LOGGER.getEffectiveLevel() > logging.DEBUG: 

338 _LOGGER.info(f"Recv'd: {pkt._rssi} {pkt}") 

339 else: 

340 _LOGGER.debug(f"Recv'd: {pkt._rssi} {pkt}") 

341 

342 self._pkt_received(pkt) 

343 

344 def _pkt_received(self, pkt: Packet) -> None: 

345 """Called by the Transport when a Packet is received.""" 

346 try: 

347 msg = Message(pkt) # should log all invalid msgs appropriately 

348 except exc.PacketInvalid: # TODO: InvalidMessageError (packet is valid) 

349 return 

350 

351 self._this_msg, self._prev_msg = msg, self._this_msg 

352 self._msg_received(msg) 

353 

354 def _msg_received(self, msg: Message) -> None: 

355 """Pass any valid/wanted Messages to the client's callback. 

356 

357 Also maintain _prev_msg, _this_msg attrs. 

358 """ 

359 

360 if self._msg_handler: # type: ignore[truthy-function] 

361 _LOGGER.debug(f"Dispatching valid message to handler: {msg}") 

362 self._loop.call_soon_threadsafe(self._msg_handler, msg) 

363 for callback in self._msg_handlers: 

364 # TODO: if handler's filter returns True: 

365 self._loop.call_soon_threadsafe(callback, msg) 

366 

367 

368class _DeviceIdFilterMixin(_BaseProtocol): 

369 """Filter out any unwanted (but otherwise valid) packets via device ids.""" 

370 

371 def __init__( 

372 self, 

373 msg_handler: MsgHandlerT, 

374 enforce_include_list: bool = False, 

375 exclude_list: DeviceListT | None = None, 

376 include_list: DeviceListT | None = None, 

377 ) -> None: 

378 super().__init__(msg_handler) 

379 

380 exclude_list = exclude_list or {} 

381 include_list = include_list or {} 

382 

383 self.enforce_include = enforce_include_list 

384 self._exclude = list(exclude_list.keys()) 

385 self._include = list(include_list.keys()) 

386 self._include += [ALL_DEV_ADDR.id, NON_DEV_ADDR.id] 

387 

388 self._active_hgi: DeviceIdT | None = None 

389 # HACK: to disable_warnings if pkt source is static (e.g. a file/dict) 

390 # HACK: but a dynamic source (e.g. a port/MQTT) should warn if needed 

391 self._known_hgi = self._extract_known_hgi_id( 

392 include_list, disable_warnings=isinstance(self, ReadProtocol) 

393 ) 

394 

395 self._foreign_gwys_lst: list[DeviceIdT] = [] 

396 self._foreign_last_run = dt.now().date() 

397 

398 @property 

399 def hgi_id(self) -> DeviceIdT: 

400 if not self._transport: 

401 return self._known_hgi or HGI_DEV_ADDR.id 

402 # CRITICAL FIX: get_extra_info returns None if key exists but val is None 

403 # We must ensure we fallback to the known HGI or default if it returns None 

404 hgi = self._transport.get_extra_info(SZ_ACTIVE_HGI) 

405 return hgi or self._known_hgi or HGI_DEV_ADDR.id 

406 

407 @staticmethod 

408 def _extract_known_hgi_id( 

409 include_list: DeviceListT, 

410 /, 

411 *, 

412 disable_warnings: bool = False, 

413 strict_checking: bool = False, 

414 ) -> DeviceIdT | None: 

415 """Return the device_id of the gateway specified in the include_list, if any. 

416 

417 The 'Known' gateway is the predicted Active gateway, given the known_list. 

418 The 'Active' gateway is the USB device that is actually Tx/Rx-ing frames. 

419 

420 The Known gateway ID should be the Active gateway ID, but does not have to 

421 match. 

422 

423 Will send a warning if the include_list is configured incorrectly. 

424 """ 

425 

426 logger = _LOGGER.warning if not disable_warnings else _LOGGER.debug 

427 

428 explicit_hgis = [ 

429 k 

430 for k, v in include_list.items() 

431 if v.get(SZ_CLASS) in (DevType.HGI, DEV_TYPE_MAP[DevType.HGI]) 

432 ] 

433 implicit_hgis = [ 

434 k 

435 for k, v in include_list.items() 

436 if not v.get(SZ_CLASS) and k[:2] == DEV_TYPE_MAP._hex(DevType.HGI) 

437 ] 

438 

439 if not explicit_hgis and not implicit_hgis: 

440 logger( 

441 f"The {SZ_KNOWN_LIST} SHOULD include exactly one gateway (HGI), " 

442 f"but does not (it should specify 'class: HGI')" 

443 ) 

444 return None 

445 

446 known_hgi = (explicit_hgis if explicit_hgis else implicit_hgis)[0] 

447 

448 if include_list[known_hgi].get(SZ_CLASS) not in ( 

449 DevType.HGI, 

450 DEV_TYPE_MAP[DevType.HGI], 

451 ): 

452 logger( 

453 f"The {SZ_KNOWN_LIST} SHOULD include exactly one gateway (HGI): " 

454 f"{known_hgi} should specify 'class: HGI', as 18: is also used for HVAC" 

455 ) 

456 

457 elif len(explicit_hgis) > 1: 

458 logger( 

459 f"The {SZ_KNOWN_LIST} SHOULD include exactly one gateway (HGI): " 

460 f"{known_hgi} is the chosen device id (why is there >1 HGI?)" 

461 ) 

462 

463 else: 

464 _LOGGER.debug( 

465 f"The {SZ_KNOWN_LIST} includes exactly one gateway (HGI): {known_hgi}" 

466 ) 

467 

468 if strict_checking: 

469 return known_hgi if [known_hgi] == explicit_hgis else None 

470 return known_hgi 

471 

472 def _set_active_hgi(self, dev_id: DeviceIdT, by_signature: bool = False) -> None: 

473 """Set the Active Gateway (HGI) device_id. 

474 

475 Send a warning if the include list is configured incorrectly. 

476 """ 

477 

478 assert self._active_hgi is None # should only be called once 

479 

480 msg = f"The active gateway '{dev_id}: {{ class: HGI }}' " 

481 msg += "(by signature)" if by_signature else "(by filter)" 

482 

483 if dev_id not in self._exclude: 

484 self._active_hgi = dev_id 

485 # else: setting self._active_hgi will not help 

486 

487 if dev_id in self._exclude: 

488 _LOGGER.error(f"{msg} MUST NOT be in the {SZ_BLOCK_LIST}{TIP}") 

489 

490 elif dev_id not in self._include: 

491 _LOGGER.warning(f"{msg} SHOULD be in the (enforced) {SZ_KNOWN_LIST}") 

492 # self._include.append(dev_id) # a good idea? 

493 

494 elif not self.enforce_include: 

495 _LOGGER.info(f"{msg} is in the {SZ_KNOWN_LIST}, which SHOULD be enforced") 

496 

497 else: 

498 _LOGGER.debug(f"{msg} is in the {SZ_KNOWN_LIST}") 

499 

500 def _is_wanted_addrs( 

501 self, src_id: DeviceIdT, dst_id: DeviceIdT, sending: bool = False 

502 ) -> bool: 

503 """Return True if the packet is not to be filtered out. 

504 

505 In any one packet, an excluded device_id 'trumps' an included device_id. 

506 

507 There are two ways to set the Active Gateway (HGI80/evofw3): 

508 - by signature (evofw3 only), when frame -> packet 

509 - by known_list (HGI80/evofw3), when filtering packets 

510 """ 

511 

512 def warn_foreign_hgi(dev_id: DeviceIdT) -> None: 

513 current_date = dt.now().date() 

514 

515 if self._foreign_last_run != current_date: 

516 self._foreign_last_run = current_date 

517 self._foreign_gwys_lst = [] # reset the list every 24h 

518 

519 if dev_id in self._foreign_gwys_lst: 

520 return 

521 

522 _LOGGER.warning( 

523 f"Device {dev_id} is potentially a Foreign gateway, " 

524 f"the Active gateway is {self._active_hgi}, " 

525 f"alternatively, is it a HVAC device?{TIP}" 

526 ) 

527 self._foreign_gwys_lst.append(dev_id) 

528 

529 for dev_id in dict.fromkeys((src_id, dst_id)): # removes duplicates 

530 if dev_id in self._exclude: # problems if incl. active gateway 

531 return False 

532 

533 if dev_id == self._active_hgi: # is active gwy 

534 continue # consider: return True (but what if corrupted dst.id?) 

535 

536 if dev_id in self._include: # incl. 63:262142 & --:------ 

537 continue 

538 

539 if sending and dev_id == HGI_DEV_ADDR.id: 

540 continue 

541 

542 if self.enforce_include: 

543 return False 

544 

545 if dev_id[:2] != DEV_TYPE_MAP.HGI: 

546 continue 

547 

548 if self._active_hgi: # this 18: is not in known_list 

549 warn_foreign_hgi(dev_id) 

550 

551 return True 

552 

553 def pkt_received(self, pkt: Packet) -> None: 

554 if not self._is_wanted_addrs(pkt.src.id, pkt.dst.id): 

555 _LOGGER.debug("%s < Packet excluded by device_id filter", pkt) 

556 return 

557 super().pkt_received(pkt) 

558 

559 async def send_cmd(self, cmd: Command, *args: Any, **kwargs: Any) -> Packet: 

560 if not self._is_wanted_addrs(cmd.src.id, cmd.dst.id, sending=True): 

561 raise exc.ProtocolError(f"Command excluded by device_id filter: {cmd}") 

562 return await super().send_cmd(cmd, *args, **kwargs) 

563 

564 

565class ReadProtocol(_DeviceIdFilterMixin, _BaseProtocol): 

566 """A protocol that can only receive Packets.""" 

567 

568 def __init__(self, msg_handler: MsgHandlerT, **kwargs: Any) -> None: 

569 super().__init__(msg_handler, **kwargs) 

570 

571 self._pause_writing = True 

572 

573 def connection_made( # type: ignore[override] 

574 self, transport: RamsesTransportT, /, *, ramses: bool = False 

575 ) -> None: 

576 """Consume the callback if invoked by SerialTransport rather than PortTransport. 

577 

578 Our PortTransport wraps SerialTransport and will wait for the signature echo 

579 to be received (c.f. FileTransport) before calling connection_made(ramses=True). 

580 """ 

581 super().connection_made(transport) 

582 

583 def resume_writing(self) -> None: 

584 raise NotImplementedError(f"{self}: The chosen Protocol is Read-Only") 

585 

586 async def send_cmd( 

587 self, 

588 cmd: Command, 

589 /, 

590 *, 

591 gap_duration: float = DEFAULT_GAP_DURATION, 

592 num_repeats: int = DEFAULT_NUM_REPEATS, 

593 priority: Priority = Priority.DEFAULT, 

594 qos: QosParams | None = None, 

595 ) -> Packet: 

596 """Raise an exception as the Protocol cannot send Commands.""" 

597 raise NotImplementedError(f"{cmd._hdr}: < this Protocol is Read-Only") 

598 

599 

600class PortProtocol(_DeviceIdFilterMixin, _BaseProtocol): 

601 """A protocol that can receive Packets and send Commands +/- QoS (using a FSM).""" 

602 

603 def __init__( 

604 self, 

605 msg_handler: MsgHandlerT, 

606 disable_qos: bool | None = DEFAULT_DISABLE_QOS, 

607 **kwargs: Any, 

608 ) -> None: 

609 """Add a FSM to the Protocol, to provide QoS.""" 

610 super().__init__(msg_handler, **kwargs) 

611 

612 self._context = ProtocolContext(self) 

613 self._disable_qos = disable_qos # no wait_for_reply 

614 

615 def __repr__(self) -> str: 

616 if not self._context: 

617 return super().__repr__() 

618 cls = self._context.state.__class__.__name__ 

619 return f"QosProtocol({cls}, len(queue)={self._context._que.qsize()})" 

620 

621 def connection_made( # type: ignore[override] 

622 self, transport: RamsesTransportT, /, *, ramses: bool = False 

623 ) -> None: 

624 """Consume the callback if invoked by SerialTransport rather than PortTransport. 

625 

626 Our PortTransport wraps SerialTransport and will wait for the signature echo 

627 to be received (c.f. FileTransport) before calling connection_made(ramses=True). 

628 """ 

629 

630 if not ramses: 

631 return None 

632 

633 # if isinstance(transport, MqttTransport): # HACK 

634 # self._context.echo_timeout = 0.5 # HACK: need to move FSM to transport? 

635 

636 super().connection_made(transport) 

637 # TODO: needed? self.resume_writing() 

638 

639 # ROBUSTNESS FIX: Ensure self._transport is set even if the wait future was cancelled 

640 if self._transport is None: 

641 _LOGGER.warning( 

642 f"{self}: Transport bound after wait cancelled (late connection)" 

643 ) 

644 self._transport = transport 

645 

646 # Safe access with check (optional but recommended) 

647 if self._transport: 

648 self._set_active_hgi(self._transport.get_extra_info(SZ_ACTIVE_HGI)) 

649 self._is_evofw3 = self._transport.get_extra_info(SZ_IS_EVOFW3) 

650 

651 if not self._context: 

652 return 

653 

654 self._context.connection_made(transport) 

655 

656 if self._pause_writing: 

657 self._context.pause_writing() 

658 else: 

659 self._context.resume_writing() 

660 

661 def connection_lost(self, err: ExceptionT | None) -> None: # type: ignore[override] 

662 """Inform the FSM that the connection with the Transport has been lost.""" 

663 

664 super().connection_lost(err) 

665 if self._context: 

666 self._context.connection_lost(err) # is this safe, when KeyboardInterrupt? 

667 

668 def pause_writing(self) -> None: 

669 """Inform the FSM that the Protocol has been paused.""" 

670 

671 super().pause_writing() 

672 if self._context: 

673 self._context.pause_writing() 

674 

675 def resume_writing(self) -> None: 

676 """Inform the FSM that the Protocol has been resumed.""" 

677 

678 super().resume_writing() 

679 if self._context: 

680 self._context.resume_writing() 

681 

682 def pkt_received(self, pkt: Packet) -> None: 

683 """Pass any valid/wanted packets to the callback.""" 

684 

685 super().pkt_received(pkt) 

686 if self._context: 

687 self._context.pkt_received(pkt) 

688 

689 async def _send_impersonation_alert(self, cmd: Command) -> None: 

690 """Send a puzzle packet warning that impersonation is occurring.""" 

691 

692 if _DBG_DISABLE_IMPERSONATION_ALERTS: 

693 return 

694 

695 msg = f"{self}: Impersonating device: {cmd.src}, for pkt: {cmd.tx_header}" 

696 if self._is_evofw3 is False: 

697 _LOGGER.error(f"{msg}, NB: non-evofw3 gateways can't impersonate!") 

698 else: 

699 _LOGGER.info(msg) 

700 

701 await self._send_cmd(Command._puzzle(msg_type="11", message=cmd.tx_header)) 

702 

703 async def _send_cmd( # NOTE: QoS wrapped here... 

704 self, 

705 cmd: Command, 

706 /, 

707 *, 

708 gap_duration: float = DEFAULT_GAP_DURATION, 

709 num_repeats: int = DEFAULT_NUM_REPEATS, 

710 priority: Priority = Priority.DEFAULT, 

711 qos: QosParams = DEFAULT_QOS, 

712 ) -> Packet: 

713 """Wrapper to send a Command with QoS (retries, until success or exception).""" 

714 

715 # TODO: use a sync function, so we don't have a stack of awaits before the write 

716 async def send_cmd(kmd: Command) -> None: 

717 """Wrapper to for self._send_frame(cmd).""" 

718 

719 await self._send_frame( 

720 str(kmd), gap_duration=gap_duration, num_repeats=num_repeats 

721 ) 

722 

723 qos = qos or DEFAULT_QOS 

724 

725 if _DBG_DISABLE_QOS: # TODO: should allow echo Packet? 

726 await send_cmd(cmd) 

727 return None # type: ignore[return-value] # used for test/dev 

728 

729 # if cmd.code == Code._PUZZ: # NOTE: not as simple as this 

730 # priority = Priority.HIGHEST # FIXME: hack for _7FFF 

731 

732 _CODES = (Code._0006, Code._0404, Code._0418, Code._1FC9) # must have QoS 

733 # 0006|RQ must have wait_for_reply: (TODO: explain why) 

734 # 0404|RQ must have wait_for_reply: (TODO: explain why) 

735 # 0418|RQ must have wait_for_reply: if null log entry, reply has no idx 

736 # 1FC9|xx must have wait_for_reply and priority (timing critical) 

737 

738 if self._disable_qos is True or _DBG_DISABLE_QOS: 

739 qos._wait_for_reply = False 

740 elif self._disable_qos is None and cmd.code not in _CODES: 

741 qos._wait_for_reply = False 

742 

743 # Should do this check before, or after previous block (of non-QoS sends)? 

744 # if not self._transport._is_wanted_addrs(cmd.src.id, cmd.dst.id, sending=True): 

745 # raise exc.ProtocolError( 

746 # f"{self}: Failed to send {cmd._hdr}: excluded by list" 

747 # ) 

748 

749 assert self._context 

750 

751 try: 

752 return await self._context.send_cmd(send_cmd, cmd, priority, qos) 

753 # except InvalidStateError as err: # TODO: handle InvalidStateError separately 

754 # # reset protocol stack 

755 except exc.ProtocolError as err: 

756 _LOGGER.info(f"{self}: Failed to send {cmd._hdr}: {err}") 

757 raise 

758 

759 async def send_cmd( 

760 self, 

761 cmd: Command, 

762 /, 

763 *, 

764 gap_duration: float = DEFAULT_GAP_DURATION, 

765 num_repeats: int = DEFAULT_NUM_REPEATS, 

766 priority: Priority = Priority.DEFAULT, 

767 qos: QosParams = DEFAULT_QOS, # max_retries, timeout, wait_for_reply 

768 ) -> Packet: 

769 """Send a Command with Qos (with retries, until success or ProtocolError). 

770 

771 Returns the Command's response Packet or the Command echo if a response is not 

772 expected (e.g. sending an RP). 

773 

774 If wait_for_reply is True, return the RQ's RP (or W's I), or raise an exception 

775 if one doesn't arrive. If it is False, return the echo of the Command only. If 

776 it is None (the default), act as True for RQs, and False for all other Commands. 

777 

778 num_repeats is # of times to send the Command, in addition to the fist transmit, 

779 with gap_duration seconds between each transmission. If wait_for_reply is True, 

780 then num_repeats is ignored. 

781 

782 Commands are queued and sent FIFO, except higher-priority Commands are always 

783 sent first. 

784 

785 Will raise: 

786 ProtocolSendFailed: tried to Tx Command, but didn't get echo/reply 

787 ProtocolError: didn't attempt to Tx Command for some reason 

788 """ 

789 

790 assert gap_duration == DEFAULT_GAP_DURATION 

791 assert 0 <= num_repeats <= 3 # if QoS, only Tx x1, with no repeats 

792 

793 if qos and not self._context: 

794 _LOGGER.warning(f"{cmd} < QoS is currently disabled by this Protocol") 

795 

796 if qos.wait_for_reply and num_repeats: 

797 _LOGGER.warning(f"{cmd} < num_repeats set to 0, as wait_for_reply is True") 

798 num_repeats = 0 # the lesser crime over wait_for_reply=False 

799 

800 pkt = await super().send_cmd( # may: raise ProtocolError/ProtocolSendFailed 

801 cmd, 

802 gap_duration=gap_duration, 

803 num_repeats=num_repeats, 

804 priority=priority, 

805 qos=qos, 

806 ) 

807 

808 if not pkt: # HACK: temporary workaround for returning None 

809 raise exc.ProtocolSendFailed(f"Failed to send command: {cmd} (REPORT THIS)") 

810 

811 return pkt 

812 

813 

814RamsesProtocolT: TypeAlias = PortProtocol | ReadProtocol 

815 

816 

817def protocol_factory( 

818 msg_handler: MsgHandlerT, 

819 /, 

820 *, 

821 disable_qos: bool | None = DEFAULT_DISABLE_QOS, 

822 disable_sending: bool | None = False, 

823 enforce_include_list: bool = False, # True, None, False 

824 exclude_list: DeviceListT | None = None, 

825 include_list: DeviceListT | None = None, 

826) -> RamsesProtocolT: 

827 """Create and return a Ramses-specific async packet Protocol.""" 

828 

829 if disable_sending: 

830 _LOGGER.debug("ReadProtocol: Sending has been disabled") 

831 return ReadProtocol( 

832 msg_handler, 

833 enforce_include_list=enforce_include_list, 

834 exclude_list=exclude_list, 

835 include_list=include_list, 

836 ) 

837 

838 if disable_qos: 

839 _LOGGER.debug("PortProtocol: QoS has been disabled (will wait_for echos)") 

840 

841 return PortProtocol( 

842 msg_handler, 

843 disable_qos=disable_qos, 

844 enforce_include_list=enforce_include_list, 

845 exclude_list=exclude_list, 

846 include_list=include_list, 

847 ) 

848 

849 

850async def create_stack( 

851 msg_handler: MsgHandlerT, 

852 /, 

853 *, 

854 protocol_factory_: Callable[..., RamsesProtocolT] | None = None, 

855 transport_factory_: Awaitable[RamsesTransportT] | None = None, 

856 disable_qos: bool | None = DEFAULT_DISABLE_QOS, # True, None, False 

857 disable_sending: bool | None = False, 

858 enforce_include_list: bool = False, 

859 exclude_list: DeviceListT | None = None, 

860 include_list: DeviceListT | None = None, 

861 **kwargs: Any, # TODO: these are for the transport_factory 

862) -> tuple[RamsesProtocolT, RamsesTransportT]: 

863 """Utility function to provide a Protocol / Transport pair. 

864 

865 Architecture: gwy (client) -> msg (Protocol) -> pkt (Transport) -> HGI/log (or dict) 

866 - send Commands via awaitable Protocol.send_cmd(cmd) 

867 - receive Messages via Gateway._handle_msg(msg) callback 

868 """ 

869 

870 read_only = kwargs.get("packet_dict") or kwargs.get("packet_log") 

871 disable_sending = disable_sending or read_only 

872 

873 protocol: RamsesProtocolT = (protocol_factory_ or protocol_factory)( 

874 msg_handler, 

875 disable_qos=disable_qos, 

876 disable_sending=disable_sending, 

877 enforce_include_list=enforce_include_list, 

878 exclude_list=exclude_list, 

879 include_list=include_list, 

880 ) 

881 

882 transport: RamsesTransportT = await (transport_factory_ or transport_factory)( # type: ignore[operator] 

883 protocol, disable_sending=bool(disable_sending), **kwargs 

884 ) 

885 

886 if not kwargs.get(SZ_PORT_NAME): 

887 set_logger_timesource(transport._dt_now) 

888 _LOGGER.warning("Logger datetimes maintained as most recent packet timestamp") 

889 

890 return protocol, transport