Coverage for nexios\websockets\channels.py: 31%

118 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-21 20:31 +0100

1import sys 

2import os 

3import time 

4import uuid 

5from nexios import logging as nexios_logger 

6from .utils import ( 

7 ChannelAddStatusEnum, 

8 ChannelRemoveStatusEnum, 

9 GroupSendStatusEnum, 

10 PayloadTypeEnum, 

11 ChannelMessageDC, 

12) 

13import typing 

14from nexios.websockets import WebSocket 

15 

16logging = nexios_logger.getLogger("nexios") 

17 

18 

19class Channel: 

20 def __init__( 

21 self, 

22 websocket: WebSocket, 

23 payload_type: str, 

24 expires: typing.Optional[int] = None, 

25 ) -> None: 

26 """Main websocket channel class. 

27 

28 Args: 

29 websocket (WebSocket): Starlette websocket 

30 expires (int): Channel ttl in seconds 

31 encoding (str): encoding of payload (str, bytes, json) 

32 uuid (str): channel uuid 

33 created (tim): channel creation time 

34 """ 

35 assert isinstance(websocket, WebSocket) 

36 assert isinstance(expires, int) 

37 assert isinstance(payload_type, str) and payload_type in [ 

38 PayloadTypeEnum.JSON.value, 

39 PayloadTypeEnum.TEXT.value, 

40 PayloadTypeEnum.BYTES.value, 

41 ] 

42 

43 self.websocket = websocket 

44 self.expires = expires 

45 self.payload_type = payload_type 

46 self.uuid = uuid.uuid4() 

47 self.created = time.time() 

48 

49 async def _send(self, payload: typing.Any) -> None: 

50 try: 

51 if self.payload_type == "json": 

52 await self.websocket.send_json(payload) 

53 elif self.payload_type == "text": 

54 await self.websocket.send_text(payload) 

55 elif self.payload_type == "bytes": 

56 await self.websocket.send_bytes(payload) 

57 else: 

58 await self.websocket.send(payload) 

59 except RuntimeError as error: 

60 logging.debug(error) 

61 

62 self.created = time.time() 

63 

64 async def _is_expired(self) -> bool: 

65 if not self.expires: 

66 return False 

67 return (self.expires + int(self.created)) < time.time() 

68 

69 def __repr__(self) -> str: 

70 return f"{self.__class__.__name__} {self.uuid=} {self.payload_type=} {self.expires=}" 

71 

72 

73class ChannelBox: 

74 CHANNEL_GROUPS: typing.Dict[str, typing.Any] = ( 

75 {} 

76 ) # groups of channels ~ key: group_name, val: dict of channels 

77 CHANNEL_GROUPS_HISTORY: typing.Dict[str, typing.Any] = {} # history messages 

78 HISTORY_SIZE: int = int(os.getenv("CHANNEL_BOX_HISTORY_SIZE", 1_048_576)) 

79 

80 @classmethod 

81 async def add_channel_to_group( 

82 cls, 

83 channel: Channel, 

84 group_name: str = "default", 

85 ) -> ChannelAddStatusEnum: 

86 """Add channel to group. 

87 

88 Args: 

89 channel (Channel): Instance of Channel class 

90 group_name (str): Group name 

91 

92 """ 

93 assert group_name, "Group name must to be set." 

94 

95 if group_name not in cls.CHANNEL_GROUPS: 

96 cls.CHANNEL_GROUPS[group_name] = {} 

97 channel_add_status = ChannelAddStatusEnum.CHANNEL_ADDED 

98 else: 

99 channel_add_status = ChannelAddStatusEnum.CHANNEL_EXIST 

100 

101 cls.CHANNEL_GROUPS[group_name][channel] = ... 

102 return channel_add_status 

103 

104 @classmethod 

105 async def remove_channel_from_group( 

106 cls, 

107 channel: Channel, 

108 group_name: str, 

109 ) -> ChannelRemoveStatusEnum: 

110 """Remove channel from group. 

111 

112 Args: 

113 channel (Channel): Instance of Channel class 

114 group_name (str): Group name 

115 """ 

116 channel_remove_status: typing.Any = None 

117 if channel in cls.CHANNEL_GROUPS.get(group_name, {}): 

118 try: 

119 del cls.CHANNEL_GROUPS[group_name][channel] 

120 channel_remove_status = ChannelRemoveStatusEnum.CHANNEL_REMOVED 

121 except KeyError: 

122 channel_remove_status = ChannelRemoveStatusEnum.CHANNEL_DOES_NOT_EXIST 

123 

124 if not any(cls.CHANNEL_GROUPS.get(group_name, {})): 

125 try: 

126 del cls.CHANNEL_GROUPS[group_name] 

127 channel_remove_status = ChannelRemoveStatusEnum.GROUP_REMOVED 

128 except KeyError: 

129 channel_remove_status = ChannelRemoveStatusEnum.GROUP_DOES_NOT_EXIST 

130 

131 await cls._clean_expired() 

132 return channel_remove_status 

133 

134 @classmethod 

135 async def group_send( 

136 cls, 

137 group_name: str = "default", 

138 payload: typing.Union[typing.Dict[str, typing.Any], str, bytes] = {}, 

139 save_history: bool = False, 

140 ) -> GroupSendStatusEnum: 

141 """Send payload to all channels connected to group. 

142 

143 Args: 

144 group_name (str, optional): Group name 

145 payload (dict, optional): Payload to channel 

146 save_history (bool, optional): Save message history. Defaults to False. 

147 

148 """ 

149 assert group_name, "Group name must to be set." 

150 

151 if save_history: 

152 cls.CHANNEL_GROUPS_HISTORY.setdefault(group_name, []) 

153 cls.CHANNEL_GROUPS_HISTORY[group_name].append( 

154 ChannelMessageDC( 

155 payload=payload, # type:ignore 

156 ) 

157 ) 

158 if sys.getsizeof(cls.CHANNEL_GROUPS_HISTORY[group_name]) > cls.HISTORY_SIZE: 

159 cls.CHANNEL_GROUPS_HISTORY[group_name] = [] 

160 

161 group_send_status = GroupSendStatusEnum.NO_SUCH_GROUP 

162 for channel in cls.CHANNEL_GROUPS.get(group_name, {}): 

163 await channel._send(payload) 

164 group_send_status = GroupSendStatusEnum.GROUP_SEND 

165 

166 return group_send_status 

167 

168 @classmethod 

169 async def show_groups(cls) -> typing.Dict[str, typing.Any]: 

170 return cls.CHANNEL_GROUPS 

171 

172 @classmethod 

173 async def flush_groups(cls) -> None: 

174 cls.CHANNEL_GROUPS = {} 

175 

176 @classmethod 

177 async def show_history( 

178 cls, 

179 group_name: str = "", 

180 ) -> typing.Dict[str, typing.Any]: 

181 return ( 

182 cls.CHANNEL_GROUPS_HISTORY.get(group_name, {}) 

183 if group_name 

184 else cls.CHANNEL_GROUPS_HISTORY 

185 ) 

186 

187 @classmethod 

188 async def flush_history(cls) -> None: 

189 cls.CHANNEL_GROUPS_HISTORY = {} 

190 

191 @classmethod 

192 async def _clean_expired(cls) -> None: 

193 for group_name in list(cls.CHANNEL_GROUPS): 

194 for channel in cls.CHANNEL_GROUPS.get(group_name, {}): 

195 _is_expired = await channel._is_expired() 

196 if _is_expired: 

197 try: 

198 del cls.CHANNEL_GROUPS[group_name][channel] 

199 except KeyError: 

200 logging.debug("No such channel") 

201 

202 if not any(cls.CHANNEL_GROUPS.get(group_name, {})): 

203 try: 

204 del cls.CHANNEL_GROUPS[group_name] 

205 except KeyError: 

206 logging.debug("No such group") 

207 

208 @classmethod 

209 async def close_all_connections(cls) -> None: 

210 """ 

211 Close all WebSocket connections in all groups. 

212 """ 

213 for group_name, channels in cls.CHANNEL_GROUPS.items(): 

214 for channel in list( 

215 channels.keys() 

216 ): # Use list() to avoid RuntimeError due to dict size change 

217 try: 

218 await channel.websocket.close() 

219 logging.debug( 

220 f"Closed connection for channel {channel.uuid} in group {group_name}" 

221 ) 

222 except Exception as e: 

223 logging.error( 

224 f"Failed to close connection for channel {channel.uuid}: {e}" 

225 ) 

226 

227 cls.CHANNEL_GROUPS = {} 

228 logging.debug("All connections closed and groups cleared.")