Coverage for src/ramses_rf/database.py: 16%
220 statements
« prev ^ index » next coverage.py v7.11.3, created at 2026-01-05 21:46 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2026-01-05 21:46 +0100
1#!/usr/bin/env python3
2"""
3RAMSES RF - Message database and index.
5.. table:: Database Query Methods[^1][#fn1]
6 :widths: auto
8 ===== ============ =========== ========== ==== ========================
9 ix method name args returns uses used by
10 ===== ============ =========== ========== ==== ========================
11 i1 get Msg, kwargs tuple(Msg) i3
12 i2 contains kwargs bool i4
13 i3 _select_from kwargs tuple(Msg) i4
14 i4 qry_dtms kwargs list(dtm)
15 i5 qry sql, kwargs tuple(Msg) _msgs()
16 i6 qry_field sql, kwargs tuple(fld) e4, e5
17 i7 get_rp_codes src, dst list(Code) Discovery-supported_cmds
18 ===== ============ =========== ========== ==== ========================
20[#fn1] A word of explanation.
21[^1]: ex = entity_base.py query methods
22"""
24from __future__ import annotations
26import asyncio
27import logging
28import sqlite3
29from collections import OrderedDict
30from datetime import datetime as dt, timedelta as td
31from typing import TYPE_CHECKING, Any, NewType
33from ramses_tx import CODES_SCHEMA, RQ, Code, Message, Packet
35if TYPE_CHECKING:
36 DtmStrT = NewType("DtmStrT", str)
37 MsgDdT = OrderedDict[DtmStrT, Message]
39_LOGGER = logging.getLogger(__name__)
42def _setup_db_adapters() -> None:
43 """Set up the database adapters and converters."""
45 def adapt_datetime_iso(val: dt) -> str:
46 """Adapt datetime.datetime to timezone-naive ISO 8601 datetime to match _msgs dtm keys."""
47 return val.isoformat(timespec="microseconds")
49 sqlite3.register_adapter(dt, adapt_datetime_iso)
51 def convert_datetime(val: bytes) -> dt:
52 """Convert ISO 8601 datetime to datetime.datetime object to import dtm in msg_db."""
53 return dt.fromisoformat(val.decode())
55 sqlite3.register_converter("DTM", convert_datetime)
58def payload_keys(parsed_payload: list[dict] | dict) -> str: # type: ignore[type-arg]
59 """
60 Copy payload keys for fast query check.
62 :param parsed_payload: pre-parsed message payload dict
63 :return: string of payload keys, separated by the | char
64 """
65 _keys: str = "|"
67 def append_keys(ppl: dict) -> str: # type: ignore[type-arg]
68 _ks: str = ""
69 for k, v in ppl.items():
70 if (
71 k not in _ks and k not in _keys and v is not None
72 ): # ignore keys with None value
73 _ks += k + "|"
74 return _ks
76 if isinstance(parsed_payload, list):
77 for d in parsed_payload:
78 _keys += append_keys(d)
79 elif isinstance(parsed_payload, dict):
80 _keys += append_keys(parsed_payload)
81 return _keys
84class MessageIndex:
85 """A central in-memory SQLite3 database for indexing RF messages.
86 Index holds all the latest messages to & from all devices by `dtm`
87 (timestamp) and `hdr` header
88 (example of a hdr: ``000C|RP|01:223036|0208``)."""
90 _housekeeping_task: asyncio.Task[None]
92 def __init__(self, maintain: bool = True) -> None:
93 """Instantiate a message database/index."""
95 self.maintain = maintain
96 self._msgs: MsgDdT = OrderedDict() # stores all messages for retrieval.
97 # Filled & cleaned up in housekeeping_loop.
99 # Connect to a SQLite DB in memory
100 self._cx = sqlite3.connect(
101 ":memory:", detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
102 )
103 # detect_types should retain dt type on store/retrieve
104 self._cu = self._cx.cursor() # Create a cursor
106 _setup_db_adapters() # DTM adapter/converter
107 self._setup_db_schema()
109 if self.maintain:
110 self._lock = asyncio.Lock()
111 self._last_housekeeping: dt = None # type: ignore[assignment]
112 self._housekeeping_task = None # type: ignore[assignment]
114 self.start()
116 def __repr__(self) -> str:
117 return f"MessageIndex({len(self._msgs)} messages)" # or msg_db.count()
119 def start(self) -> None:
120 """Start the housekeeper loop."""
122 if self.maintain:
123 if self._housekeeping_task and (not self._housekeeping_task.done()):
124 return
126 self._housekeeping_task = asyncio.create_task(
127 self._housekeeping_loop(), name=f"{self.__class__.__name__}.housekeeper"
128 )
130 def stop(self) -> None:
131 """Stop the housekeeper loop."""
133 if (
134 self.maintain
135 and self._housekeeping_task
136 and (not self._housekeeping_task.done())
137 ):
138 self._housekeeping_task.cancel() # stop the housekeeper
140 self._cx.commit() # just in case
141 self._cx.close() # may still need to do queries after engine has stopped?
143 @property
144 def msgs(self) -> MsgDdT:
145 """Return the messages in the index in a threadsafe way."""
146 return self._msgs
148 def _setup_db_schema(self) -> None:
149 """Set up the message database schema.
151 .. note::
152 messages TABLE Fields:
154 - dtm message timestamp
155 - verb " I", "RQ" etc.
156 - src message origin address
157 - dst message destination address
158 - code packet code aka command class e.g. 0005, 31DA
159 - ctx message context, created from payload as index + extra markers (Heat)
160 - hdr packet header e.g. 000C|RP|01:223036|0208 (see: src/ramses_tx/frame.py)
161 - plk the keys stored in the parsed payload, separated by the | char
162 """
164 self._cu.execute(
165 """
166 CREATE TABLE messages (
167 dtm DTM NOT NULL PRIMARY KEY,
168 verb TEXT(2) NOT NULL,
169 src TEXT(12) NOT NULL,
170 dst TEXT(12) NOT NULL,
171 code TEXT(4) NOT NULL,
172 ctx TEXT,
173 hdr TEXT NOT NULL UNIQUE,
174 plk TEXT NOT NULL
175 )
176 """
177 )
179 self._cu.execute("CREATE INDEX idx_verb ON messages (verb)")
180 self._cu.execute("CREATE INDEX idx_src ON messages (src)")
181 self._cu.execute("CREATE INDEX idx_dst ON messages (dst)")
182 self._cu.execute("CREATE INDEX idx_code ON messages (code)")
183 self._cu.execute("CREATE INDEX idx_ctx ON messages (ctx)")
184 self._cu.execute("CREATE INDEX idx_hdr ON messages (hdr)")
186 self._cx.commit()
188 async def _housekeeping_loop(self) -> None:
189 """Periodically remove stale messages from the index,
190 unless `self.maintain` is False - as in (most) tests."""
192 async def housekeeping(dt_now: dt, _cutoff: td = td(days=1)) -> None:
193 """
194 Deletes all messages older than a given delta from the dict using the MessageIndex.
195 :param dt_now: current timestamp
196 :param _cutoff: the oldest timestamp to retain, default is 24 hours ago
197 """
198 msgs = None
199 dtm = dt_now - _cutoff
201 self._cu.execute("SELECT dtm FROM messages WHERE dtm >= ?", (dtm,))
202 rows = self._cu.fetchall() # fetch dtm of current messages to retain
204 try: # make this operation atomic, i.e. update self._msgs only on success
205 await self._lock.acquire()
206 self._cu.execute("DELETE FROM messages WHERE dtm < ?", (dtm,))
207 msgs = OrderedDict({row[0]: self._msgs[row[0]] for row in rows})
208 self._cx.commit()
210 except sqlite3.Error: # need to tighten?
211 self._cx.rollback()
212 else:
213 self._msgs = msgs
214 finally:
215 self._lock.release()
216 if msgs:
217 _LOGGER.debug(
218 "MessageIndex size was: %d, now: %d", len(rows), len(msgs)
219 )
221 while True:
222 self._last_housekeeping = dt.now()
223 await asyncio.sleep(3600)
224 _LOGGER.info("Starting next MessageIndex housekeeping")
225 await housekeeping(self._last_housekeeping)
227 def add(self, msg: Message) -> Message | None:
228 """
229 Add a single message to the MessageIndex.
230 Logs a warning if there is a duplicate dtm.
232 :returns: any message that was removed because it had the same header
233 """
234 # TODO: eventually, may be better to use SqlAlchemy
236 dup: tuple[Message, ...] = tuple() # avoid UnboundLocalError
237 old: Message | None = None # avoid UnboundLocalError
239 try: # TODO: remove this, or apply only when source is a real packet log?
240 # await self._lock.acquire()
241 dup = self._delete_from( # HACK: because of contrived pkt logs
242 dtm=msg.dtm # stored as such with DTM formatter
243 )
244 old = self._insert_into(msg) # will delete old msg by hdr (not dtm!)
246 except (
247 sqlite3.Error
248 ): # UNIQUE constraint failed: ? messages.dtm or .hdr (so: HACK)
249 self._cx.rollback()
251 else:
252 # _msgs dict requires a timestamp reformat
253 dtm: DtmStrT = msg.dtm.isoformat(timespec="microseconds") # type: ignore[assignment]
254 # add msg to self._msgs dict
255 self._msgs[dtm] = msg
257 finally:
258 pass # self._lock.release()
260 if (
261 dup
262 and (msg.src is not msg.dst)
263 and not msg.dst.id.startswith("18:") # HGI
264 and msg.verb != RQ # these may come very quickly
265 ): # when src==dst, expect to add duplicate, don't warn
266 _LOGGER.debug(
267 "Overwrote dtm (%s) for %s: %s (contrived log?)",
268 msg.dtm,
269 msg._pkt._hdr,
270 dup[0]._pkt,
271 )
273 return old
275 def add_record(self, src: str, code: str = "", verb: str = "") -> None:
276 """
277 Add a single record to the MessageIndex with timestamp `now()` and no Message contents.
279 :param src: device id to use as source address
280 :param code: device id to use as destination address (can be identical)
281 :param verb: two letter verb str to use
282 """
283 # Used by OtbGateway init, via entity_base.py
284 _now: dt = dt.now()
285 dtm: DtmStrT = _now.isoformat(timespec="microseconds") # type: ignore[assignment]
286 hdr = f"{code}|{verb}|{src}|00" # dummy record has no contents
288 dup = self._delete_from(hdr=hdr)
290 sql = """
291 INSERT INTO messages (dtm, verb, src, dst, code, ctx, hdr, plk)
292 VALUES (?, ?, ?, ?, ?, ?, ?, ?)
293 """
294 try:
295 self._cu.execute(
296 sql,
297 (
298 _now,
299 verb,
300 src,
301 src,
302 code,
303 None,
304 hdr,
305 "|",
306 ),
307 )
308 except sqlite3.Error:
309 self._cx.rollback()
310 else:
311 # also add dummy 3220 msg to self._msgs dict to allow maintenance loop
312 msg: Message = Message._from_pkt(
313 Packet(
314 _now, f"... {verb} --- {src} --:------ {src} {code} 005 0000000000"
315 )
316 )
317 self._msgs[dtm] = msg
319 if dup: # expected when more than one heat system in schema
320 _LOGGER.debug("Replaced record with same hdr: %s", hdr)
322 def _insert_into(self, msg: Message) -> Message | None:
323 """
324 Insert a message into the index.
326 :returns: any message replaced (by same hdr)
327 """
328 assert msg._pkt._hdr is not None, "Skipping: Packet has no hdr: {msg._pkt}"
330 if msg._pkt._ctx is True:
331 msg_pkt_ctx = "True"
332 elif msg._pkt._ctx is False:
333 msg_pkt_ctx = "False"
334 else:
335 msg_pkt_ctx = msg._pkt._ctx # can be None
337 _old_msgs = self._delete_from(hdr=msg._pkt._hdr)
339 sql = """
340 INSERT INTO messages (dtm, verb, src, dst, code, ctx, hdr, plk)
341 VALUES (?, ?, ?, ?, ?, ?, ?, ?)
342 """
344 self._cu.execute(
345 sql,
346 (
347 msg.dtm,
348 str(msg.verb),
349 msg.src.id,
350 msg.dst.id,
351 str(msg.code),
352 msg_pkt_ctx,
353 msg._pkt._hdr,
354 payload_keys(msg.payload),
355 ),
356 )
357 # _LOGGER.debug(f"Added {msg} to gwy.msg_db")
359 return _old_msgs[0] if _old_msgs else None
361 def rem(
362 self, msg: Message | None = None, **kwargs: str | dt
363 ) -> tuple[Message, ...] | None:
364 """Remove a set of message(s) from the index.
366 :returns: any messages that were removed.
367 """
368 # _LOGGER.debug(f"SQL REM msg={msg} bool{bool(msg)} kwargs={kwargs} bool(kwargs)")
369 # SQL REM
370 # msg=|| 02:044328 | | I | heat_demand | FC || {'domain_id': 'FC', 'heat_demand': 0.74}
371 # boolTrue
372 # kwargs={}
373 # bool(kwargs)
375 if not bool(msg) ^ bool(kwargs):
376 raise ValueError("Either a Message or kwargs should be provided, not both")
377 if msg:
378 kwargs["dtm"] = msg.dtm
380 msgs = None
381 try: # make this operation atomic, i.e. update self._msgs only on success
382 # await self._lock.acquire()
383 msgs = self._delete_from(**kwargs)
385 except sqlite3.Error: # need to tighten?
386 self._cx.rollback()
388 else:
389 for msg in msgs:
390 dtm: DtmStrT = msg.dtm.isoformat(timespec="microseconds") # type: ignore[assignment]
391 self._msgs.pop(dtm)
393 finally:
394 pass # self._lock.release()
396 return msgs
398 def _delete_from(self, **kwargs: bool | dt | str) -> tuple[Message, ...]:
399 """Remove message(s) from the index.
401 :returns: any messages that were removed"""
403 msgs = self._select_from(**kwargs)
405 sql = "DELETE FROM messages WHERE "
406 sql += " AND ".join(f"{k} = ?" for k in kwargs)
408 self._cu.execute(sql, tuple(kwargs.values()))
410 return msgs
412 # MessageIndex msg_db query methods
414 def get(
415 self, msg: Message | None = None, **kwargs: bool | dt | str
416 ) -> tuple[Message, ...]:
417 """
418 Public method to get a set of message(s) from the index.
420 :param msg: Message to return, by dtm (expect a single result as dtm is unique key)
421 :param kwargs: data table field names and criteria, e.g. (hdr=...)
422 :return: tuple of matching Messages
423 """
425 if not (bool(msg) ^ bool(kwargs)):
426 raise ValueError("Either a Message or kwargs should be provided, not both")
428 if msg:
429 kwargs["dtm"] = msg.dtm
431 return self._select_from(**kwargs)
433 def contains(self, **kwargs: bool | dt | str) -> bool:
434 """
435 Check if the MessageIndex contains at least 1 record that matches the provided fields.
437 :param kwargs: (exact) SQLite table field_name: required_value pairs
438 :return: True if at least one message fitting the given conditions is present, False when qry returned empty
439 """
441 return len(self.qry_dtms(**kwargs)) > 0
443 def _select_from(self, **kwargs: bool | dt | str) -> tuple[Message, ...]:
444 """
445 Select message(s) using the MessageIndex.
447 :param kwargs: (exact) SQLite table field_name: required_value pairs
448 :returns: a tuple of qualifying messages
449 """
451 # CHANGE: Use a list comprehension with a check to avoid KeyError
452 res: list[Message] = []
453 for row in self.qry_dtms(**kwargs):
454 ts: DtmStrT = row[0].isoformat(timespec="microseconds")
455 if ts in self._msgs:
456 res.append(self._msgs[ts])
457 else:
458 _LOGGER.debug("MessageIndex timestamp %s not in device messages", ts)
459 return tuple(res)
461 def qry_dtms(self, **kwargs: bool | dt | str) -> list[Any]:
462 """
463 Select from the MessageIndex a list of dtms that match the provided arguments.
465 :param kwargs: data table field names and criteria
466 :return: list of unformatted dtms that match, useful for msg lookup, or an empty list if 0 matches
467 """
468 # tweak kwargs as stored in SQLite, inverse from _insert_into():
469 kw = {key: value for key, value in kwargs.items() if key != "ctx"}
470 if "ctx" in kwargs:
471 if isinstance(kwargs["ctx"], str):
472 kw["ctx"] = kwargs["ctx"]
473 elif kwargs["ctx"]:
474 kw["ctx"] = "True"
475 else:
476 kw["ctx"] = "False"
478 sql = "SELECT dtm FROM messages WHERE "
479 sql += " AND ".join(f"{k} = ?" for k in kw)
481 self._cu.execute(sql, tuple(kw.values()))
482 return self._cu.fetchall()
484 def qry(self, sql: str, parameters: tuple[str, ...]) -> tuple[Message, ...]:
485 """
486 Get a tuple of messages from _msgs using the index, given sql and parameters.
488 :param sql: a bespoke SQL query SELECT string that should return dtm as first field
489 :param parameters: tuple of kwargs with the selection filter
490 :return: a tuple of qualifying messages
491 """
493 if "SELECT" not in sql:
494 raise ValueError(f"{self}: Only SELECT queries are allowed")
496 self._cu.execute(sql, parameters)
498 lst: list[Message] = []
499 # stamp = list(self._msgs)[0] if len(self._msgs) > 0 else "N/A" # for debug
500 for row in self._cu.fetchall():
501 ts: DtmStrT = row[0].isoformat(
502 timespec="microseconds"
503 ) # must reformat from DTM
504 # _LOGGER.debug(
505 # f"QRY Msg key raw: {row[0]} Reformatted: {ts} _msgs stamp format: {stamp}"
506 # )
507 # QRY Msg key raw: 2022-09-08 13:43:31.536862 Reformatted: 2022-09-08T13:43:31.536862
508 # _msgs stamp format: 2022-09-08T13:40:52.447364
509 if ts in self._msgs:
510 lst.append(self._msgs[ts])
511 # _LOGGER.debug("MessageIndex ts %s added to qry.lst", ts) # too frequent
512 else: # happens in tests with artificial msg from heat
513 _LOGGER.info("MessageIndex timestamp %s not in device messages", ts)
514 return tuple(lst)
516 def get_rp_codes(self, parameters: tuple[str, ...]) -> list[Code]:
517 """
518 Get a list of Codes from the index, given parameters.
520 :param parameters: tuple of additional kwargs
521 :return: list of Code: value pairs
522 """
524 def get_code(code: str) -> Code:
525 for Cd in CODES_SCHEMA:
526 if code == Cd:
527 return Cd
528 raise LookupError(f"Failed to find matching code for {code}")
530 sql = """
531 SELECT code from messages WHERE verb is 'RP' AND (src = ? OR dst = ?)
532 """
533 if "SELECT" not in sql:
534 raise ValueError(f"{self}: Only SELECT queries are allowed")
536 self._cu.execute(sql, parameters)
537 res = self._cu.fetchall()
538 return [get_code(res[0]) for res[0] in self._cu.fetchall()]
540 def qry_field(
541 self, sql: str, parameters: tuple[str, ...]
542 ) -> list[tuple[dt | str, str]]:
543 """
544 Get a list of fields from the index, given select sql and parameters.
546 :param sql: a bespoke SQL query SELECT string
547 :param parameters: tuple of additional kwargs
548 :return: list of key: value pairs as defined in sql
549 """
551 if "SELECT" not in sql:
552 raise ValueError(f"{self}: Only SELECT queries are allowed")
554 self._cu.execute(sql, parameters)
555 return self._cu.fetchall()
557 def all(self, include_expired: bool = False) -> tuple[Message, ...]:
558 """Get all messages from the index."""
560 self._cu.execute("SELECT * FROM messages")
562 lst: list[Message] = []
563 # stamp = list(self._msgs)[0] if len(self._msgs) > 0 else "N/A"
564 for row in self._cu.fetchall():
565 ts: DtmStrT = row[0].isoformat(timespec="microseconds")
566 # _LOGGER.debug(
567 # f"ALL Msg key raw: {row[0]} Reformatted: {ts} _msgs stamp format: {stamp}"
568 # )
569 # ALL Msg key raw: 2022-05-02 10:02:02.744905
570 # Reformatted: 2022-05-02T10:02:02.744905
571 # _msgs stamp format: 2022-05-02T10:02:02.744905
572 if ts in self._msgs:
573 # if include_expired or not self._msgs[ts].HAS_EXPIRED: # not working
574 lst.append(self._msgs[ts])
575 _LOGGER.debug("MessageIndex ts %s added to all.lst", ts)
576 else: # happens in tests and real evohome setups with dummy msg from heat init
577 _LOGGER.debug("MessageIndex ts %s not in device messages", ts)
578 return tuple(lst)
580 def clr(self) -> None:
581 """Clear the message index (remove indexes of all messages)."""
583 self._cu.execute("DELETE FROM messages")
584 self._cx.commit()
586 self._msgs.clear()
588 # def _msgs(self, device_id: DeviceIdT) -> tuple[Message, ...]:
589 # msgs = [msg for msg in self._msgs.values() if msg.src.id == device_id]
590 # return msgs