Coverage for nexios\websockets\channels.py: 31%
118 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
1import sys
2import os
3import time
4import uuid
5from nexios import logging as nexios_logger
6from .utils import (
7 ChannelAddStatusEnum,
8 ChannelRemoveStatusEnum,
9 GroupSendStatusEnum,
10 PayloadTypeEnum,
11 ChannelMessageDC,
12)
13import typing
14from nexios.websockets import WebSocket
16logging = nexios_logger.getLogger("nexios")
19class Channel:
20 def __init__(
21 self,
22 websocket: WebSocket,
23 payload_type: str,
24 expires: typing.Optional[int] = None,
25 ) -> None:
26 """Main websocket channel class.
28 Args:
29 websocket (WebSocket): Starlette websocket
30 expires (int): Channel ttl in seconds
31 encoding (str): encoding of payload (str, bytes, json)
32 uuid (str): channel uuid
33 created (tim): channel creation time
34 """
35 assert isinstance(websocket, WebSocket)
36 assert isinstance(expires, int)
37 assert isinstance(payload_type, str) and payload_type in [
38 PayloadTypeEnum.JSON.value,
39 PayloadTypeEnum.TEXT.value,
40 PayloadTypeEnum.BYTES.value,
41 ]
43 self.websocket = websocket
44 self.expires = expires
45 self.payload_type = payload_type
46 self.uuid = uuid.uuid4()
47 self.created = time.time()
49 async def _send(self, payload: typing.Any) -> None:
50 try:
51 if self.payload_type == "json":
52 await self.websocket.send_json(payload)
53 elif self.payload_type == "text":
54 await self.websocket.send_text(payload)
55 elif self.payload_type == "bytes":
56 await self.websocket.send_bytes(payload)
57 else:
58 await self.websocket.send(payload)
59 except RuntimeError as error:
60 logging.debug(error)
62 self.created = time.time()
64 async def _is_expired(self) -> bool:
65 if not self.expires:
66 return False
67 return (self.expires + int(self.created)) < time.time()
69 def __repr__(self) -> str:
70 return f"{self.__class__.__name__} {self.uuid=} {self.payload_type=} {self.expires=}"
73class ChannelBox:
74 CHANNEL_GROUPS: typing.Dict[str, typing.Any] = (
75 {}
76 ) # groups of channels ~ key: group_name, val: dict of channels
77 CHANNEL_GROUPS_HISTORY: typing.Dict[str, typing.Any] = {} # history messages
78 HISTORY_SIZE: int = int(os.getenv("CHANNEL_BOX_HISTORY_SIZE", 1_048_576))
80 @classmethod
81 async def add_channel_to_group(
82 cls,
83 channel: Channel,
84 group_name: str = "default",
85 ) -> ChannelAddStatusEnum:
86 """Add channel to group.
88 Args:
89 channel (Channel): Instance of Channel class
90 group_name (str): Group name
92 """
93 assert group_name, "Group name must to be set."
95 if group_name not in cls.CHANNEL_GROUPS:
96 cls.CHANNEL_GROUPS[group_name] = {}
97 channel_add_status = ChannelAddStatusEnum.CHANNEL_ADDED
98 else:
99 channel_add_status = ChannelAddStatusEnum.CHANNEL_EXIST
101 cls.CHANNEL_GROUPS[group_name][channel] = ...
102 return channel_add_status
104 @classmethod
105 async def remove_channel_from_group(
106 cls,
107 channel: Channel,
108 group_name: str,
109 ) -> ChannelRemoveStatusEnum:
110 """Remove channel from group.
112 Args:
113 channel (Channel): Instance of Channel class
114 group_name (str): Group name
115 """
116 channel_remove_status: typing.Any = None
117 if channel in cls.CHANNEL_GROUPS.get(group_name, {}):
118 try:
119 del cls.CHANNEL_GROUPS[group_name][channel]
120 channel_remove_status = ChannelRemoveStatusEnum.CHANNEL_REMOVED
121 except KeyError:
122 channel_remove_status = ChannelRemoveStatusEnum.CHANNEL_DOES_NOT_EXIST
124 if not any(cls.CHANNEL_GROUPS.get(group_name, {})):
125 try:
126 del cls.CHANNEL_GROUPS[group_name]
127 channel_remove_status = ChannelRemoveStatusEnum.GROUP_REMOVED
128 except KeyError:
129 channel_remove_status = ChannelRemoveStatusEnum.GROUP_DOES_NOT_EXIST
131 await cls._clean_expired()
132 return channel_remove_status
134 @classmethod
135 async def group_send(
136 cls,
137 group_name: str = "default",
138 payload: typing.Union[typing.Dict[str, typing.Any], str, bytes] = {},
139 save_history: bool = False,
140 ) -> GroupSendStatusEnum:
141 """Send payload to all channels connected to group.
143 Args:
144 group_name (str, optional): Group name
145 payload (dict, optional): Payload to channel
146 save_history (bool, optional): Save message history. Defaults to False.
148 """
149 assert group_name, "Group name must to be set."
151 if save_history:
152 cls.CHANNEL_GROUPS_HISTORY.setdefault(group_name, [])
153 cls.CHANNEL_GROUPS_HISTORY[group_name].append(
154 ChannelMessageDC(
155 payload=payload, # type:ignore
156 )
157 )
158 if sys.getsizeof(cls.CHANNEL_GROUPS_HISTORY[group_name]) > cls.HISTORY_SIZE:
159 cls.CHANNEL_GROUPS_HISTORY[group_name] = []
161 group_send_status = GroupSendStatusEnum.NO_SUCH_GROUP
162 for channel in cls.CHANNEL_GROUPS.get(group_name, {}):
163 await channel._send(payload)
164 group_send_status = GroupSendStatusEnum.GROUP_SEND
166 return group_send_status
168 @classmethod
169 async def show_groups(cls) -> typing.Dict[str, typing.Any]:
170 return cls.CHANNEL_GROUPS
172 @classmethod
173 async def flush_groups(cls) -> None:
174 cls.CHANNEL_GROUPS = {}
176 @classmethod
177 async def show_history(
178 cls,
179 group_name: str = "",
180 ) -> typing.Dict[str, typing.Any]:
181 return (
182 cls.CHANNEL_GROUPS_HISTORY.get(group_name, {})
183 if group_name
184 else cls.CHANNEL_GROUPS_HISTORY
185 )
187 @classmethod
188 async def flush_history(cls) -> None:
189 cls.CHANNEL_GROUPS_HISTORY = {}
191 @classmethod
192 async def _clean_expired(cls) -> None:
193 for group_name in list(cls.CHANNEL_GROUPS):
194 for channel in cls.CHANNEL_GROUPS.get(group_name, {}):
195 _is_expired = await channel._is_expired()
196 if _is_expired:
197 try:
198 del cls.CHANNEL_GROUPS[group_name][channel]
199 except KeyError:
200 logging.debug("No such channel")
202 if not any(cls.CHANNEL_GROUPS.get(group_name, {})):
203 try:
204 del cls.CHANNEL_GROUPS[group_name]
205 except KeyError:
206 logging.debug("No such group")
208 @classmethod
209 async def close_all_connections(cls) -> None:
210 """
211 Close all WebSocket connections in all groups.
212 """
213 for group_name, channels in cls.CHANNEL_GROUPS.items():
214 for channel in list(
215 channels.keys()
216 ): # Use list() to avoid RuntimeError due to dict size change
217 try:
218 await channel.websocket.close()
219 logging.debug(
220 f"Closed connection for channel {channel.uuid} in group {group_name}"
221 )
222 except Exception as e:
223 logging.error(
224 f"Failed to close connection for channel {channel.uuid}: {e}"
225 )
227 cls.CHANNEL_GROUPS = {}
228 logging.debug("All connections closed and groups cleared.")