Coverage for nexios\http\formparsers.py: 82%
185 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
1from __future__ import annotations
3import typing
4import urllib.parse
5from dataclasses import dataclass, field
6from enum import Enum
7from tempfile import SpooledTemporaryFile
10from nexios.structs import FormData, Headers, UploadedFile
12if typing.TYPE_CHECKING:
13 import multipart # type:ignore
14 from multipart.multipart import ( # type:ignore
15 parse_options_header,
16 )
17else:
18 try:
19 try:
20 import python_multipart as multipart
21 from python_multipart.multipart import parse_options_header
22 except ModuleNotFoundError:
23 import multipart
24 from multipart.multipart import parse_options_header
25 except ModuleNotFoundError:
26 multipart = None
27 parse_options_header = None
30class FormMessage(Enum):
31 FIELD_START = 1
32 FIELD_NAME = 2
33 FIELD_DATA = 3
34 FIELD_END = 4
35 END = 5
38@dataclass
39class MultipartPart:
40 content_disposition: typing.Optional[bytes] = None
41 field_name: str = ""
42 data: bytearray = field(default_factory=bytearray)
43 file: typing.Optional[UploadedFile] = None
44 item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
47def _user_safe_decode(src: typing.Union[bytes, bytearray], codec: str) -> str:
48 try:
49 return src.decode(codec)
50 except (UnicodeDecodeError, LookupError):
51 return src.decode("latin-1")
54class MultiPartException(Exception):
55 def __init__(self, message: str) -> None:
56 self.message = message
59class FormParser:
60 def __init__(
61 self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
62 ) -> None:
63 assert (
64 multipart is not None
65 ), "The `python-multipart` library must be installed to use form parsing."
66 self.headers = headers
67 self.stream = stream
68 self.messages: list[tuple[FormMessage, bytes]] = []
70 def on_field_start(self) -> None:
71 message = (FormMessage.FIELD_START, b"")
72 self.messages.append(message)
74 def on_field_name(self, data: bytes, start: int, end: int) -> None:
75 message = (FormMessage.FIELD_NAME, data[start:end])
76 self.messages.append(message)
78 def on_field_data(self, data: bytes, start: int, end: int) -> None:
79 message = (FormMessage.FIELD_DATA, data[start:end])
80 self.messages.append(message)
82 def on_field_end(self) -> None:
83 message = (FormMessage.FIELD_END, b"")
84 self.messages.append(message)
86 def on_end(self) -> None:
87 message = (FormMessage.END, b"")
88 self.messages.append(message)
90 async def parse(self) -> FormData:
91 """
92 Parse the request stream as form data.
94 Returns:
95 FormData: The parsed form data.
96 """
97 content_type = self.headers.get("content-type", "")
98 if content_type.startswith("multipart/form-data"):
99 multipart_parser = MultiPartParser(self.headers, self.stream)
100 return await multipart_parser.parse()
102 # Default to application/x-www-form-urlencoded
103 form = FormData()
104 content = b""
106 # Collect all chunks into a single content buffer
107 async for chunk in self.stream:
108 if chunk:
109 content += chunk
111 if content:
112 try:
113 # Use parse_qsl to get a list of key-value pairs
114 field_items = urllib.parse.parse_qsl(
115 content.decode("utf-8"), keep_blank_values=True
116 )
118 # Add each field to the form data
119 for key, value in field_items:
120 # URL decode the value to handle special characters
121 decoded_value = urllib.parse.unquote(value)
122 form.append(key, decoded_value)
123 except (UnicodeDecodeError, ValueError) as e:
124 # If there's a decoding error, try with latin-1 encoding
125 try:
126 field_items = urllib.parse.parse_qsl(
127 content.decode("latin-1"), keep_blank_values=True
128 )
129 for key, value in field_items:
130 decoded_value = urllib.parse.unquote(value)
131 form.append(key, decoded_value)
132 except Exception:
133 # If still can't parse, return empty form
134 pass
136 return form
139class MultiPartParser:
140 max_file_size = 1024 * 1024 # 1MB
141 max_part_size = 1024 * 1024 # 1MB
142 max_fields = 1000
143 max_files = 1000
145 def __init__(
146 self,
147 headers: Headers,
148 stream: typing.AsyncGenerator[bytes, None],
149 *,
150 max_fields: typing.Optional[int] = None,
151 max_files: typing.Optional[int] = None,
152 ) -> None:
153 assert (
154 multipart is not None
155 ), "The `python-multipart` library must be installed to use form parsing."
156 self.headers = headers
157 self.stream = stream
158 self.max_files = max_files if max_files is not None else self.max_files
159 self.max_fields = max_fields if max_fields is not None else self.max_fields
160 self.items: list[tuple[str, typing.Union[str, UploadedFile]]] = []
161 self._current_files = 0
162 self._current_fields = 0
163 self._current_partial_header_name: bytes = b""
164 self._current_partial_header_value: bytes = b""
165 self._current_part = MultipartPart()
166 self._charset = ""
167 self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
168 self._file_parts_to_finish: list[MultipartPart] = []
169 self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
171 def on_part_begin(self) -> None:
172 self._current_part = MultipartPart()
174 def on_part_data(self, data: bytes, start: int, end: int) -> None:
175 message_bytes = data[start:end]
176 if self._current_part.file is None:
177 if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
178 raise MultiPartException(
179 f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB."
180 )
181 self._current_part.data.extend(message_bytes)
182 else:
183 # Check file size limit when writing file parts
184 if self._current_part.file and self._current_part.file.size is not None:
185 new_size = self._current_part.file.size + len(message_bytes)
186 if new_size > self.max_file_size:
187 raise MultiPartException(
188 f"File too large. Maximum size is {self.max_file_size} bytes"
189 )
190 self._file_parts_to_write.append((self._current_part, message_bytes))
192 def on_part_end(self) -> None:
193 if self._current_part.file is None:
194 self.items.append(
195 (
196 self._current_part.field_name,
197 _user_safe_decode(
198 self._current_part.data, self._charset # type: ignore
199 ),
200 )
201 )
202 else:
203 self._file_parts_to_finish.append(self._current_part)
204 # The file can be added to the items right now even though it's not
205 # finished yet, because it will be finished in the `parse()` method, before
206 # self.items is used in the return value.
207 self.items.append((self._current_part.field_name, self._current_part.file))
209 def on_header_field(self, data: bytes, start: int, end: int) -> None:
210 self._current_partial_header_name += data[start:end]
212 def on_header_value(self, data: bytes, start: int, end: int) -> None:
213 self._current_partial_header_value += data[start:end]
215 def on_header_end(self) -> None:
216 field = self._current_partial_header_name.lower()
217 if field == b"content-disposition":
218 self._current_part.content_disposition = self._current_partial_header_value
219 self._current_part.item_headers.append(
220 (field, self._current_partial_header_value)
221 )
222 self._current_partial_header_name = b""
223 self._current_partial_header_value = b""
225 def on_headers_finished(self) -> None:
226 _, options = parse_options_header(self._current_part.content_disposition)
227 try:
228 self._current_part.field_name = _user_safe_decode(
229 options[b"name"], self._charset # type: ignore
230 )
231 except KeyError:
232 raise MultiPartException(
233 'The Content-Disposition header field "name" must be provided.'
234 )
235 if b"filename" in options:
236 self._current_files += 1
237 if self._current_files > self.max_files:
238 raise MultiPartException(
239 f"Too many files. Maximum number of files is {self.max_files}."
240 )
241 filename = _user_safe_decode(
242 options[b"filename"], self._charset # type: ignore
243 ) # type:ignore
244 tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
245 self._files_to_close_on_error.append(tempfile)
246 self._current_part.file = UploadedFile(
247 file=tempfile, # type: ignore[arg-type]
248 size=0,
249 filename=filename,
250 headers=Headers(raw=self._current_part.item_headers),
251 )
252 else:
253 self._current_fields += 1
254 if self._current_fields > self.max_fields:
255 raise MultiPartException(
256 f"Too many fields. Maximum number of fields is {self.max_fields}."
257 )
258 self._current_part.file = None
260 def on_end(self) -> None:
261 pass
263 async def parse(self) -> FormData:
264 """Parse the form data from the request body."""
265 content_type = self.headers.get("content-type", "")
266 content_type, params = parse_options_header(content_type)
268 if content_type != b"multipart/form-data":
269 return FormData()
271 boundary = params.get(b"boundary")
272 if not boundary:
273 return FormData()
275 charset = params.get(b"charset")
276 self._charset = charset.decode("latin-1") if charset else "utf-8"
278 callbacks: typing.Dict[str, typing.Callable[..., typing.Any]] = {
279 "on_part_begin": self.on_part_begin,
280 "on_part_data": self.on_part_data,
281 "on_part_end": self.on_part_end,
282 "on_header_field": self.on_header_field,
283 "on_header_value": self.on_header_value,
284 "on_header_end": self.on_header_end,
285 "on_headers_finished": self.on_headers_finished,
286 "on_end": self.on_end,
287 }
289 parser = multipart.MultipartParser(boundary, callbacks) # type:ignore
290 try:
291 # Feed the parser with data from the request.
292 async for chunk in self.stream:
293 parser.write(chunk) # type:ignore
294 # Write file data, it needs to use await with the UploadedFile methods
295 # that call the corresponding file methods *in a threadpool*,
296 # otherwise, if they were called directly in the callback methods above
297 # (regular, non-async functions), that would block the event loop in
298 # the main thread.
299 for part, data in self._file_parts_to_write:
300 # assert part.file # for type checkers
301 await part.file.write(data) # type:ignore
302 for part in self._file_parts_to_finish:
303 assert part.file # for type checkers
304 await part.file.seek(0)
305 self._file_parts_to_write.clear()
306 self._file_parts_to_finish.clear()
307 except MultiPartException as exc:
308 # Close all the files if there was an error.
309 for file in self._files_to_close_on_error:
310 file.close()
311 raise exc
313 parser.finalize() # type:ignore
314 return FormData(self.items)