Coverage for nexios\testing\transport.py: 35%
251 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
1import httpx
2import anyio
3import io
4from typing import Any, Tuple, Optional, List, Dict, AsyncGenerator, Union
5from urllib.parse import unquote
6from enum import Enum
8ASGIScope = Dict[str, Any]
9Message = Dict[str, Any]
10HeaderList = List[Tuple[bytes, bytes]]
13class WebSocketState(Enum):
14 CONNECTING = 0
15 CONNECTED = 1
16 DISCONNECTED = 2
19class WebSocketDisconnect(Exception):
20 def __init__(self, code: int = 1000, reason: Optional[str] = None):
21 self.code = code
22 self.reason = reason
23 super().__init__(
24 f"WebSocket disconnected with code {code}: {reason or 'No reason provided'}"
25 )
28class WebSocketConnection:
29 def __init__(
30 self,
31 app: Any,
32 scope: ASGIScope,
33 raise_app_exceptions: bool = True,
34 timeout: float = 5.0,
35 ):
36 self.app = app
37 self.scope = scope
38 self.raise_app_exceptions = raise_app_exceptions
39 self.timeout = timeout
41 # Connection state
42 self.state = WebSocketState.CONNECTING
43 self.close_code: Optional[int] = None
44 self.close_reason: Optional[str] = None
45 self.subprotocol: Optional[str] = None
47 # Communication channels
48 self.receive_channel, self.send_channel = anyio.create_memory_object_stream(
49 max_buffer_size=float("inf")
50 )
51 self.app_channel, self.client_channel = anyio.create_memory_object_stream(
52 max_buffer_size=float("inf")
53 )
55 # Control events
56 self.connection_event = anyio.Event()
57 self.disconnection_event = anyio.Event()
59 # Background task group
60 self.task_group = None
62 async def run_app(self) -> None:
63 """Run the ASGI application to handle the WebSocket connection."""
64 try:
65 scope = self.scope
66 app = self.app # Store app reference
68 # Start the ASGI application
69 try:
70 await app(scope, self._asgi_receive, self._asgi_send)
71 except Exception as exc:
72 if self.raise_app_exceptions:
73 await self._handle_disconnect(1011, str(exc))
74 raise
75 await self._handle_disconnect(1011, "Internal error")
76 else:
77 # If we get here, the app completed normally
78 if self.state != WebSocketState.DISCONNECTED:
79 await self._handle_disconnect(1000, "Normal closure")
80 finally:
81 # Ensure we're disconnected
82 if self.state != WebSocketState.DISCONNECTED:
83 await self._handle_disconnect(1006, "Abnormal closure")
85 async def _asgi_receive(self) -> Message:
86 """ASGI receive function for the application."""
87 if self.state == WebSocketState.CONNECTING:
88 message = {"type": "websocket.connect"}
89 self.state = WebSocketState.CONNECTED
90 return message
92 if self.state == WebSocketState.DISCONNECTED:
93 return {"type": "websocket.disconnect", "code": self.close_code}
95 try:
96 async with anyio.move_on_after(self.timeout) as scope:
97 async with self.app_channel as receiver:
98 message = await receiver.receive()
99 if scope.cancel_called:
100 await self._handle_disconnect(1002, "Receive timeout")
101 return {"type": "websocket.disconnect", "code": 1002}
102 return message
103 except Exception:
104 await self._handle_disconnect(1006, "Connection lost")
105 return {"type": "websocket.disconnect", "code": 1006}
107 async def _asgi_send(self, message: Message) -> None:
108 """ASGI send function for the application."""
109 message_type = message["type"]
111 try:
112 if message_type == "websocket.accept":
113 self.subprotocol = message.get("subprotocol")
114 async with self.client_channel as sender:
115 await sender.send({"type": "accept"})
116 self.connection_event.set()
118 elif message_type == "websocket.send":
119 async with self.client_channel as sender:
120 await sender.send(
121 {
122 "type": "message",
123 "data": message.get("text") or message.get("bytes"),
124 "is_text": "text" in message,
125 }
126 )
128 elif message_type == "websocket.close":
129 code = message.get("code", 1000)
130 reason = message.get("reason")
132 # First notify the client
133 async with self.client_channel as sender:
134 await sender.send({"type": "close", "code": code, "reason": reason})
136 # Then handle the disconnect
137 await self._handle_disconnect(code, reason)
138 except Exception as e:
139 # If sending fails, ensure we disconnect
140 await self._handle_disconnect(1006, str(e))
141 raise
143 async def _handle_disconnect(self, code: int, reason: Optional[str] = None) -> None:
144 """Handle WebSocket disconnection."""
145 if self.state == WebSocketState.DISCONNECTED:
146 return
148 self.state = WebSocketState.DISCONNECTED
149 self.close_code = code
150 self.close_reason = reason
152 # Set the disconnection event first
153 self.disconnection_event.set()
155 # Then try to clean up the task group
156 if self.task_group and self.task_group.cancel_scope:
157 await self.task_group.cancel_scope.cancel()
159 async def connect(self) -> None:
160 """Establish the WebSocket connection."""
161 if self.state != WebSocketState.CONNECTING:
162 raise RuntimeError("WebSocket is already connected or disconnected")
164 async with anyio.create_task_group() as tg:
165 self.task_group = tg
166 await tg.start(self.run_app)
168 timeout_error = False
169 try:
170 async with anyio.move_on_after(self.timeout) as scope:
171 await self.connection_event.wait()
172 timeout_error = scope.cancel_called
174 if timeout_error:
175 await self._handle_disconnect(1006, "Connection timeout")
176 raise RuntimeError("WebSocket connection timed out")
178 except Exception as e:
179 await self._handle_disconnect(1006, str(e))
180 raise
182 async def send(self, data: Union[str, bytes]) -> None:
183 """Send data through the WebSocket."""
184 if self.state != WebSocketState.CONNECTED:
185 raise WebSocketDisconnect(
186 self.close_code or 1006,
187 self.close_reason or "Cannot send on closed connection",
188 )
190 message = {
191 "type": "websocket.receive",
192 "text": data if isinstance(data, str) else None,
193 "bytes": data if isinstance(data, bytes) else None,
194 }
196 try:
197 timeout_error = False
198 async with anyio.move_on_after(self.timeout) as scope:
199 async with self.receive_channel as sender:
200 await sender.send(message)
201 timeout_error = scope.cancel_called
203 if timeout_error:
204 await self._handle_disconnect(1002, "Protocol error")
205 raise WebSocketDisconnect(1002, "Send timeout")
206 except Exception:
207 await self._handle_disconnect(1002, "Protocol error")
208 raise WebSocketDisconnect(1002, "Send error")
210 async def receive(self) -> Union[str, bytes]:
211 """Receive data from the WebSocket."""
212 if self.state == WebSocketState.DISCONNECTED:
213 raise WebSocketDisconnect(
214 self.close_code or 1006, self.close_reason or "Connection closed"
215 )
217 try:
218 timeout_error = False
219 message = None
221 async with anyio.move_on_after(self.timeout) as scope:
222 async with self.client_channel as receiver:
223 message = await receiver.receive()
224 timeout_error = scope.cancel_called
226 if timeout_error or message is None:
227 await self._handle_disconnect(1002, "Protocol error")
228 raise WebSocketDisconnect(1002, "Receive timeout")
230 if message["type"] == "message":
231 return message["data"]
232 elif message["type"] == "close":
233 raise WebSocketDisconnect(message["code"], message["reason"])
235 except Exception:
236 await self._handle_disconnect(1002, "Protocol error")
237 raise WebSocketDisconnect(1002, "Receive error")
239 async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
240 """Close the WebSocket connection gracefully."""
241 if self.state == WebSocketState.DISCONNECTED:
242 return
244 await self._handle_disconnect(code, reason)
245 await self.disconnection_event.wait()
247 async def __aenter__(self):
248 await self.connect()
249 return self
251 async def __aexit__(self, exc_type, exc_val, exc_tb):
252 await self.close()
255class NexiosAsyncTransport(httpx.AsyncBaseTransport):
256 """Custom ASGI transport with full HTTP and WebSocket support."""
258 def __init__(
259 self,
260 app: Any,
261 raise_app_exceptions: bool = True,
262 root_path: str = "",
263 client: Tuple[str, int] = ("testclient", 5000),
264 app_state: Optional[Dict[str, Any]] = None,
265 websocket_timeout: float = 5.0,
266 ):
267 self.app = app
268 self.raise_app_exceptions = raise_app_exceptions
269 self.root_path = root_path
270 self.client = client
271 self.app_state = app_state or {}
272 self.websocket_timeout = websocket_timeout
274 async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
275 """Handle the incoming request and route it appropriately."""
276 scheme, netloc, path, raw_path, query = self._parse_url(request)
277 host, port = self._get_host_port(netloc, scheme)
278 headers = self._prepare_headers(request, host, port)
280 if scheme in {"ws", "wss"}:
281 return await self._handle_websocket(
282 request, scheme, path, raw_path, query, headers, host, port
283 )
284 return await self._handle_http(
285 request, scheme, path, raw_path, query, headers, host, port
286 )
288 async def _handle_websocket(
289 self,
290 request: httpx.Request,
291 scheme: str,
292 path: str,
293 raw_path: bytes,
294 query: str,
295 headers: HeaderList,
296 host: str,
297 port: int,
298 ) -> httpx.Response:
299 """Handle WebSocket requests."""
300 import base64
301 import hashlib
303 def calculate_accept(key: str) -> str:
304 """Calculate Sec-WebSocket-Accept header value according to RFC 6455."""
305 GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
306 accept = hashlib.sha1((key + GUID).encode()).digest()
307 return base64.b64encode(accept).decode()
309 # Get the WebSocket key from request headers
310 ws_key = None
311 for k, v in headers:
312 if k.lower() == b"sec-websocket-key":
313 ws_key = v.decode()
314 break
316 if not ws_key:
317 return httpx.Response(400, content=b"Missing WebSocket Key")
319 # Calculate the accept key
320 ws_accept = calculate_accept(ws_key)
322 # Build WebSocket scope
323 scope = self._build_websocket_scope(
324 request, scheme, path, raw_path, query, headers, host, port
325 )
327 # Create WebSocket connection handler
328 websocket = WebSocketConnection(
329 app=self.app,
330 scope=scope,
331 raise_app_exceptions=self.raise_app_exceptions,
332 timeout=self.websocket_timeout,
333 )
335 # Return a 101 Switching Protocols response with proper WebSocket headers
336 response_headers = {
337 "Upgrade": "websocket",
338 "Connection": "Upgrade",
339 "Sec-WebSocket-Accept": ws_accept,
340 }
342 # Add subprotocol if specified
343 if "sec-websocket-protocol" in request.headers:
344 response_headers["Sec-WebSocket-Protocol"] = request.headers[
345 "sec-websocket-protocol"
346 ]
348 return httpx.Response(
349 101,
350 headers=response_headers,
351 request=request,
352 extensions={"websocket": websocket},
353 )
355 async def _handle_http(
356 self,
357 request: httpx.Request,
358 scheme: str,
359 path: str,
360 raw_path: bytes,
361 query: str,
362 headers: HeaderList,
363 host: str,
364 port: int,
365 ) -> httpx.Response:
366 """Handle HTTP requests."""
367 scope = self._build_http_scope(
368 request, scheme, path, raw_path, query, headers, host, port
369 )
370 return await self._send_http_request(scope, request)
372 async def _send_http_request(
373 self, scope: ASGIScope, request: httpx.Request
374 ) -> httpx.Response:
375 """Send HTTP request to ASGI app and return response."""
376 request_complete = False
377 response_started = False
378 response_complete = anyio.Event()
379 response_body = io.BytesIO()
380 response_headers: List[Tuple[str, str]] = []
381 status_code = 500
383 async def receive() -> Message:
384 nonlocal request_complete
385 if request_complete:
386 await response_complete.wait()
387 return {"type": "http.disconnect"}
389 body = await request.aread()
390 request_complete = True
391 return {"type": "http.request", "body": body}
393 async def send(message: Message) -> None:
394 nonlocal response_started, status_code, response_headers
395 if message["type"] == "http.response.start":
396 status_code = message["status"]
397 response_headers = [
398 (k.decode(), v.decode()) for k, v in message.get("headers", [])
399 ]
400 response_started = True
401 elif message["type"] == "http.response.body":
402 assert response_started, "Received body before response start"
403 response_body.write(message.get("body", b""))
404 if not message.get("more_body", False):
405 response_body.seek(0)
406 response_complete.set()
408 try:
409 await self.app(scope, receive, send)
410 except BaseException as exc:
411 if self.raise_app_exceptions:
412 raise exc
413 status_code = 500
414 response_body = io.BytesIO(b"Internal Server Error")
416 if self.raise_app_exceptions and not response_started:
417 raise RuntimeError("TestClient did not receive any response.")
419 return httpx.Response(
420 status_code,
421 headers=dict(response_headers),
422 content=response_body.read(),
423 request=request,
424 )
426 # Helper methods remain the same as previous implementation
427 def _parse_url(self, request: httpx.Request) -> Tuple[str, str, str, bytes, str]:
428 return (
429 request.url.scheme,
430 request.url.netloc.decode("ascii"),
431 request.url.path,
432 request.url.raw_path,
433 request.url.query.decode("ascii"),
434 )
436 def _get_host_port(self, netloc: str, scheme: str) -> Tuple[str, int]:
437 default_ports = {"http": 80, "https": 443, "ws": 80, "wss": 443}
438 if ":" in netloc:
439 host, port = netloc.split(":", 1)
440 return host, int(port)
441 return netloc, default_ports.get(scheme, 80)
443 def _prepare_headers(
444 self, request: httpx.Request, host: str, port: int
445 ) -> HeaderList:
446 headers = (
447 [(b"host", f"{host}:{port}".encode())]
448 if "host" not in request.headers
449 else []
450 )
451 headers.extend(
452 [
453 (key.lower().encode(), value.encode())
454 for key, value in request.headers.multi_items()
455 ]
456 )
457 return headers
459 def _build_http_scope(
460 self,
461 request: httpx.Request,
462 scheme: str,
463 path: str,
464 raw_path: bytes,
465 query: str,
466 headers: HeaderList,
467 host: str,
468 port: int,
469 ) -> ASGIScope:
470 return {
471 "type": "http",
472 "http_version": "1.1",
473 "method": request.method,
474 "path": unquote(path),
475 "raw_path": raw_path.split(b"?", 1)[0],
476 "root_path": self.root_path,
477 "scheme": scheme,
478 "query_string": query.encode(),
479 "headers": headers,
480 "client": self.client,
481 "server": [host, port],
482 "state": self.app_state.copy(),
483 }
485 def _build_websocket_scope(
486 self,
487 request: httpx.Request,
488 scheme: str,
489 path: str,
490 raw_path: bytes,
491 query: str,
492 headers: HeaderList,
493 host: str,
494 port: int,
495 ) -> ASGIScope:
496 subprotocols = []
497 if "sec-websocket-protocol" in request.headers:
498 subprotocols = [
499 p.strip() for p in request.headers["sec-websocket-protocol"].split(",")
500 ]
502 return {
503 "type": "websocket",
504 "path": unquote(path),
505 "raw_path": raw_path.split(b"?", 1)[0],
506 "root_path": self.root_path,
507 "scheme": scheme,
508 "query_string": query.encode(),
509 "headers": headers,
510 "client": self.client,
511 "server": [host, port],
512 "subprotocols": subprotocols,
513 "state": self.app_state.copy(),
514 "app": self.app, # Include app in scope
515 }