Coverage for nexios\testing\transport.py: 35%

251 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-21 20:31 +0100

1import httpx 

2import anyio 

3import io 

4from typing import Any, Tuple, Optional, List, Dict, AsyncGenerator, Union 

5from urllib.parse import unquote 

6from enum import Enum 

7 

8ASGIScope = Dict[str, Any] 

9Message = Dict[str, Any] 

10HeaderList = List[Tuple[bytes, bytes]] 

11 

12 

13class WebSocketState(Enum): 

14 CONNECTING = 0 

15 CONNECTED = 1 

16 DISCONNECTED = 2 

17 

18 

19class WebSocketDisconnect(Exception): 

20 def __init__(self, code: int = 1000, reason: Optional[str] = None): 

21 self.code = code 

22 self.reason = reason 

23 super().__init__( 

24 f"WebSocket disconnected with code {code}: {reason or 'No reason provided'}" 

25 ) 

26 

27 

28class WebSocketConnection: 

29 def __init__( 

30 self, 

31 app: Any, 

32 scope: ASGIScope, 

33 raise_app_exceptions: bool = True, 

34 timeout: float = 5.0, 

35 ): 

36 self.app = app 

37 self.scope = scope 

38 self.raise_app_exceptions = raise_app_exceptions 

39 self.timeout = timeout 

40 

41 # Connection state 

42 self.state = WebSocketState.CONNECTING 

43 self.close_code: Optional[int] = None 

44 self.close_reason: Optional[str] = None 

45 self.subprotocol: Optional[str] = None 

46 

47 # Communication channels 

48 self.receive_channel, self.send_channel = anyio.create_memory_object_stream( 

49 max_buffer_size=float("inf") 

50 ) 

51 self.app_channel, self.client_channel = anyio.create_memory_object_stream( 

52 max_buffer_size=float("inf") 

53 ) 

54 

55 # Control events 

56 self.connection_event = anyio.Event() 

57 self.disconnection_event = anyio.Event() 

58 

59 # Background task group 

60 self.task_group = None 

61 

62 async def run_app(self) -> None: 

63 """Run the ASGI application to handle the WebSocket connection.""" 

64 try: 

65 scope = self.scope 

66 app = self.app # Store app reference 

67 

68 # Start the ASGI application 

69 try: 

70 await app(scope, self._asgi_receive, self._asgi_send) 

71 except Exception as exc: 

72 if self.raise_app_exceptions: 

73 await self._handle_disconnect(1011, str(exc)) 

74 raise 

75 await self._handle_disconnect(1011, "Internal error") 

76 else: 

77 # If we get here, the app completed normally 

78 if self.state != WebSocketState.DISCONNECTED: 

79 await self._handle_disconnect(1000, "Normal closure") 

80 finally: 

81 # Ensure we're disconnected 

82 if self.state != WebSocketState.DISCONNECTED: 

83 await self._handle_disconnect(1006, "Abnormal closure") 

84 

85 async def _asgi_receive(self) -> Message: 

86 """ASGI receive function for the application.""" 

87 if self.state == WebSocketState.CONNECTING: 

88 message = {"type": "websocket.connect"} 

89 self.state = WebSocketState.CONNECTED 

90 return message 

91 

92 if self.state == WebSocketState.DISCONNECTED: 

93 return {"type": "websocket.disconnect", "code": self.close_code} 

94 

95 try: 

96 async with anyio.move_on_after(self.timeout) as scope: 

97 async with self.app_channel as receiver: 

98 message = await receiver.receive() 

99 if scope.cancel_called: 

100 await self._handle_disconnect(1002, "Receive timeout") 

101 return {"type": "websocket.disconnect", "code": 1002} 

102 return message 

103 except Exception: 

104 await self._handle_disconnect(1006, "Connection lost") 

105 return {"type": "websocket.disconnect", "code": 1006} 

106 

107 async def _asgi_send(self, message: Message) -> None: 

108 """ASGI send function for the application.""" 

109 message_type = message["type"] 

110 

111 try: 

112 if message_type == "websocket.accept": 

113 self.subprotocol = message.get("subprotocol") 

114 async with self.client_channel as sender: 

115 await sender.send({"type": "accept"}) 

116 self.connection_event.set() 

117 

118 elif message_type == "websocket.send": 

119 async with self.client_channel as sender: 

120 await sender.send( 

121 { 

122 "type": "message", 

123 "data": message.get("text") or message.get("bytes"), 

124 "is_text": "text" in message, 

125 } 

126 ) 

127 

128 elif message_type == "websocket.close": 

129 code = message.get("code", 1000) 

130 reason = message.get("reason") 

131 

132 # First notify the client 

133 async with self.client_channel as sender: 

134 await sender.send({"type": "close", "code": code, "reason": reason}) 

135 

136 # Then handle the disconnect 

137 await self._handle_disconnect(code, reason) 

138 except Exception as e: 

139 # If sending fails, ensure we disconnect 

140 await self._handle_disconnect(1006, str(e)) 

141 raise 

142 

143 async def _handle_disconnect(self, code: int, reason: Optional[str] = None) -> None: 

144 """Handle WebSocket disconnection.""" 

145 if self.state == WebSocketState.DISCONNECTED: 

146 return 

147 

148 self.state = WebSocketState.DISCONNECTED 

149 self.close_code = code 

150 self.close_reason = reason 

151 

152 # Set the disconnection event first 

153 self.disconnection_event.set() 

154 

155 # Then try to clean up the task group 

156 if self.task_group and self.task_group.cancel_scope: 

157 await self.task_group.cancel_scope.cancel() 

158 

159 async def connect(self) -> None: 

160 """Establish the WebSocket connection.""" 

161 if self.state != WebSocketState.CONNECTING: 

162 raise RuntimeError("WebSocket is already connected or disconnected") 

163 

164 async with anyio.create_task_group() as tg: 

165 self.task_group = tg 

166 await tg.start(self.run_app) 

167 

168 timeout_error = False 

169 try: 

170 async with anyio.move_on_after(self.timeout) as scope: 

171 await self.connection_event.wait() 

172 timeout_error = scope.cancel_called 

173 

174 if timeout_error: 

175 await self._handle_disconnect(1006, "Connection timeout") 

176 raise RuntimeError("WebSocket connection timed out") 

177 

178 except Exception as e: 

179 await self._handle_disconnect(1006, str(e)) 

180 raise 

181 

182 async def send(self, data: Union[str, bytes]) -> None: 

183 """Send data through the WebSocket.""" 

184 if self.state != WebSocketState.CONNECTED: 

185 raise WebSocketDisconnect( 

186 self.close_code or 1006, 

187 self.close_reason or "Cannot send on closed connection", 

188 ) 

189 

190 message = { 

191 "type": "websocket.receive", 

192 "text": data if isinstance(data, str) else None, 

193 "bytes": data if isinstance(data, bytes) else None, 

194 } 

195 

196 try: 

197 timeout_error = False 

198 async with anyio.move_on_after(self.timeout) as scope: 

199 async with self.receive_channel as sender: 

200 await sender.send(message) 

201 timeout_error = scope.cancel_called 

202 

203 if timeout_error: 

204 await self._handle_disconnect(1002, "Protocol error") 

205 raise WebSocketDisconnect(1002, "Send timeout") 

206 except Exception: 

207 await self._handle_disconnect(1002, "Protocol error") 

208 raise WebSocketDisconnect(1002, "Send error") 

209 

210 async def receive(self) -> Union[str, bytes]: 

211 """Receive data from the WebSocket.""" 

212 if self.state == WebSocketState.DISCONNECTED: 

213 raise WebSocketDisconnect( 

214 self.close_code or 1006, self.close_reason or "Connection closed" 

215 ) 

216 

217 try: 

218 timeout_error = False 

219 message = None 

220 

221 async with anyio.move_on_after(self.timeout) as scope: 

222 async with self.client_channel as receiver: 

223 message = await receiver.receive() 

224 timeout_error = scope.cancel_called 

225 

226 if timeout_error or message is None: 

227 await self._handle_disconnect(1002, "Protocol error") 

228 raise WebSocketDisconnect(1002, "Receive timeout") 

229 

230 if message["type"] == "message": 

231 return message["data"] 

232 elif message["type"] == "close": 

233 raise WebSocketDisconnect(message["code"], message["reason"]) 

234 

235 except Exception: 

236 await self._handle_disconnect(1002, "Protocol error") 

237 raise WebSocketDisconnect(1002, "Receive error") 

238 

239 async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: 

240 """Close the WebSocket connection gracefully.""" 

241 if self.state == WebSocketState.DISCONNECTED: 

242 return 

243 

244 await self._handle_disconnect(code, reason) 

245 await self.disconnection_event.wait() 

246 

247 async def __aenter__(self): 

248 await self.connect() 

249 return self 

250 

251 async def __aexit__(self, exc_type, exc_val, exc_tb): 

252 await self.close() 

253 

254 

255class NexiosAsyncTransport(httpx.AsyncBaseTransport): 

256 """Custom ASGI transport with full HTTP and WebSocket support.""" 

257 

258 def __init__( 

259 self, 

260 app: Any, 

261 raise_app_exceptions: bool = True, 

262 root_path: str = "", 

263 client: Tuple[str, int] = ("testclient", 5000), 

264 app_state: Optional[Dict[str, Any]] = None, 

265 websocket_timeout: float = 5.0, 

266 ): 

267 self.app = app 

268 self.raise_app_exceptions = raise_app_exceptions 

269 self.root_path = root_path 

270 self.client = client 

271 self.app_state = app_state or {} 

272 self.websocket_timeout = websocket_timeout 

273 

274 async def handle_async_request(self, request: httpx.Request) -> httpx.Response: 

275 """Handle the incoming request and route it appropriately.""" 

276 scheme, netloc, path, raw_path, query = self._parse_url(request) 

277 host, port = self._get_host_port(netloc, scheme) 

278 headers = self._prepare_headers(request, host, port) 

279 

280 if scheme in {"ws", "wss"}: 

281 return await self._handle_websocket( 

282 request, scheme, path, raw_path, query, headers, host, port 

283 ) 

284 return await self._handle_http( 

285 request, scheme, path, raw_path, query, headers, host, port 

286 ) 

287 

288 async def _handle_websocket( 

289 self, 

290 request: httpx.Request, 

291 scheme: str, 

292 path: str, 

293 raw_path: bytes, 

294 query: str, 

295 headers: HeaderList, 

296 host: str, 

297 port: int, 

298 ) -> httpx.Response: 

299 """Handle WebSocket requests.""" 

300 import base64 

301 import hashlib 

302 

303 def calculate_accept(key: str) -> str: 

304 """Calculate Sec-WebSocket-Accept header value according to RFC 6455.""" 

305 GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 

306 accept = hashlib.sha1((key + GUID).encode()).digest() 

307 return base64.b64encode(accept).decode() 

308 

309 # Get the WebSocket key from request headers 

310 ws_key = None 

311 for k, v in headers: 

312 if k.lower() == b"sec-websocket-key": 

313 ws_key = v.decode() 

314 break 

315 

316 if not ws_key: 

317 return httpx.Response(400, content=b"Missing WebSocket Key") 

318 

319 # Calculate the accept key 

320 ws_accept = calculate_accept(ws_key) 

321 

322 # Build WebSocket scope 

323 scope = self._build_websocket_scope( 

324 request, scheme, path, raw_path, query, headers, host, port 

325 ) 

326 

327 # Create WebSocket connection handler 

328 websocket = WebSocketConnection( 

329 app=self.app, 

330 scope=scope, 

331 raise_app_exceptions=self.raise_app_exceptions, 

332 timeout=self.websocket_timeout, 

333 ) 

334 

335 # Return a 101 Switching Protocols response with proper WebSocket headers 

336 response_headers = { 

337 "Upgrade": "websocket", 

338 "Connection": "Upgrade", 

339 "Sec-WebSocket-Accept": ws_accept, 

340 } 

341 

342 # Add subprotocol if specified 

343 if "sec-websocket-protocol" in request.headers: 

344 response_headers["Sec-WebSocket-Protocol"] = request.headers[ 

345 "sec-websocket-protocol" 

346 ] 

347 

348 return httpx.Response( 

349 101, 

350 headers=response_headers, 

351 request=request, 

352 extensions={"websocket": websocket}, 

353 ) 

354 

355 async def _handle_http( 

356 self, 

357 request: httpx.Request, 

358 scheme: str, 

359 path: str, 

360 raw_path: bytes, 

361 query: str, 

362 headers: HeaderList, 

363 host: str, 

364 port: int, 

365 ) -> httpx.Response: 

366 """Handle HTTP requests.""" 

367 scope = self._build_http_scope( 

368 request, scheme, path, raw_path, query, headers, host, port 

369 ) 

370 return await self._send_http_request(scope, request) 

371 

372 async def _send_http_request( 

373 self, scope: ASGIScope, request: httpx.Request 

374 ) -> httpx.Response: 

375 """Send HTTP request to ASGI app and return response.""" 

376 request_complete = False 

377 response_started = False 

378 response_complete = anyio.Event() 

379 response_body = io.BytesIO() 

380 response_headers: List[Tuple[str, str]] = [] 

381 status_code = 500 

382 

383 async def receive() -> Message: 

384 nonlocal request_complete 

385 if request_complete: 

386 await response_complete.wait() 

387 return {"type": "http.disconnect"} 

388 

389 body = await request.aread() 

390 request_complete = True 

391 return {"type": "http.request", "body": body} 

392 

393 async def send(message: Message) -> None: 

394 nonlocal response_started, status_code, response_headers 

395 if message["type"] == "http.response.start": 

396 status_code = message["status"] 

397 response_headers = [ 

398 (k.decode(), v.decode()) for k, v in message.get("headers", []) 

399 ] 

400 response_started = True 

401 elif message["type"] == "http.response.body": 

402 assert response_started, "Received body before response start" 

403 response_body.write(message.get("body", b"")) 

404 if not message.get("more_body", False): 

405 response_body.seek(0) 

406 response_complete.set() 

407 

408 try: 

409 await self.app(scope, receive, send) 

410 except BaseException as exc: 

411 if self.raise_app_exceptions: 

412 raise exc 

413 status_code = 500 

414 response_body = io.BytesIO(b"Internal Server Error") 

415 

416 if self.raise_app_exceptions and not response_started: 

417 raise RuntimeError("TestClient did not receive any response.") 

418 

419 return httpx.Response( 

420 status_code, 

421 headers=dict(response_headers), 

422 content=response_body.read(), 

423 request=request, 

424 ) 

425 

426 # Helper methods remain the same as previous implementation 

427 def _parse_url(self, request: httpx.Request) -> Tuple[str, str, str, bytes, str]: 

428 return ( 

429 request.url.scheme, 

430 request.url.netloc.decode("ascii"), 

431 request.url.path, 

432 request.url.raw_path, 

433 request.url.query.decode("ascii"), 

434 ) 

435 

436 def _get_host_port(self, netloc: str, scheme: str) -> Tuple[str, int]: 

437 default_ports = {"http": 80, "https": 443, "ws": 80, "wss": 443} 

438 if ":" in netloc: 

439 host, port = netloc.split(":", 1) 

440 return host, int(port) 

441 return netloc, default_ports.get(scheme, 80) 

442 

443 def _prepare_headers( 

444 self, request: httpx.Request, host: str, port: int 

445 ) -> HeaderList: 

446 headers = ( 

447 [(b"host", f"{host}:{port}".encode())] 

448 if "host" not in request.headers 

449 else [] 

450 ) 

451 headers.extend( 

452 [ 

453 (key.lower().encode(), value.encode()) 

454 for key, value in request.headers.multi_items() 

455 ] 

456 ) 

457 return headers 

458 

459 def _build_http_scope( 

460 self, 

461 request: httpx.Request, 

462 scheme: str, 

463 path: str, 

464 raw_path: bytes, 

465 query: str, 

466 headers: HeaderList, 

467 host: str, 

468 port: int, 

469 ) -> ASGIScope: 

470 return { 

471 "type": "http", 

472 "http_version": "1.1", 

473 "method": request.method, 

474 "path": unquote(path), 

475 "raw_path": raw_path.split(b"?", 1)[0], 

476 "root_path": self.root_path, 

477 "scheme": scheme, 

478 "query_string": query.encode(), 

479 "headers": headers, 

480 "client": self.client, 

481 "server": [host, port], 

482 "state": self.app_state.copy(), 

483 } 

484 

485 def _build_websocket_scope( 

486 self, 

487 request: httpx.Request, 

488 scheme: str, 

489 path: str, 

490 raw_path: bytes, 

491 query: str, 

492 headers: HeaderList, 

493 host: str, 

494 port: int, 

495 ) -> ASGIScope: 

496 subprotocols = [] 

497 if "sec-websocket-protocol" in request.headers: 

498 subprotocols = [ 

499 p.strip() for p in request.headers["sec-websocket-protocol"].split(",") 

500 ] 

501 

502 return { 

503 "type": "websocket", 

504 "path": unquote(path), 

505 "raw_path": raw_path.split(b"?", 1)[0], 

506 "root_path": self.root_path, 

507 "scheme": scheme, 

508 "query_string": query.encode(), 

509 "headers": headers, 

510 "client": self.client, 

511 "server": [host, port], 

512 "subprotocols": subprotocols, 

513 "state": self.app_state.copy(), 

514 "app": self.app, # Include app in scope 

515 }