Coverage for nexios\websockets\base.py: 24%
137 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
1import enum
2import json
3import typing
4from typing import AsyncIterator, Iterable, Optional
5from nexios.http.request import HTTPConnection
7Scope = typing.MutableMapping[str, typing.Any]
8Message = typing.MutableMapping[str, typing.Any]
10Receive = typing.Callable[[], typing.Awaitable[Message]]
11Send = typing.Callable[[Message], typing.Awaitable[None]]
14class WebSocketState(enum.Enum):
15 CONNECTING = 0
16 CONNECTED = 1
17 DISCONNECTED = 2
18 RESPONSE = 3
21class WebSocketDisconnect(Exception):
22 def __init__(self, code: int = 1000, reason: Optional[str] = None) -> None:
23 self.code = code
24 self.reason = reason or ""
27class WebSocket(HTTPConnection):
28 def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
29 super().__init__(scope, receive)
30 assert scope["type"] == "websocket"
31 self._receive = receive
32 self._send = send
33 self.client_state = WebSocketState.CONNECTING
34 self.application_state = WebSocketState.CONNECTING
36 async def receive(self) -> Message:
37 """
38 Receive ASGI websocket messages, ensuring valid state transitions.
39 """
40 if self.client_state == WebSocketState.CONNECTING:
41 message = await self._receive()
42 message_type = message["type"]
43 if message_type != "websocket.connect":
44 raise RuntimeError(
45 f'Expected ASGI message "websocket.connect", but got {message_type!r}'
46 )
47 self.client_state = WebSocketState.CONNECTED
48 return message
49 elif self.client_state == WebSocketState.CONNECTED:
50 message = await self._receive()
51 message_type = message["type"]
52 if message_type not in {"websocket.receive", "websocket.disconnect"}:
53 raise RuntimeError(
54 f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
55 )
56 if message_type == "websocket.disconnect":
57 self.client_state = WebSocketState.DISCONNECTED
58 return message
59 else:
60 raise RuntimeError(
61 'Cannot call "receive" once a disconnect message has been received.'
62 )
64 async def send(self, message: Message) -> None:
65 """
66 Send ASGI websocket messages, ensuring valid state transitions.
67 """
68 if self.application_state == WebSocketState.CONNECTING:
69 message_type = message["type"]
70 if message_type not in {
71 "websocket.accept",
72 "websocket.close",
73 "websocket.http.response.start",
74 }:
75 raise RuntimeError(
76 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
77 f"but got {message_type!r}"
78 )
79 if message_type == "websocket.close":
80 self.application_state = WebSocketState.DISCONNECTED
81 elif message_type == "websocket.http.response.start":
82 self.application_state = WebSocketState.RESPONSE
83 else:
84 self.application_state = WebSocketState.CONNECTED
85 await self._send(message)
86 elif self.application_state == WebSocketState.CONNECTED:
87 message_type = message["type"]
88 if message_type not in {"websocket.send", "websocket.close"}:
89 raise RuntimeError(
90 f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
91 )
92 if message_type == "websocket.close":
93 self.application_state = WebSocketState.DISCONNECTED
94 try:
95 await self._send(message)
96 except OSError:
97 self.application_state = WebSocketState.DISCONNECTED
98 raise WebSocketDisconnect(code=1006)
99 elif self.application_state == WebSocketState.RESPONSE:
100 message_type = message["type"]
101 if message_type != "websocket.http.response.body":
102 raise RuntimeError(
103 f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}'
104 )
105 if not message.get("more_body", False):
106 self.application_state = WebSocketState.DISCONNECTED
107 await self._send(message)
108 else:
109 raise RuntimeError('Cannot call "send" once a close message has been sent.')
111 async def accept(
112 self,
113 subprotocol: Optional[str] = None,
114 headers: Optional[Iterable[tuple[bytes, bytes]]] = None,
115 ) -> None:
116 headers = headers or []
118 if self.client_state == WebSocketState.CONNECTING:
119 # If we haven't yet seen the 'connect' message, then wait for it first.
120 await self.receive()
121 await self.send(
122 {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
123 )
125 def _raise_on_disconnect(self, message: Message) -> None:
126 if message["type"] == "websocket.disconnect":
127 raise WebSocketDisconnect(message["code"], message.get("reason"))
129 async def receive_text(self) -> str:
130 if self.application_state != WebSocketState.CONNECTED:
131 raise RuntimeError(
132 'WebSocket is not connected. Need to call "accept" first.'
133 )
134 message = await self.receive()
135 self._raise_on_disconnect(message)
136 return typing.cast(str, message["text"])
138 async def receive_bytes(self) -> bytes:
139 if self.application_state != WebSocketState.CONNECTED:
140 raise RuntimeError(
141 'WebSocket is not connected. Need to call "accept" first.'
142 )
143 message = await self.receive()
144 self._raise_on_disconnect(message)
145 return typing.cast(bytes, message["bytes"])
147 async def receive_json(self, mode: str = "text") -> typing.Any:
148 if mode not in {"text", "binary"}:
149 raise RuntimeError('The "mode" argument should be "text" or "binary".')
150 if self.application_state != WebSocketState.CONNECTED:
151 raise RuntimeError(
152 'WebSocket is not connected. Need to call "accept" first.'
153 )
154 message = await self.receive()
155 self._raise_on_disconnect(message)
157 if mode == "text":
158 text = message["text"]
159 else:
160 text = message["bytes"].decode("utf-8")
162 return json.loads(text)
164 async def iter_text(self) -> AsyncIterator[str]:
165 try:
166 while True:
167 yield await self.receive_text()
168 except WebSocketDisconnect:
169 pass
171 async def iter_bytes(self) -> AsyncIterator[bytes]:
172 try:
173 while True:
174 yield await self.receive_bytes()
175 except WebSocketDisconnect:
176 pass
178 async def iter_json(self) -> AsyncIterator[typing.Any]:
179 try:
180 while True:
181 yield await self.receive_json()
182 except WebSocketDisconnect:
183 pass
185 async def send_text(self, data: str) -> None:
186 await self.send({"type": "websocket.send", "text": data})
188 async def send_bytes(self, data: bytes) -> None:
189 await self.send({"type": "websocket.send", "bytes": data})
191 async def send_json(self, data: typing.Any, mode: str = "text") -> None:
192 if mode not in {"text", "binary"}:
193 raise RuntimeError('The "mode" argument should be "text" or "binary".')
194 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
195 if mode == "text":
196 await self.send({"type": "websocket.send", "text": text})
197 else:
198 await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
200 async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
201 await self.send(
202 {"type": "websocket.close", "code": code, "reason": reason or ""}
203 )
205 def is_connected(self) -> bool:
206 return (
207 self.client_state == WebSocketState.CONNECTED
208 and self.application_state == WebSocketState.CONNECTED
209 )