Coverage for nexios\websockets\base.py: 24%

137 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-21 20:31 +0100

1import enum 

2import json 

3import typing 

4from typing import AsyncIterator, Iterable, Optional 

5from nexios.http.request import HTTPConnection 

6 

7Scope = typing.MutableMapping[str, typing.Any] 

8Message = typing.MutableMapping[str, typing.Any] 

9 

10Receive = typing.Callable[[], typing.Awaitable[Message]] 

11Send = typing.Callable[[Message], typing.Awaitable[None]] 

12 

13 

14class WebSocketState(enum.Enum): 

15 CONNECTING = 0 

16 CONNECTED = 1 

17 DISCONNECTED = 2 

18 RESPONSE = 3 

19 

20 

21class WebSocketDisconnect(Exception): 

22 def __init__(self, code: int = 1000, reason: Optional[str] = None) -> None: 

23 self.code = code 

24 self.reason = reason or "" 

25 

26 

27class WebSocket(HTTPConnection): 

28 def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: 

29 super().__init__(scope, receive) 

30 assert scope["type"] == "websocket" 

31 self._receive = receive 

32 self._send = send 

33 self.client_state = WebSocketState.CONNECTING 

34 self.application_state = WebSocketState.CONNECTING 

35 

36 async def receive(self) -> Message: 

37 """ 

38 Receive ASGI websocket messages, ensuring valid state transitions. 

39 """ 

40 if self.client_state == WebSocketState.CONNECTING: 

41 message = await self._receive() 

42 message_type = message["type"] 

43 if message_type != "websocket.connect": 

44 raise RuntimeError( 

45 f'Expected ASGI message "websocket.connect", but got {message_type!r}' 

46 ) 

47 self.client_state = WebSocketState.CONNECTED 

48 return message 

49 elif self.client_state == WebSocketState.CONNECTED: 

50 message = await self._receive() 

51 message_type = message["type"] 

52 if message_type not in {"websocket.receive", "websocket.disconnect"}: 

53 raise RuntimeError( 

54 f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}' 

55 ) 

56 if message_type == "websocket.disconnect": 

57 self.client_state = WebSocketState.DISCONNECTED 

58 return message 

59 else: 

60 raise RuntimeError( 

61 'Cannot call "receive" once a disconnect message has been received.' 

62 ) 

63 

64 async def send(self, message: Message) -> None: 

65 """ 

66 Send ASGI websocket messages, ensuring valid state transitions. 

67 """ 

68 if self.application_state == WebSocketState.CONNECTING: 

69 message_type = message["type"] 

70 if message_type not in { 

71 "websocket.accept", 

72 "websocket.close", 

73 "websocket.http.response.start", 

74 }: 

75 raise RuntimeError( 

76 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", ' 

77 f"but got {message_type!r}" 

78 ) 

79 if message_type == "websocket.close": 

80 self.application_state = WebSocketState.DISCONNECTED 

81 elif message_type == "websocket.http.response.start": 

82 self.application_state = WebSocketState.RESPONSE 

83 else: 

84 self.application_state = WebSocketState.CONNECTED 

85 await self._send(message) 

86 elif self.application_state == WebSocketState.CONNECTED: 

87 message_type = message["type"] 

88 if message_type not in {"websocket.send", "websocket.close"}: 

89 raise RuntimeError( 

90 f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}' 

91 ) 

92 if message_type == "websocket.close": 

93 self.application_state = WebSocketState.DISCONNECTED 

94 try: 

95 await self._send(message) 

96 except OSError: 

97 self.application_state = WebSocketState.DISCONNECTED 

98 raise WebSocketDisconnect(code=1006) 

99 elif self.application_state == WebSocketState.RESPONSE: 

100 message_type = message["type"] 

101 if message_type != "websocket.http.response.body": 

102 raise RuntimeError( 

103 f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}' 

104 ) 

105 if not message.get("more_body", False): 

106 self.application_state = WebSocketState.DISCONNECTED 

107 await self._send(message) 

108 else: 

109 raise RuntimeError('Cannot call "send" once a close message has been sent.') 

110 

111 async def accept( 

112 self, 

113 subprotocol: Optional[str] = None, 

114 headers: Optional[Iterable[tuple[bytes, bytes]]] = None, 

115 ) -> None: 

116 headers = headers or [] 

117 

118 if self.client_state == WebSocketState.CONNECTING: 

119 # If we haven't yet seen the 'connect' message, then wait for it first. 

120 await self.receive() 

121 await self.send( 

122 {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers} 

123 ) 

124 

125 def _raise_on_disconnect(self, message: Message) -> None: 

126 if message["type"] == "websocket.disconnect": 

127 raise WebSocketDisconnect(message["code"], message.get("reason")) 

128 

129 async def receive_text(self) -> str: 

130 if self.application_state != WebSocketState.CONNECTED: 

131 raise RuntimeError( 

132 'WebSocket is not connected. Need to call "accept" first.' 

133 ) 

134 message = await self.receive() 

135 self._raise_on_disconnect(message) 

136 return typing.cast(str, message["text"]) 

137 

138 async def receive_bytes(self) -> bytes: 

139 if self.application_state != WebSocketState.CONNECTED: 

140 raise RuntimeError( 

141 'WebSocket is not connected. Need to call "accept" first.' 

142 ) 

143 message = await self.receive() 

144 self._raise_on_disconnect(message) 

145 return typing.cast(bytes, message["bytes"]) 

146 

147 async def receive_json(self, mode: str = "text") -> typing.Any: 

148 if mode not in {"text", "binary"}: 

149 raise RuntimeError('The "mode" argument should be "text" or "binary".') 

150 if self.application_state != WebSocketState.CONNECTED: 

151 raise RuntimeError( 

152 'WebSocket is not connected. Need to call "accept" first.' 

153 ) 

154 message = await self.receive() 

155 self._raise_on_disconnect(message) 

156 

157 if mode == "text": 

158 text = message["text"] 

159 else: 

160 text = message["bytes"].decode("utf-8") 

161 

162 return json.loads(text) 

163 

164 async def iter_text(self) -> AsyncIterator[str]: 

165 try: 

166 while True: 

167 yield await self.receive_text() 

168 except WebSocketDisconnect: 

169 pass 

170 

171 async def iter_bytes(self) -> AsyncIterator[bytes]: 

172 try: 

173 while True: 

174 yield await self.receive_bytes() 

175 except WebSocketDisconnect: 

176 pass 

177 

178 async def iter_json(self) -> AsyncIterator[typing.Any]: 

179 try: 

180 while True: 

181 yield await self.receive_json() 

182 except WebSocketDisconnect: 

183 pass 

184 

185 async def send_text(self, data: str) -> None: 

186 await self.send({"type": "websocket.send", "text": data}) 

187 

188 async def send_bytes(self, data: bytes) -> None: 

189 await self.send({"type": "websocket.send", "bytes": data}) 

190 

191 async def send_json(self, data: typing.Any, mode: str = "text") -> None: 

192 if mode not in {"text", "binary"}: 

193 raise RuntimeError('The "mode" argument should be "text" or "binary".') 

194 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) 

195 if mode == "text": 

196 await self.send({"type": "websocket.send", "text": text}) 

197 else: 

198 await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) 

199 

200 async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: 

201 await self.send( 

202 {"type": "websocket.close", "code": code, "reason": reason or ""} 

203 ) 

204 

205 def is_connected(self) -> bool: 

206 return ( 

207 self.client_state == WebSocketState.CONNECTED 

208 and self.application_state == WebSocketState.CONNECTED 

209 )