Coverage for nexios\websockets\consumers.py: 24%
101 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
1from .base import WebSocket
2from nexios import status
3import typing, json
4from .channels import Channel, ChannelBox, PayloadTypeEnum
5from nexios import logging
6import uuid
9Message = typing.MutableMapping[str, typing.Any]
12class WebSocketConsumer:
14 channel: typing.Optional[Channel] = None
15 middleware: typing.List[typing.Any] = []
17 encoding: typing.Optional[str] = None
19 def __init__(
20 self,
21 logging_enabled: bool = True,
22 logger: typing.Optional[logging.Logger] = None,
23 ):
24 """
25 Args:
26 logging_enabled: Whether logging is enabled for this endpoint.
27 logger: A custom logger instance. If not provided, the default logger will be used.
28 """
29 self.logging_enabled = logging_enabled
30 self.logger = logger if logger else logging.getLogger("nexios")
32 @classmethod
33 def as_route(cls, path: str):
34 from nexios.routing import WebsocketRoutes
36 """
37 Convert the WebSocketConsumer class into a Route that can be registered with the app or router.
38 """
40 async def handler(websocket: WebSocket, **kwargs) -> None:
41 instance = cls()
42 await instance(websocket, **kwargs)
44 return WebsocketRoutes(path, handler, middlewares=cls.middleware)
46 async def __call__(self, ws: WebSocket) -> None:
47 self.websocket = ws
49 self.channel = Channel(
50 websocket=self.websocket,
51 expires=3600, # Set your desired TTL for the channel
52 payload_type=(
53 PayloadTypeEnum.JSON.value
54 if self.encoding == "json"
55 else PayloadTypeEnum.TEXT.value
56 ),
57 )
58 await self.on_connect(self.websocket)
60 close_code = status.WS_1000_NORMAL_CLOSURE
62 try:
63 while True:
64 message = await self.websocket.receive()
65 if message["type"] == "websocket.receive":
66 data = await self.decode(self.websocket, message)
67 await self.on_receive(self.websocket, data)
68 elif message["type"] == "websocket.disconnect":
69 close_code = int(
70 message.get("code") or status.WS_1000_NORMAL_CLOSURE
71 )
72 break
73 except Exception as exc:
74 close_code = status.WS_1011_INTERNAL_ERROR
75 raise exc
76 finally:
77 await self.on_disconnect(self.websocket, close_code)
79 async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
80 if self.encoding == "text":
81 if "text" not in message:
82 await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
83 raise RuntimeError("Expected text websocket messages, but got bytes")
84 return message["text"]
86 elif self.encoding == "bytes":
87 if "bytes" not in message:
88 await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
89 raise RuntimeError("Expected bytes websocket messages, but got text")
90 return message["bytes"]
92 elif self.encoding == "json":
93 if message.get("text") is not None:
94 text = message["text"]
95 else:
96 text = message["bytes"].decode("utf-8")
98 try:
99 return json.loads(text)
100 except json.decoder.JSONDecodeError:
101 await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
102 raise RuntimeError("Malformed JSON data received.")
104 assert (
105 self.encoding is None
106 ), f"Unsupported 'encoding' attribute {self.encoding}"
107 return message["text"] if message.get("text") else message["bytes"]
109 async def on_connect(self, websocket: WebSocket) -> None:
110 """Override to handle an incoming websocket connection"""
111 await websocket.accept()
112 if self.logging_enabled:
113 self.logger.info("New WebSocket connection established")
115 async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
116 """Override to handle an incoming websocket message"""
117 if self.logging_enabled:
118 self.logger.info(f"Received message: {data}")
120 async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
121 """Override to handle a disconnecting websocket"""
122 if self.logging_enabled:
123 self.logger.info(f"WebSocket disconnected with code: {close_code}")
125 # New Methods for Channel and Group Management
126 async def broadcast(
127 self,
128 payload: typing.Any,
129 group_name: str = "default",
130 save_history: bool = False,
131 ) -> None:
132 """
133 Broadcast a message to all channels in a group.
134 Args:
135 payload: The message payload to broadcast.
136 group_name: The name of the group to broadcast to.
137 save_history: Whether to save the message in the group's history.
138 """
139 await ChannelBox.group_send(
140 group_name=group_name, payload=payload, save_history=save_history
141 )
142 if self.logging_enabled:
143 self.logger.info(f"Broadcasted message to group '{group_name}': {payload}")
145 async def send_to(self, channel_id: uuid.UUID, payload: typing.Any) -> None:
146 """
147 Send a message to a specific channel by its ID.
148 Args:
149 channel_id: The UUID of the target channel.
150 payload: The message payload to send.
151 """
152 for _, channels in ChannelBox.CHANNEL_GROUPS.items():
153 for channel in channels:
154 if channel.uuid == channel_id:
155 await channel._send(payload)
156 if self.logging_enabled:
157 self.logger.info(
158 f"Sent message to channel {channel_id}: {payload}"
159 )
160 return
161 if self.logging_enabled:
162 self.logger.warning(f"Channel with ID {channel_id} not found.")
164 async def group(self, group_name: str) -> typing.List[Channel]:
165 """
166 Get all channels in a specific group.
167 Args:
168 group_name: The name of the group.
169 Returns:
170 A list of channels in the group.
171 """
172 channels = list(ChannelBox.CHANNEL_GROUPS.get(group_name, {}).keys())
173 if self.logging_enabled:
174 self.logger.info(
175 f"Retrieved channels in group '{group_name}': {len(channels)} channels"
176 )
177 return channels
179 async def join_group(self, group_name: str) -> None:
180 """
181 Add the current channel to a group.
182 Args:
183 group_name: The name of the group to join.
184 """
185 if self.channel:
187 await ChannelBox.add_channel_to_group(self.channel, group_name=group_name)
188 if self.logging_enabled:
189 self.logger.info(f"Channel joined group '{group_name}'")
191 async def leave_group(self, group_name: str) -> None:
192 """
193 Remove the current channel from a group.
194 Args:
195 group_name: The name of the group to leave.
196 """
197 if self.channel:
198 await ChannelBox.remove_channel_from_group(
199 self.channel, group_name=group_name
200 )
201 if self.logging_enabled:
202 self.logger.info(f"Channel left group '{group_name}'")