Coverage for nexios\middlewares\core\__init__.py: 82%

137 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-21 20:31 +0100

1from __future__ import annotations 

2 

3import typing 

4 

5import anyio 

6from nexios.http.request import ClientDisconnect, Request 

7from nexios.http.response import NexiosResponse as Response 

8from nexios.types import ASGIApp, Message, Receive, Scope, Send 

9from nexios.websockets import WebSocket 

10from nexios._utils.async_helpers import collapse_excgroups 

11 

12RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] 

13DispatchFunction = typing.Callable[ 

14 [Request, Response, typing.Callable[[], typing.Awaitable[Response]]], 

15 typing.Awaitable[Response], 

16] 

17T = typing.TypeVar("T") 

18 

19import sys 

20from collections.abc import Iterator 

21from typing import Any, Protocol 

22 

23if sys.version_info >= (3, 10): # pragma: no cover 

24 from typing import ParamSpec 

25else: # pragma: no cover 

26 from typing_extensions import ParamSpec 

27 

28from nexios.types import ASGIApp 

29 

30P = ParamSpec("P") 

31 

32 

33class _MiddlewareFactory(Protocol[P]): 

34 def __call__( 

35 self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs 

36 ) -> ASGIApp: ... # pragma: no cover 

37 

38 

39class Middleware: 

40 def __init__( 

41 self, 

42 cls: _MiddlewareFactory[P], 

43 *args: P.args, 

44 **kwargs: P.kwargs, 

45 ) -> None: 

46 self.cls = cls 

47 self.args = args 

48 self.kwargs = kwargs 

49 

50 def __iter__(self) -> Iterator[Any]: 

51 as_tuple = (self.cls, self.args, self.kwargs) 

52 return iter(as_tuple) 

53 

54 def __repr__(self) -> str: 

55 class_name = self.__class__.__name__ 

56 args_strings = [f"{value!r}" for value in self.args] 

57 option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()] 

58 name = getattr(self.cls, "__name__", "") 

59 args_repr = ", ".join([name] + args_strings + option_strings) 

60 return f"{class_name}({args_repr})" 

61 

62 

63class _CachedRequest(Request): 

64 """ 

65 If the user calls Request.body() from their dispatch function 

66 we cache the entire request body in memory and pass that to downstream middlewares, 

67 but if they call Request.stream() then all we do is send an 

68 empty body so that downstream things don't hang forever. 

69 """ 

70 

71 def __init__(self, scope: Scope, receive: Receive): 

72 super().__init__(scope, receive) 

73 self._wrapped_rcv_disconnected = False 

74 self._wrapped_rcv_consumed = False 

75 self._wrapped_rc_stream = self.stream() 

76 

77 async def wrapped_receive(self) -> Message: 

78 # wrapped_rcv state 1: disconnected 

79 if self._wrapped_rcv_disconnected: 

80 # we've already sent a disconnect to the downstream app 

81 # we don't need to wait to get another one 

82 # (although most ASGI servers will just keep sending it) 

83 return {"type": "http.disconnect"} 

84 # wrapped_rcv state 1: consumed but not yet disconnected 

85 if self._wrapped_rcv_consumed: 

86 # since the downstream app has consumed us all that is left 

87 # is to send it a disconnect 

88 if self._is_disconnected: 

89 # the middleware has already seen the disconnect 

90 # since we know the client is disconnected no need to wait 

91 # for the message 

92 self._wrapped_rcv_disconnected = True 

93 return {"type": "http.disconnect"} 

94 # we don't know yet if the client is disconnected or not 

95 # so we'll wait until we get that message 

96 msg = await self.receive() 

97 if msg["type"] != "http.disconnect": # pragma: no cover 

98 # at this point a disconnect is all that we should be receiving 

99 # if we get something else, things went wrong somewhere 

100 raise RuntimeError(f"Unexpected message received: {msg['type']}") 

101 self._wrapped_rcv_disconnected = True 

102 return msg 

103 

104 # wrapped_rcv state 3: not yet consumed 

105 if getattr(self, "_body", None) is not None: 

106 # body() was called, we return it even if the client disconnected 

107 self._wrapped_rcv_consumed = True 

108 return { 

109 "type": "http.request", 

110 "body": self._body, 

111 "more_body": False, 

112 } 

113 elif self._stream_consumed: 

114 # stream() was called to completion 

115 # return an empty body so that downstream apps don't hang 

116 # waiting for a disconnect 

117 self._wrapped_rcv_consumed = True 

118 return { 

119 "type": "http.request", 

120 "body": b"", 

121 "more_body": False, 

122 } 

123 else: 

124 # body() was never called and stream() wasn't consumed 

125 try: 

126 stream = self.stream() 

127 chunk = await stream.__anext__() 

128 self._wrapped_rcv_consumed = self._stream_consumed 

129 return { 

130 "type": "http.request", 

131 "body": chunk, 

132 "more_body": not self._stream_consumed, 

133 } 

134 except ClientDisconnect: 

135 self._wrapped_rcv_disconnected = True 

136 return {"type": "http.disconnect"} 

137 

138 

139class BaseMiddleware: 

140 def __init__(self, app: ASGIApp, dispatch: DispatchFunction) -> None: 

141 self.app = app 

142 self.dispatch_func = dispatch 

143 

144 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

145 if scope["type"] != "http": 

146 await self.app(scope, receive, send) 

147 return 

148 

149 request = _CachedRequest(scope, receive) 

150 response = Response(request=request) 

151 

152 wrapped_receive = request.wrapped_receive 

153 response_sent = anyio.Event() 

154 

155 async def call_next() -> Response: 

156 app_exc: Exception | None = None 

157 

158 async def receive_or_disconnect() -> Message: 

159 if response_sent.is_set(): 

160 return {"type": "http.disconnect"} 

161 

162 async with anyio.create_task_group() as task_group: 

163 

164 async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T: 

165 result = await func() 

166 task_group.cancel_scope.cancel() 

167 return result 

168 

169 task_group.start_soon(wrap, response_sent.wait) 

170 message = await wrap(wrapped_receive) 

171 

172 if response_sent.is_set(): 

173 return {"type": "http.disconnect"} 

174 

175 return message 

176 

177 async def send_no_error(message: Message) -> None: 

178 try: 

179 await send_stream.send(message) 

180 except anyio.BrokenResourceError: 

181 # recv_stream has been closed, i.e. response_sent has been set. 

182 raise RuntimeError("No response returned") 

183 

184 async def coro() -> None: 

185 nonlocal app_exc 

186 

187 with send_stream: 

188 try: 

189 await self.app(scope, receive_or_disconnect, send_no_error) 

190 except Exception as exc: 

191 app_exc = exc 

192 

193 task_group.start_soon(coro) 

194 

195 try: 

196 message = await recv_stream.receive() 

197 info = message.get("info", None) 

198 if message["type"] == "http.response.debug" and info is not None: 

199 message = await recv_stream.receive() 

200 except anyio.EndOfStream: 

201 if app_exc is not None: 

202 raise app_exc 

203 raise RuntimeError("No response returned.") 

204 

205 assert message["type"] == "http.response.start" 

206 

207 async def body_stream() -> typing.AsyncGenerator[bytes, None]: 

208 async for message in recv_stream: 

209 assert message["type"] == "http.response.body" 

210 body = message.get("body", b"") 

211 if body: 

212 yield body 

213 if not message.get("more_body", False): 

214 break 

215 

216 if app_exc is not None: 

217 raise app_exc 

218 

219 response_ = response.stream(iterator=body_stream(), status_code=message["status"]) # type: ignore 

220 response_._response._headers = message["headers"] # type: ignore 

221 return response 

222 

223 streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream() # type: ignore 

224 send_stream, recv_stream = streams 

225 with recv_stream, send_stream, collapse_excgroups(): 

226 async with anyio.create_task_group() as task_group: 

227 await self.dispatch_func(request, response, call_next) # type: ignore 

228 await response.get_response()(scope, wrapped_receive, send) 

229 response_sent.set() 

230 recv_stream.close() 

231 

232 

233WebSocketDispatchFunction = typing.Callable[ 

234 ["WebSocket", typing.Coroutine[None, None, typing.Any]], typing.Awaitable[None] 

235] 

236 

237 

238MiddlewareType = typing.Callable[ 

239 [Request, Response, typing.Awaitable[None]], 

240 typing.Callable[[], typing.Awaitable[None]], 

241] 

242 

243 

244def wrap_middleware(middleware_function: DispatchFunction) -> Middleware: 

245 return Middleware(BaseMiddleware, dispatch=middleware_function) 

246 

247 

248__all__ = ["BaseMiddleware"]