Coverage for nexios\middlewares\core\__init__.py: 82%
137 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
1from __future__ import annotations
3import typing
5import anyio
6from nexios.http.request import ClientDisconnect, Request
7from nexios.http.response import NexiosResponse as Response
8from nexios.types import ASGIApp, Message, Receive, Scope, Send
9from nexios.websockets import WebSocket
10from nexios._utils.async_helpers import collapse_excgroups
12RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
13DispatchFunction = typing.Callable[
14 [Request, Response, typing.Callable[[], typing.Awaitable[Response]]],
15 typing.Awaitable[Response],
16]
17T = typing.TypeVar("T")
19import sys
20from collections.abc import Iterator
21from typing import Any, Protocol
23if sys.version_info >= (3, 10): # pragma: no cover
24 from typing import ParamSpec
25else: # pragma: no cover
26 from typing_extensions import ParamSpec
28from nexios.types import ASGIApp
30P = ParamSpec("P")
33class _MiddlewareFactory(Protocol[P]):
34 def __call__(
35 self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs
36 ) -> ASGIApp: ... # pragma: no cover
39class Middleware:
40 def __init__(
41 self,
42 cls: _MiddlewareFactory[P],
43 *args: P.args,
44 **kwargs: P.kwargs,
45 ) -> None:
46 self.cls = cls
47 self.args = args
48 self.kwargs = kwargs
50 def __iter__(self) -> Iterator[Any]:
51 as_tuple = (self.cls, self.args, self.kwargs)
52 return iter(as_tuple)
54 def __repr__(self) -> str:
55 class_name = self.__class__.__name__
56 args_strings = [f"{value!r}" for value in self.args]
57 option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
58 name = getattr(self.cls, "__name__", "")
59 args_repr = ", ".join([name] + args_strings + option_strings)
60 return f"{class_name}({args_repr})"
63class _CachedRequest(Request):
64 """
65 If the user calls Request.body() from their dispatch function
66 we cache the entire request body in memory and pass that to downstream middlewares,
67 but if they call Request.stream() then all we do is send an
68 empty body so that downstream things don't hang forever.
69 """
71 def __init__(self, scope: Scope, receive: Receive):
72 super().__init__(scope, receive)
73 self._wrapped_rcv_disconnected = False
74 self._wrapped_rcv_consumed = False
75 self._wrapped_rc_stream = self.stream()
77 async def wrapped_receive(self) -> Message:
78 # wrapped_rcv state 1: disconnected
79 if self._wrapped_rcv_disconnected:
80 # we've already sent a disconnect to the downstream app
81 # we don't need to wait to get another one
82 # (although most ASGI servers will just keep sending it)
83 return {"type": "http.disconnect"}
84 # wrapped_rcv state 1: consumed but not yet disconnected
85 if self._wrapped_rcv_consumed:
86 # since the downstream app has consumed us all that is left
87 # is to send it a disconnect
88 if self._is_disconnected:
89 # the middleware has already seen the disconnect
90 # since we know the client is disconnected no need to wait
91 # for the message
92 self._wrapped_rcv_disconnected = True
93 return {"type": "http.disconnect"}
94 # we don't know yet if the client is disconnected or not
95 # so we'll wait until we get that message
96 msg = await self.receive()
97 if msg["type"] != "http.disconnect": # pragma: no cover
98 # at this point a disconnect is all that we should be receiving
99 # if we get something else, things went wrong somewhere
100 raise RuntimeError(f"Unexpected message received: {msg['type']}")
101 self._wrapped_rcv_disconnected = True
102 return msg
104 # wrapped_rcv state 3: not yet consumed
105 if getattr(self, "_body", None) is not None:
106 # body() was called, we return it even if the client disconnected
107 self._wrapped_rcv_consumed = True
108 return {
109 "type": "http.request",
110 "body": self._body,
111 "more_body": False,
112 }
113 elif self._stream_consumed:
114 # stream() was called to completion
115 # return an empty body so that downstream apps don't hang
116 # waiting for a disconnect
117 self._wrapped_rcv_consumed = True
118 return {
119 "type": "http.request",
120 "body": b"",
121 "more_body": False,
122 }
123 else:
124 # body() was never called and stream() wasn't consumed
125 try:
126 stream = self.stream()
127 chunk = await stream.__anext__()
128 self._wrapped_rcv_consumed = self._stream_consumed
129 return {
130 "type": "http.request",
131 "body": chunk,
132 "more_body": not self._stream_consumed,
133 }
134 except ClientDisconnect:
135 self._wrapped_rcv_disconnected = True
136 return {"type": "http.disconnect"}
139class BaseMiddleware:
140 def __init__(self, app: ASGIApp, dispatch: DispatchFunction) -> None:
141 self.app = app
142 self.dispatch_func = dispatch
144 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
145 if scope["type"] != "http":
146 await self.app(scope, receive, send)
147 return
149 request = _CachedRequest(scope, receive)
150 response = Response(request=request)
152 wrapped_receive = request.wrapped_receive
153 response_sent = anyio.Event()
155 async def call_next() -> Response:
156 app_exc: Exception | None = None
158 async def receive_or_disconnect() -> Message:
159 if response_sent.is_set():
160 return {"type": "http.disconnect"}
162 async with anyio.create_task_group() as task_group:
164 async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
165 result = await func()
166 task_group.cancel_scope.cancel()
167 return result
169 task_group.start_soon(wrap, response_sent.wait)
170 message = await wrap(wrapped_receive)
172 if response_sent.is_set():
173 return {"type": "http.disconnect"}
175 return message
177 async def send_no_error(message: Message) -> None:
178 try:
179 await send_stream.send(message)
180 except anyio.BrokenResourceError:
181 # recv_stream has been closed, i.e. response_sent has been set.
182 raise RuntimeError("No response returned")
184 async def coro() -> None:
185 nonlocal app_exc
187 with send_stream:
188 try:
189 await self.app(scope, receive_or_disconnect, send_no_error)
190 except Exception as exc:
191 app_exc = exc
193 task_group.start_soon(coro)
195 try:
196 message = await recv_stream.receive()
197 info = message.get("info", None)
198 if message["type"] == "http.response.debug" and info is not None:
199 message = await recv_stream.receive()
200 except anyio.EndOfStream:
201 if app_exc is not None:
202 raise app_exc
203 raise RuntimeError("No response returned.")
205 assert message["type"] == "http.response.start"
207 async def body_stream() -> typing.AsyncGenerator[bytes, None]:
208 async for message in recv_stream:
209 assert message["type"] == "http.response.body"
210 body = message.get("body", b"")
211 if body:
212 yield body
213 if not message.get("more_body", False):
214 break
216 if app_exc is not None:
217 raise app_exc
219 response_ = response.stream(iterator=body_stream(), status_code=message["status"]) # type: ignore
220 response_._response._headers = message["headers"] # type: ignore
221 return response
223 streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream() # type: ignore
224 send_stream, recv_stream = streams
225 with recv_stream, send_stream, collapse_excgroups():
226 async with anyio.create_task_group() as task_group:
227 await self.dispatch_func(request, response, call_next) # type: ignore
228 await response.get_response()(scope, wrapped_receive, send)
229 response_sent.set()
230 recv_stream.close()
233WebSocketDispatchFunction = typing.Callable[
234 ["WebSocket", typing.Coroutine[None, None, typing.Any]], typing.Awaitable[None]
235]
238MiddlewareType = typing.Callable[
239 [Request, Response, typing.Awaitable[None]],
240 typing.Callable[[], typing.Awaitable[None]],
241]
244def wrap_middleware(middleware_function: DispatchFunction) -> Middleware:
245 return Middleware(BaseMiddleware, dispatch=middleware_function)
248__all__ = ["BaseMiddleware"]