Coverage for nexios\websockets\consumers.py: 24%

101 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-21 20:31 +0100

1from .base import WebSocket 

2from nexios import status 

3import typing, json 

4from .channels import Channel, ChannelBox, PayloadTypeEnum 

5from nexios import logging 

6import uuid 

7 

8 

9Message = typing.MutableMapping[str, typing.Any] 

10 

11 

12class WebSocketConsumer: 

13 

14 channel: typing.Optional[Channel] = None 

15 middleware: typing.List[typing.Any] = [] 

16 

17 encoding: typing.Optional[str] = None 

18 

19 def __init__( 

20 self, 

21 logging_enabled: bool = True, 

22 logger: typing.Optional[logging.Logger] = None, 

23 ): 

24 """ 

25 Args: 

26 logging_enabled: Whether logging is enabled for this endpoint. 

27 logger: A custom logger instance. If not provided, the default logger will be used. 

28 """ 

29 self.logging_enabled = logging_enabled 

30 self.logger = logger if logger else logging.getLogger("nexios") 

31 

32 @classmethod 

33 def as_route(cls, path: str): 

34 from nexios.routing import WebsocketRoutes 

35 

36 """ 

37 Convert the WebSocketConsumer class into a Route that can be registered with the app or router. 

38 """ 

39 

40 async def handler(websocket: WebSocket, **kwargs) -> None: 

41 instance = cls() 

42 await instance(websocket, **kwargs) 

43 

44 return WebsocketRoutes(path, handler, middlewares=cls.middleware) 

45 

46 async def __call__(self, ws: WebSocket) -> None: 

47 self.websocket = ws 

48 

49 self.channel = Channel( 

50 websocket=self.websocket, 

51 expires=3600, # Set your desired TTL for the channel 

52 payload_type=( 

53 PayloadTypeEnum.JSON.value 

54 if self.encoding == "json" 

55 else PayloadTypeEnum.TEXT.value 

56 ), 

57 ) 

58 await self.on_connect(self.websocket) 

59 

60 close_code = status.WS_1000_NORMAL_CLOSURE 

61 

62 try: 

63 while True: 

64 message = await self.websocket.receive() 

65 if message["type"] == "websocket.receive": 

66 data = await self.decode(self.websocket, message) 

67 await self.on_receive(self.websocket, data) 

68 elif message["type"] == "websocket.disconnect": 

69 close_code = int( 

70 message.get("code") or status.WS_1000_NORMAL_CLOSURE 

71 ) 

72 break 

73 except Exception as exc: 

74 close_code = status.WS_1011_INTERNAL_ERROR 

75 raise exc 

76 finally: 

77 await self.on_disconnect(self.websocket, close_code) 

78 

79 async def decode(self, websocket: WebSocket, message: Message) -> typing.Any: 

80 if self.encoding == "text": 

81 if "text" not in message: 

82 await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) 

83 raise RuntimeError("Expected text websocket messages, but got bytes") 

84 return message["text"] 

85 

86 elif self.encoding == "bytes": 

87 if "bytes" not in message: 

88 await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) 

89 raise RuntimeError("Expected bytes websocket messages, but got text") 

90 return message["bytes"] 

91 

92 elif self.encoding == "json": 

93 if message.get("text") is not None: 

94 text = message["text"] 

95 else: 

96 text = message["bytes"].decode("utf-8") 

97 

98 try: 

99 return json.loads(text) 

100 except json.decoder.JSONDecodeError: 

101 await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) 

102 raise RuntimeError("Malformed JSON data received.") 

103 

104 assert ( 

105 self.encoding is None 

106 ), f"Unsupported 'encoding' attribute {self.encoding}" 

107 return message["text"] if message.get("text") else message["bytes"] 

108 

109 async def on_connect(self, websocket: WebSocket) -> None: 

110 """Override to handle an incoming websocket connection""" 

111 await websocket.accept() 

112 if self.logging_enabled: 

113 self.logger.info("New WebSocket connection established") 

114 

115 async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None: 

116 """Override to handle an incoming websocket message""" 

117 if self.logging_enabled: 

118 self.logger.info(f"Received message: {data}") 

119 

120 async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: 

121 """Override to handle a disconnecting websocket""" 

122 if self.logging_enabled: 

123 self.logger.info(f"WebSocket disconnected with code: {close_code}") 

124 

125 # New Methods for Channel and Group Management 

126 async def broadcast( 

127 self, 

128 payload: typing.Any, 

129 group_name: str = "default", 

130 save_history: bool = False, 

131 ) -> None: 

132 """ 

133 Broadcast a message to all channels in a group. 

134 Args: 

135 payload: The message payload to broadcast. 

136 group_name: The name of the group to broadcast to. 

137 save_history: Whether to save the message in the group's history. 

138 """ 

139 await ChannelBox.group_send( 

140 group_name=group_name, payload=payload, save_history=save_history 

141 ) 

142 if self.logging_enabled: 

143 self.logger.info(f"Broadcasted message to group '{group_name}': {payload}") 

144 

145 async def send_to(self, channel_id: uuid.UUID, payload: typing.Any) -> None: 

146 """ 

147 Send a message to a specific channel by its ID. 

148 Args: 

149 channel_id: The UUID of the target channel. 

150 payload: The message payload to send. 

151 """ 

152 for _, channels in ChannelBox.CHANNEL_GROUPS.items(): 

153 for channel in channels: 

154 if channel.uuid == channel_id: 

155 await channel._send(payload) 

156 if self.logging_enabled: 

157 self.logger.info( 

158 f"Sent message to channel {channel_id}: {payload}" 

159 ) 

160 return 

161 if self.logging_enabled: 

162 self.logger.warning(f"Channel with ID {channel_id} not found.") 

163 

164 async def group(self, group_name: str) -> typing.List[Channel]: 

165 """ 

166 Get all channels in a specific group. 

167 Args: 

168 group_name: The name of the group. 

169 Returns: 

170 A list of channels in the group. 

171 """ 

172 channels = list(ChannelBox.CHANNEL_GROUPS.get(group_name, {}).keys()) 

173 if self.logging_enabled: 

174 self.logger.info( 

175 f"Retrieved channels in group '{group_name}': {len(channels)} channels" 

176 ) 

177 return channels 

178 

179 async def join_group(self, group_name: str) -> None: 

180 """ 

181 Add the current channel to a group. 

182 Args: 

183 group_name: The name of the group to join. 

184 """ 

185 if self.channel: 

186 

187 await ChannelBox.add_channel_to_group(self.channel, group_name=group_name) 

188 if self.logging_enabled: 

189 self.logger.info(f"Channel joined group '{group_name}'") 

190 

191 async def leave_group(self, group_name: str) -> None: 

192 """ 

193 Remove the current channel from a group. 

194 Args: 

195 group_name: The name of the group to leave. 

196 """ 

197 if self.channel: 

198 await ChannelBox.remove_channel_from_group( 

199 self.channel, group_name=group_name 

200 ) 

201 if self.logging_enabled: 

202 self.logger.info(f"Channel left group '{group_name}'")