Coverage for nexios\http\request.py: 70%

267 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-21 20:31 +0100

1from __future__ import annotations 

2 

3import json 

4import typing 

5from http import cookies as http_cookies 

6 

7import anyio 

8from nexios._utils.async_helpers import ( 

9 AwaitableOrContextManager, 

10 AwaitableOrContextManagerWrapper, 

11) 

12from nexios.session.base import BaseSessionInterface 

13from nexios.http.formparsers import FormParser, MultiPartParser 

14from nexios.structs import URL, Address, FormData, Headers, QueryParams, State 

15from .formparsers import FormParser, MultiPartException, MultiPartParser 

16 

17try: 

18 from python_multipart.multipart import parse_options_header # type:ignore 

19 

20except ImportError: 

21 parse_options_header = None 

22Scope = typing.MutableMapping[str, typing.Any] 

23Message = typing.MutableMapping[str, typing.Any] 

24 

25Receive = typing.Callable[[], typing.Awaitable[Message]] 

26Send = typing.Callable[[Message], typing.Awaitable[None]] 

27JSONType = typing.Union[ 

28 str, int, float, bool, None, typing.Dict[str, typing.Any], typing.List[typing.Any] 

29] 

30 

31SERVER_PUSH_HEADERS_TO_COPY = { 

32 "accept", 

33 "accept-encoding", 

34 "accept-language", 

35 "cache-control", 

36 "user-agent", 

37} 

38 

39 

40def cookie_parser(cookie_string: str) -> dict[str, str]: 

41 """ 

42 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. 

43 

44 It attempts to mimic browser cookie parsing behavior: browsers and web servers 

45 frequently disregard the spec (RFC 6265) when setting and reading cookies, 

46 so we attempt to suit the common scenarios here. 

47 

48 This function has been adapted from Django 3.1.0. 

49 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based 

50 on an outdated spec and will fail on lots of input we want to support 

51 """ 

52 cookie_dict: dict[str, str] = {} 

53 for chunk in cookie_string.split(";"): 

54 if "=" in chunk: 

55 key, val = chunk.split("=", 1) 

56 else: 

57 # Assume an empty name per 

58 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 

59 key, val = "", chunk 

60 key, val = key.strip(), val.strip() 

61 if key or val: 

62 # unquote using Python's algorithm. 

63 cookie_dict[key] = http_cookies._unquote(val) # type:ignore 

64 return cookie_dict 

65 

66 

67class ClientDisconnect(Exception): 

68 pass 

69 

70 

71T = typing.TypeVar("T") 

72 

73 

74class HTTPConnection(object): 

75 """ 

76 A base class for incoming HTTP connections, that is used to provide 

77 any functionality that is common to both `Request` and `WebSocket`. 

78 """ 

79 

80 def __init__(self, scope: Scope, receive: Receive) -> None: 

81 assert scope["type"] in ("http", "websocket") 

82 self.scope = scope 

83 

84 def __getitem__(self, key: str) -> typing.Any: 

85 return self.scope[key] 

86 

87 def __iter__(self) -> typing.Iterator[str]: 

88 return iter(self.scope) 

89 

90 def __len__(self) -> int: 

91 return len(self.scope) 

92 

93 __eq__ = object.__eq__ 

94 __hash__ = object.__hash__ 

95 

96 @property 

97 def app(self) -> typing.Any: 

98 return self.scope["app"] 

99 

100 @property 

101 def base_app(self) -> "NexiosApp": # type: ignore 

102 return self.scope["base_app"] 

103 

104 @property 

105 def url(self) -> URL: 

106 if not hasattr(self, "_url"): # pragma: no branch 

107 self._url = URL(scope=self.scope) 

108 return self._url 

109 

110 @property 

111 def base_url(self) -> URL: 

112 if not hasattr(self, "_base_url"): 

113 base_url_scope = dict(self.scope) 

114 app_root_path = base_url_scope.get( 

115 "app_root_path", base_url_scope.get("root_path", "") 

116 ) 

117 path = app_root_path 

118 if not path.endswith("/"): 

119 path += "/" 

120 base_url_scope["path"] = path 

121 base_url_scope["query_string"] = b"" 

122 base_url_scope["root_path"] = app_root_path 

123 self._base_url = URL(scope=base_url_scope) 

124 return self._base_url 

125 

126 @property 

127 def headers(self) -> Headers: 

128 if not hasattr(self, "_headers"): 

129 self._headers = Headers(scope=self.scope) 

130 return self._headers 

131 

132 @property 

133 def path(self) -> str: 

134 return self.url.path 

135 

136 @property 

137 def query_params(self) -> QueryParams: 

138 if not hasattr(self, "_query_params"): # pragma: no branch 

139 self._query_params = QueryParams(self.scope["query_string"]) 

140 return self._query_params 

141 

142 @property 

143 def path_params(self) -> dict[str, typing.Any]: 

144 return self.scope.get("route_params", {}) 

145 

146 @property 

147 def cookies(self) -> dict[str, str]: 

148 if not hasattr(self, "_cookies"): 

149 cookies: dict[str, str] = {} 

150 cookie_header = self.headers.get("cookie") 

151 

152 if cookie_header: 

153 cookies = cookie_parser(cookie_header) 

154 self._cookies = cookies 

155 return self._cookies 

156 

157 @property 

158 def client(self) -> typing.Union[Address, None]: 

159 host_port = self.scope.get("client") 

160 if host_port is not None: 

161 return Address(*host_port) 

162 return None 

163 

164 @property 

165 def state(self) -> State: 

166 if not hasattr(self, "_state"): 

167 # Ensure 'state' has an empty dict if it's not already populated. 

168 self.scope.setdefault("state", {}) 

169 # Create a state instance with a reference to the dict in which it should 

170 # store info 

171 self._state = State(self.scope["state"]) 

172 return self._state 

173 

174 @property 

175 def origin(self): 

176 return self.headers.get("Origin") 

177 

178 @property 

179 def user_agent(self) -> str: 

180 """Returns the User-Agent header if available.""" 

181 return self.headers.get("user-agent", "") 

182 

183 def build_absolute_uri( 

184 self, path: str = "", query_params: typing.Optional[dict[str, str]] = None 

185 ) -> str: 

186 """ 

187 Builds an absolute URI using the base URL and the provided path. 

188 

189 :param path: A relative path to append to the base URL. 

190 :param query_params: Optional query parameters to append as a query string. 

191 :return: A fully constructed absolute URI as a string. 

192 """ 

193 base_url = str(self.base_url).rstrip("/") 

194 

195 if path.startswith("/"): 

196 uri = f"{base_url}{path}" 

197 else: 

198 uri = f"{base_url}/{path}" 

199 

200 if query_params: 

201 from urllib.parse import urlencode 

202 

203 query_string = urlencode(query_params) 

204 uri = f"{uri}?{query_string}" 

205 

206 return uri 

207 

208 

209async def empty_receive() -> typing.NoReturn: 

210 raise RuntimeError("Receive channel has not been made available") 

211 

212 

213async def empty_send(message: Message) -> typing.NoReturn: 

214 raise RuntimeError("Send channel has not been made available") 

215 

216 

217class Request(HTTPConnection): 

218 _form: typing.Union[FormData, None, typing.Dict[str, typing.Any]] # type: ignore 

219 

220 def __init__( 

221 self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send 

222 ): 

223 super().__init__(scope, receive) 

224 assert scope["type"] == "http" 

225 self._receive = receive 

226 self._send = send 

227 self._stream_consumed = False 

228 self._is_disconnected = False 

229 self._form = None # type: ignore 

230 

231 @property 

232 def method(self) -> str: 

233 return self.scope["method"] 

234 

235 @property 

236 def receive(self): 

237 return self._receive 

238 

239 @property 

240 def content_type(self) -> typing.Optional[str]: 

241 content_type_header = self.headers.get("Content-Type") 

242 content_type: str 

243 content_type, _ = parse_options_header(content_type_header) # type:ignore 

244 return content_type # type:ignore 

245 

246 async def stream(self) -> typing.AsyncGenerator[bytes, None]: 

247 if hasattr(self, "_body"): 

248 yield self._body 

249 yield b"" 

250 return 

251 if self._stream_consumed: 

252 raise RuntimeError("Stream consumed") 

253 while not self._stream_consumed: 

254 message = await self._receive() 

255 if message["type"] == "http.request": 

256 body = message.get("body", b"") 

257 if not message.get("more_body", False): 

258 self._stream_consumed = True 

259 if body: 

260 yield body 

261 elif message["type"] == "http.disconnect": 

262 self._is_disconnected = True 

263 raise ClientDisconnect() 

264 yield b"" 

265 

266 async def body(self) -> bytes: 

267 if not hasattr(self, "_body"): 

268 chunks: list[bytes] = [] 

269 async for chunk in self.stream(): 

270 chunks.append(chunk) 

271 self._body = b"".join(chunks) 

272 return self._body 

273 

274 @property 

275 async def json(self) -> typing.Union[JSONType, dict[str, typing.Any]]: 

276 

277 if not hasattr(self, "_json"): 

278 _body = await self.body() 

279 try: 

280 body = _body.decode() 

281 except UnicodeDecodeError: 

282 return {} 

283 try: 

284 self._json: JSONType = json.loads(body) 

285 except json.JSONDecodeError: 

286 self._json = {} 

287 return self._json 

288 

289 @property 

290 async def text(self) -> str: 

291 """ 

292 Read and decode the body of the request as text. 

293 

294 Returns: 

295 str: The decoded text content of the request body. 

296 """ 

297 if not hasattr(self, "_text"): 

298 body = await self.body() 

299 try: 

300 self._text = body.decode("utf-8") 

301 except UnicodeDecodeError: 

302 self._text = body.decode("latin-1") 

303 return self._text 

304 

305 async def _get_form( 

306 self, 

307 *, 

308 max_files: typing.Optional[int] = 1000, 

309 max_fields: typing.Optional[int] = 1000, 

310 ) -> FormData: 

311 if self._form is None: # type:ignore 

312 assert ( 

313 parse_options_header is not None 

314 ), "The `python-multipart` library must be installed to use form parsing." 

315 content_type_header = self.headers.get("Content-Type") 

316 content_type: bytes 

317 content_type, _ = parse_options_header(content_type_header) # type:ignore 

318 if content_type == b"multipart/form-data": 

319 try: 

320 multipart_parser = MultiPartParser( 

321 self.headers, 

322 self.stream(), 

323 max_files=max_files, 

324 max_fields=max_fields, 

325 ) 

326 self._form = await multipart_parser.parse() 

327 except MultiPartException as _: 

328 self._form = {} # type: ignore 

329 elif content_type == b"application/x-www-form-urlencoded": 

330 form_parser = FormParser(self.headers, self.stream()) 

331 self._form = await form_parser.parse() 

332 else: 

333 self._form: FormData = FormData() 

334 return self._form # type:ignore 

335 

336 @property 

337 def form_data( 

338 self, 

339 *, 

340 max_files: typing.Optional[int] = 1000, 

341 max_fields: typing.Optional[int] = 1000, 

342 ) -> AwaitableOrContextManager[FormData]: 

343 return AwaitableOrContextManagerWrapper( 

344 self._get_form(max_files=max_files, max_fields=max_fields) 

345 ) 

346 

347 async def close(self) -> None: 

348 if self._form is not None: # type: ignore 

349 await self._form.close() 

350 

351 async def is_disconnected(self) -> bool: 

352 if not self._is_disconnected: 

353 message: typing.Dict[str, typing.Any] = {} 

354 

355 # If message isn't immediately available, move on 

356 with anyio.CancelScope() as cs: # type: ignore 

357 cs.cancel() # type: ignore 

358 message = await self._receive() # type:ignore 

359 

360 if message.get("type") == "http.disconnect": 

361 self._is_disconnected = True 

362 

363 return self._is_disconnected 

364 

365 async def send_push_promise(self, path: str) -> None: 

366 if "http.response.push" in self.scope.get("extensions", {}): 

367 raw_headers: list[tuple[bytes, bytes]] = [] 

368 for name in SERVER_PUSH_HEADERS_TO_COPY: 

369 for value in self.headers.getlist(name): 

370 raw_headers.append( 

371 (name.encode("latin-1"), value.encode("latin-1")) 

372 ) 

373 await self._send( 

374 {"type": "http.response.push", "path": path, "headers": raw_headers} 

375 ) 

376 

377 @property 

378 async def files(self) -> typing.Dict[str, typing.Any]: 

379 """ 

380 This method returns a dictionary of files from the form_data. 

381 """ 

382 form_data: FormData = await self.form_data 

383 files_dict: typing.Dict[str, typing.Any] = {} 

384 for key, value in form_data.items(): 

385 if isinstance(value, (list, tuple)): 

386 for item in value: # type: ignore 

387 if hasattr(item, "filename"): # type: ignore 

388 files_dict[key] = item 

389 elif hasattr(value, "filename"): 

390 files_dict[key] = value 

391 return files_dict 

392 

393 @property 

394 async def form(self) -> FormData: 

395 """ 

396 Parse and return form data from the request body. 

397 Handles both URL-encoded and multipart form data. 

398 Uses the existing form_data property which already handles all form types. 

399 """ 

400 if not hasattr(self, "_form"): 

401 form_data = await self.form_data 

402 self._form = form_data 

403 return self._form 

404 

405 def valid(self) -> bool: 

406 """ 

407 Checks if the request is valid by ensuring the method and headers are properly set. 

408 """ 

409 return self.method in { 

410 "GET", 

411 "POST", 

412 "PUT", 

413 "DELETE", 

414 "PATCH", 

415 "HEAD", 

416 "OPTIONS", 

417 } and bool(self.headers) 

418 

419 @property 

420 def session(self) -> BaseSessionInterface: 

421 assert "session" in self.scope.keys(), "No Session Middleware Installed" 

422 return self.scope["session"] 

423 

424 @property 

425 def user(self): 

426 return self.scope.get("user", None) 

427 

428 def url_for(self, _name: str, **path_params: typing.Dict[str, typing.Any]) -> str: 

429 return self.base_app.url_for(_name, **path_params) 

430 

431 @user.setter 

432 def user(self, value: str): 

433 self.scope["user"] = value 

434 

435 def __str__(self) -> str: 

436 return f"<Request {self.method} {self.url}>"