Coverage for src/ramses_rf/database.py: 16%

220 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2026-01-05 21:46 +0100

1#!/usr/bin/env python3 

2""" 

3RAMSES RF - Message database and index. 

4 

5.. table:: Database Query Methods[^1][#fn1] 

6 :widths: auto 

7 

8 ===== ============ =========== ========== ==== ======================== 

9 ix method name args returns uses used by 

10 ===== ============ =========== ========== ==== ======================== 

11 i1 get Msg, kwargs tuple(Msg) i3 

12 i2 contains kwargs bool i4 

13 i3 _select_from kwargs tuple(Msg) i4 

14 i4 qry_dtms kwargs list(dtm) 

15 i5 qry sql, kwargs tuple(Msg) _msgs() 

16 i6 qry_field sql, kwargs tuple(fld) e4, e5 

17 i7 get_rp_codes src, dst list(Code) Discovery-supported_cmds 

18 ===== ============ =========== ========== ==== ======================== 

19 

20[#fn1] A word of explanation. 

21[^1]: ex = entity_base.py query methods 

22""" 

23 

24from __future__ import annotations 

25 

26import asyncio 

27import logging 

28import sqlite3 

29from collections import OrderedDict 

30from datetime import datetime as dt, timedelta as td 

31from typing import TYPE_CHECKING, Any, NewType 

32 

33from ramses_tx import CODES_SCHEMA, RQ, Code, Message, Packet 

34 

35if TYPE_CHECKING: 

36 DtmStrT = NewType("DtmStrT", str) 

37 MsgDdT = OrderedDict[DtmStrT, Message] 

38 

39_LOGGER = logging.getLogger(__name__) 

40 

41 

42def _setup_db_adapters() -> None: 

43 """Set up the database adapters and converters.""" 

44 

45 def adapt_datetime_iso(val: dt) -> str: 

46 """Adapt datetime.datetime to timezone-naive ISO 8601 datetime to match _msgs dtm keys.""" 

47 return val.isoformat(timespec="microseconds") 

48 

49 sqlite3.register_adapter(dt, adapt_datetime_iso) 

50 

51 def convert_datetime(val: bytes) -> dt: 

52 """Convert ISO 8601 datetime to datetime.datetime object to import dtm in msg_db.""" 

53 return dt.fromisoformat(val.decode()) 

54 

55 sqlite3.register_converter("DTM", convert_datetime) 

56 

57 

58def payload_keys(parsed_payload: list[dict] | dict) -> str: # type: ignore[type-arg] 

59 """ 

60 Copy payload keys for fast query check. 

61 

62 :param parsed_payload: pre-parsed message payload dict 

63 :return: string of payload keys, separated by the | char 

64 """ 

65 _keys: str = "|" 

66 

67 def append_keys(ppl: dict) -> str: # type: ignore[type-arg] 

68 _ks: str = "" 

69 for k, v in ppl.items(): 

70 if ( 

71 k not in _ks and k not in _keys and v is not None 

72 ): # ignore keys with None value 

73 _ks += k + "|" 

74 return _ks 

75 

76 if isinstance(parsed_payload, list): 

77 for d in parsed_payload: 

78 _keys += append_keys(d) 

79 elif isinstance(parsed_payload, dict): 

80 _keys += append_keys(parsed_payload) 

81 return _keys 

82 

83 

84class MessageIndex: 

85 """A central in-memory SQLite3 database for indexing RF messages. 

86 Index holds all the latest messages to & from all devices by `dtm` 

87 (timestamp) and `hdr` header 

88 (example of a hdr: ``000C|RP|01:223036|0208``).""" 

89 

90 _housekeeping_task: asyncio.Task[None] 

91 

92 def __init__(self, maintain: bool = True) -> None: 

93 """Instantiate a message database/index.""" 

94 

95 self.maintain = maintain 

96 self._msgs: MsgDdT = OrderedDict() # stores all messages for retrieval. 

97 # Filled & cleaned up in housekeeping_loop. 

98 

99 # Connect to a SQLite DB in memory 

100 self._cx = sqlite3.connect( 

101 ":memory:", detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES 

102 ) 

103 # detect_types should retain dt type on store/retrieve 

104 self._cu = self._cx.cursor() # Create a cursor 

105 

106 _setup_db_adapters() # DTM adapter/converter 

107 self._setup_db_schema() 

108 

109 if self.maintain: 

110 self._lock = asyncio.Lock() 

111 self._last_housekeeping: dt = None # type: ignore[assignment] 

112 self._housekeeping_task = None # type: ignore[assignment] 

113 

114 self.start() 

115 

116 def __repr__(self) -> str: 

117 return f"MessageIndex({len(self._msgs)} messages)" # or msg_db.count() 

118 

119 def start(self) -> None: 

120 """Start the housekeeper loop.""" 

121 

122 if self.maintain: 

123 if self._housekeeping_task and (not self._housekeeping_task.done()): 

124 return 

125 

126 self._housekeeping_task = asyncio.create_task( 

127 self._housekeeping_loop(), name=f"{self.__class__.__name__}.housekeeper" 

128 ) 

129 

130 def stop(self) -> None: 

131 """Stop the housekeeper loop.""" 

132 

133 if ( 

134 self.maintain 

135 and self._housekeeping_task 

136 and (not self._housekeeping_task.done()) 

137 ): 

138 self._housekeeping_task.cancel() # stop the housekeeper 

139 

140 self._cx.commit() # just in case 

141 self._cx.close() # may still need to do queries after engine has stopped? 

142 

143 @property 

144 def msgs(self) -> MsgDdT: 

145 """Return the messages in the index in a threadsafe way.""" 

146 return self._msgs 

147 

148 def _setup_db_schema(self) -> None: 

149 """Set up the message database schema. 

150 

151 .. note:: 

152 messages TABLE Fields: 

153 

154 - dtm message timestamp 

155 - verb " I", "RQ" etc. 

156 - src message origin address 

157 - dst message destination address 

158 - code packet code aka command class e.g. 0005, 31DA 

159 - ctx message context, created from payload as index + extra markers (Heat) 

160 - hdr packet header e.g. 000C|RP|01:223036|0208 (see: src/ramses_tx/frame.py) 

161 - plk the keys stored in the parsed payload, separated by the | char 

162 """ 

163 

164 self._cu.execute( 

165 """ 

166 CREATE TABLE messages ( 

167 dtm DTM NOT NULL PRIMARY KEY, 

168 verb TEXT(2) NOT NULL, 

169 src TEXT(12) NOT NULL, 

170 dst TEXT(12) NOT NULL, 

171 code TEXT(4) NOT NULL, 

172 ctx TEXT, 

173 hdr TEXT NOT NULL UNIQUE, 

174 plk TEXT NOT NULL 

175 ) 

176 """ 

177 ) 

178 

179 self._cu.execute("CREATE INDEX idx_verb ON messages (verb)") 

180 self._cu.execute("CREATE INDEX idx_src ON messages (src)") 

181 self._cu.execute("CREATE INDEX idx_dst ON messages (dst)") 

182 self._cu.execute("CREATE INDEX idx_code ON messages (code)") 

183 self._cu.execute("CREATE INDEX idx_ctx ON messages (ctx)") 

184 self._cu.execute("CREATE INDEX idx_hdr ON messages (hdr)") 

185 

186 self._cx.commit() 

187 

188 async def _housekeeping_loop(self) -> None: 

189 """Periodically remove stale messages from the index, 

190 unless `self.maintain` is False - as in (most) tests.""" 

191 

192 async def housekeeping(dt_now: dt, _cutoff: td = td(days=1)) -> None: 

193 """ 

194 Deletes all messages older than a given delta from the dict using the MessageIndex. 

195 :param dt_now: current timestamp 

196 :param _cutoff: the oldest timestamp to retain, default is 24 hours ago 

197 """ 

198 msgs = None 

199 dtm = dt_now - _cutoff 

200 

201 self._cu.execute("SELECT dtm FROM messages WHERE dtm >= ?", (dtm,)) 

202 rows = self._cu.fetchall() # fetch dtm of current messages to retain 

203 

204 try: # make this operation atomic, i.e. update self._msgs only on success 

205 await self._lock.acquire() 

206 self._cu.execute("DELETE FROM messages WHERE dtm < ?", (dtm,)) 

207 msgs = OrderedDict({row[0]: self._msgs[row[0]] for row in rows}) 

208 self._cx.commit() 

209 

210 except sqlite3.Error: # need to tighten? 

211 self._cx.rollback() 

212 else: 

213 self._msgs = msgs 

214 finally: 

215 self._lock.release() 

216 if msgs: 

217 _LOGGER.debug( 

218 "MessageIndex size was: %d, now: %d", len(rows), len(msgs) 

219 ) 

220 

221 while True: 

222 self._last_housekeeping = dt.now() 

223 await asyncio.sleep(3600) 

224 _LOGGER.info("Starting next MessageIndex housekeeping") 

225 await housekeeping(self._last_housekeeping) 

226 

227 def add(self, msg: Message) -> Message | None: 

228 """ 

229 Add a single message to the MessageIndex. 

230 Logs a warning if there is a duplicate dtm. 

231 

232 :returns: any message that was removed because it had the same header 

233 """ 

234 # TODO: eventually, may be better to use SqlAlchemy 

235 

236 dup: tuple[Message, ...] = tuple() # avoid UnboundLocalError 

237 old: Message | None = None # avoid UnboundLocalError 

238 

239 try: # TODO: remove this, or apply only when source is a real packet log? 

240 # await self._lock.acquire() 

241 dup = self._delete_from( # HACK: because of contrived pkt logs 

242 dtm=msg.dtm # stored as such with DTM formatter 

243 ) 

244 old = self._insert_into(msg) # will delete old msg by hdr (not dtm!) 

245 

246 except ( 

247 sqlite3.Error 

248 ): # UNIQUE constraint failed: ? messages.dtm or .hdr (so: HACK) 

249 self._cx.rollback() 

250 

251 else: 

252 # _msgs dict requires a timestamp reformat 

253 dtm: DtmStrT = msg.dtm.isoformat(timespec="microseconds") # type: ignore[assignment] 

254 # add msg to self._msgs dict 

255 self._msgs[dtm] = msg 

256 

257 finally: 

258 pass # self._lock.release() 

259 

260 if ( 

261 dup 

262 and (msg.src is not msg.dst) 

263 and not msg.dst.id.startswith("18:") # HGI 

264 and msg.verb != RQ # these may come very quickly 

265 ): # when src==dst, expect to add duplicate, don't warn 

266 _LOGGER.debug( 

267 "Overwrote dtm (%s) for %s: %s (contrived log?)", 

268 msg.dtm, 

269 msg._pkt._hdr, 

270 dup[0]._pkt, 

271 ) 

272 

273 return old 

274 

275 def add_record(self, src: str, code: str = "", verb: str = "") -> None: 

276 """ 

277 Add a single record to the MessageIndex with timestamp `now()` and no Message contents. 

278 

279 :param src: device id to use as source address 

280 :param code: device id to use as destination address (can be identical) 

281 :param verb: two letter verb str to use 

282 """ 

283 # Used by OtbGateway init, via entity_base.py 

284 _now: dt = dt.now() 

285 dtm: DtmStrT = _now.isoformat(timespec="microseconds") # type: ignore[assignment] 

286 hdr = f"{code}|{verb}|{src}|00" # dummy record has no contents 

287 

288 dup = self._delete_from(hdr=hdr) 

289 

290 sql = """ 

291 INSERT INTO messages (dtm, verb, src, dst, code, ctx, hdr, plk) 

292 VALUES (?, ?, ?, ?, ?, ?, ?, ?) 

293 """ 

294 try: 

295 self._cu.execute( 

296 sql, 

297 ( 

298 _now, 

299 verb, 

300 src, 

301 src, 

302 code, 

303 None, 

304 hdr, 

305 "|", 

306 ), 

307 ) 

308 except sqlite3.Error: 

309 self._cx.rollback() 

310 else: 

311 # also add dummy 3220 msg to self._msgs dict to allow maintenance loop 

312 msg: Message = Message._from_pkt( 

313 Packet( 

314 _now, f"... {verb} --- {src} --:------ {src} {code} 005 0000000000" 

315 ) 

316 ) 

317 self._msgs[dtm] = msg 

318 

319 if dup: # expected when more than one heat system in schema 

320 _LOGGER.debug("Replaced record with same hdr: %s", hdr) 

321 

322 def _insert_into(self, msg: Message) -> Message | None: 

323 """ 

324 Insert a message into the index. 

325 

326 :returns: any message replaced (by same hdr) 

327 """ 

328 assert msg._pkt._hdr is not None, "Skipping: Packet has no hdr: {msg._pkt}" 

329 

330 if msg._pkt._ctx is True: 

331 msg_pkt_ctx = "True" 

332 elif msg._pkt._ctx is False: 

333 msg_pkt_ctx = "False" 

334 else: 

335 msg_pkt_ctx = msg._pkt._ctx # can be None 

336 

337 _old_msgs = self._delete_from(hdr=msg._pkt._hdr) 

338 

339 sql = """ 

340 INSERT INTO messages (dtm, verb, src, dst, code, ctx, hdr, plk) 

341 VALUES (?, ?, ?, ?, ?, ?, ?, ?) 

342 """ 

343 

344 self._cu.execute( 

345 sql, 

346 ( 

347 msg.dtm, 

348 str(msg.verb), 

349 msg.src.id, 

350 msg.dst.id, 

351 str(msg.code), 

352 msg_pkt_ctx, 

353 msg._pkt._hdr, 

354 payload_keys(msg.payload), 

355 ), 

356 ) 

357 # _LOGGER.debug(f"Added {msg} to gwy.msg_db") 

358 

359 return _old_msgs[0] if _old_msgs else None 

360 

361 def rem( 

362 self, msg: Message | None = None, **kwargs: str | dt 

363 ) -> tuple[Message, ...] | None: 

364 """Remove a set of message(s) from the index. 

365 

366 :returns: any messages that were removed. 

367 """ 

368 # _LOGGER.debug(f"SQL REM msg={msg} bool{bool(msg)} kwargs={kwargs} bool(kwargs)") 

369 # SQL REM 

370 # msg=|| 02:044328 | | I | heat_demand | FC || {'domain_id': 'FC', 'heat_demand': 0.74} 

371 # boolTrue 

372 # kwargs={} 

373 # bool(kwargs) 

374 

375 if not bool(msg) ^ bool(kwargs): 

376 raise ValueError("Either a Message or kwargs should be provided, not both") 

377 if msg: 

378 kwargs["dtm"] = msg.dtm 

379 

380 msgs = None 

381 try: # make this operation atomic, i.e. update self._msgs only on success 

382 # await self._lock.acquire() 

383 msgs = self._delete_from(**kwargs) 

384 

385 except sqlite3.Error: # need to tighten? 

386 self._cx.rollback() 

387 

388 else: 

389 for msg in msgs: 

390 dtm: DtmStrT = msg.dtm.isoformat(timespec="microseconds") # type: ignore[assignment] 

391 self._msgs.pop(dtm) 

392 

393 finally: 

394 pass # self._lock.release() 

395 

396 return msgs 

397 

398 def _delete_from(self, **kwargs: bool | dt | str) -> tuple[Message, ...]: 

399 """Remove message(s) from the index. 

400 

401 :returns: any messages that were removed""" 

402 

403 msgs = self._select_from(**kwargs) 

404 

405 sql = "DELETE FROM messages WHERE " 

406 sql += " AND ".join(f"{k} = ?" for k in kwargs) 

407 

408 self._cu.execute(sql, tuple(kwargs.values())) 

409 

410 return msgs 

411 

412 # MessageIndex msg_db query methods 

413 

414 def get( 

415 self, msg: Message | None = None, **kwargs: bool | dt | str 

416 ) -> tuple[Message, ...]: 

417 """ 

418 Public method to get a set of message(s) from the index. 

419 

420 :param msg: Message to return, by dtm (expect a single result as dtm is unique key) 

421 :param kwargs: data table field names and criteria, e.g. (hdr=...) 

422 :return: tuple of matching Messages 

423 """ 

424 

425 if not (bool(msg) ^ bool(kwargs)): 

426 raise ValueError("Either a Message or kwargs should be provided, not both") 

427 

428 if msg: 

429 kwargs["dtm"] = msg.dtm 

430 

431 return self._select_from(**kwargs) 

432 

433 def contains(self, **kwargs: bool | dt | str) -> bool: 

434 """ 

435 Check if the MessageIndex contains at least 1 record that matches the provided fields. 

436 

437 :param kwargs: (exact) SQLite table field_name: required_value pairs 

438 :return: True if at least one message fitting the given conditions is present, False when qry returned empty 

439 """ 

440 

441 return len(self.qry_dtms(**kwargs)) > 0 

442 

443 def _select_from(self, **kwargs: bool | dt | str) -> tuple[Message, ...]: 

444 """ 

445 Select message(s) using the MessageIndex. 

446 

447 :param kwargs: (exact) SQLite table field_name: required_value pairs 

448 :returns: a tuple of qualifying messages 

449 """ 

450 

451 # CHANGE: Use a list comprehension with a check to avoid KeyError 

452 res: list[Message] = [] 

453 for row in self.qry_dtms(**kwargs): 

454 ts: DtmStrT = row[0].isoformat(timespec="microseconds") 

455 if ts in self._msgs: 

456 res.append(self._msgs[ts]) 

457 else: 

458 _LOGGER.debug("MessageIndex timestamp %s not in device messages", ts) 

459 return tuple(res) 

460 

461 def qry_dtms(self, **kwargs: bool | dt | str) -> list[Any]: 

462 """ 

463 Select from the MessageIndex a list of dtms that match the provided arguments. 

464 

465 :param kwargs: data table field names and criteria 

466 :return: list of unformatted dtms that match, useful for msg lookup, or an empty list if 0 matches 

467 """ 

468 # tweak kwargs as stored in SQLite, inverse from _insert_into(): 

469 kw = {key: value for key, value in kwargs.items() if key != "ctx"} 

470 if "ctx" in kwargs: 

471 if isinstance(kwargs["ctx"], str): 

472 kw["ctx"] = kwargs["ctx"] 

473 elif kwargs["ctx"]: 

474 kw["ctx"] = "True" 

475 else: 

476 kw["ctx"] = "False" 

477 

478 sql = "SELECT dtm FROM messages WHERE " 

479 sql += " AND ".join(f"{k} = ?" for k in kw) 

480 

481 self._cu.execute(sql, tuple(kw.values())) 

482 return self._cu.fetchall() 

483 

484 def qry(self, sql: str, parameters: tuple[str, ...]) -> tuple[Message, ...]: 

485 """ 

486 Get a tuple of messages from _msgs using the index, given sql and parameters. 

487 

488 :param sql: a bespoke SQL query SELECT string that should return dtm as first field 

489 :param parameters: tuple of kwargs with the selection filter 

490 :return: a tuple of qualifying messages 

491 """ 

492 

493 if "SELECT" not in sql: 

494 raise ValueError(f"{self}: Only SELECT queries are allowed") 

495 

496 self._cu.execute(sql, parameters) 

497 

498 lst: list[Message] = [] 

499 # stamp = list(self._msgs)[0] if len(self._msgs) > 0 else "N/A" # for debug 

500 for row in self._cu.fetchall(): 

501 ts: DtmStrT = row[0].isoformat( 

502 timespec="microseconds" 

503 ) # must reformat from DTM 

504 # _LOGGER.debug( 

505 # f"QRY Msg key raw: {row[0]} Reformatted: {ts} _msgs stamp format: {stamp}" 

506 # ) 

507 # QRY Msg key raw: 2022-09-08 13:43:31.536862 Reformatted: 2022-09-08T13:43:31.536862 

508 # _msgs stamp format: 2022-09-08T13:40:52.447364 

509 if ts in self._msgs: 

510 lst.append(self._msgs[ts]) 

511 # _LOGGER.debug("MessageIndex ts %s added to qry.lst", ts) # too frequent 

512 else: # happens in tests with artificial msg from heat 

513 _LOGGER.info("MessageIndex timestamp %s not in device messages", ts) 

514 return tuple(lst) 

515 

516 def get_rp_codes(self, parameters: tuple[str, ...]) -> list[Code]: 

517 """ 

518 Get a list of Codes from the index, given parameters. 

519 

520 :param parameters: tuple of additional kwargs 

521 :return: list of Code: value pairs 

522 """ 

523 

524 def get_code(code: str) -> Code: 

525 for Cd in CODES_SCHEMA: 

526 if code == Cd: 

527 return Cd 

528 raise LookupError(f"Failed to find matching code for {code}") 

529 

530 sql = """ 

531 SELECT code from messages WHERE verb is 'RP' AND (src = ? OR dst = ?) 

532 """ 

533 if "SELECT" not in sql: 

534 raise ValueError(f"{self}: Only SELECT queries are allowed") 

535 

536 self._cu.execute(sql, parameters) 

537 res = self._cu.fetchall() 

538 return [get_code(res[0]) for res[0] in self._cu.fetchall()] 

539 

540 def qry_field( 

541 self, sql: str, parameters: tuple[str, ...] 

542 ) -> list[tuple[dt | str, str]]: 

543 """ 

544 Get a list of fields from the index, given select sql and parameters. 

545 

546 :param sql: a bespoke SQL query SELECT string 

547 :param parameters: tuple of additional kwargs 

548 :return: list of key: value pairs as defined in sql 

549 """ 

550 

551 if "SELECT" not in sql: 

552 raise ValueError(f"{self}: Only SELECT queries are allowed") 

553 

554 self._cu.execute(sql, parameters) 

555 return self._cu.fetchall() 

556 

557 def all(self, include_expired: bool = False) -> tuple[Message, ...]: 

558 """Get all messages from the index.""" 

559 

560 self._cu.execute("SELECT * FROM messages") 

561 

562 lst: list[Message] = [] 

563 # stamp = list(self._msgs)[0] if len(self._msgs) > 0 else "N/A" 

564 for row in self._cu.fetchall(): 

565 ts: DtmStrT = row[0].isoformat(timespec="microseconds") 

566 # _LOGGER.debug( 

567 # f"ALL Msg key raw: {row[0]} Reformatted: {ts} _msgs stamp format: {stamp}" 

568 # ) 

569 # ALL Msg key raw: 2022-05-02 10:02:02.744905 

570 # Reformatted: 2022-05-02T10:02:02.744905 

571 # _msgs stamp format: 2022-05-02T10:02:02.744905 

572 if ts in self._msgs: 

573 # if include_expired or not self._msgs[ts].HAS_EXPIRED: # not working 

574 lst.append(self._msgs[ts]) 

575 _LOGGER.debug("MessageIndex ts %s added to all.lst", ts) 

576 else: # happens in tests and real evohome setups with dummy msg from heat init 

577 _LOGGER.debug("MessageIndex ts %s not in device messages", ts) 

578 return tuple(lst) 

579 

580 def clr(self) -> None: 

581 """Clear the message index (remove indexes of all messages).""" 

582 

583 self._cu.execute("DELETE FROM messages") 

584 self._cx.commit() 

585 

586 self._msgs.clear() 

587 

588 # def _msgs(self, device_id: DeviceIdT) -> tuple[Message, ...]: 

589 # msgs = [msg for msg in self._msgs.values() if msg.src.id == device_id] 

590 # return msgs