Coverage for nexios\http\formparsers.py: 82%

185 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-21 20:31 +0100

1from __future__ import annotations 

2 

3import typing 

4import urllib.parse 

5from dataclasses import dataclass, field 

6from enum import Enum 

7from tempfile import SpooledTemporaryFile 

8 

9 

10from nexios.structs import FormData, Headers, UploadedFile 

11 

12if typing.TYPE_CHECKING: 

13 import multipart # type:ignore 

14 from multipart.multipart import ( # type:ignore 

15 parse_options_header, 

16 ) 

17else: 

18 try: 

19 try: 

20 import python_multipart as multipart 

21 from python_multipart.multipart import parse_options_header 

22 except ModuleNotFoundError: 

23 import multipart 

24 from multipart.multipart import parse_options_header 

25 except ModuleNotFoundError: 

26 multipart = None 

27 parse_options_header = None 

28 

29 

30class FormMessage(Enum): 

31 FIELD_START = 1 

32 FIELD_NAME = 2 

33 FIELD_DATA = 3 

34 FIELD_END = 4 

35 END = 5 

36 

37 

38@dataclass 

39class MultipartPart: 

40 content_disposition: typing.Optional[bytes] = None 

41 field_name: str = "" 

42 data: bytearray = field(default_factory=bytearray) 

43 file: typing.Optional[UploadedFile] = None 

44 item_headers: list[tuple[bytes, bytes]] = field(default_factory=list) 

45 

46 

47def _user_safe_decode(src: typing.Union[bytes, bytearray], codec: str) -> str: 

48 try: 

49 return src.decode(codec) 

50 except (UnicodeDecodeError, LookupError): 

51 return src.decode("latin-1") 

52 

53 

54class MultiPartException(Exception): 

55 def __init__(self, message: str) -> None: 

56 self.message = message 

57 

58 

59class FormParser: 

60 def __init__( 

61 self, headers: Headers, stream: typing.AsyncGenerator[bytes, None] 

62 ) -> None: 

63 assert ( 

64 multipart is not None 

65 ), "The `python-multipart` library must be installed to use form parsing." 

66 self.headers = headers 

67 self.stream = stream 

68 self.messages: list[tuple[FormMessage, bytes]] = [] 

69 

70 def on_field_start(self) -> None: 

71 message = (FormMessage.FIELD_START, b"") 

72 self.messages.append(message) 

73 

74 def on_field_name(self, data: bytes, start: int, end: int) -> None: 

75 message = (FormMessage.FIELD_NAME, data[start:end]) 

76 self.messages.append(message) 

77 

78 def on_field_data(self, data: bytes, start: int, end: int) -> None: 

79 message = (FormMessage.FIELD_DATA, data[start:end]) 

80 self.messages.append(message) 

81 

82 def on_field_end(self) -> None: 

83 message = (FormMessage.FIELD_END, b"") 

84 self.messages.append(message) 

85 

86 def on_end(self) -> None: 

87 message = (FormMessage.END, b"") 

88 self.messages.append(message) 

89 

90 async def parse(self) -> FormData: 

91 """ 

92 Parse the request stream as form data. 

93 

94 Returns: 

95 FormData: The parsed form data. 

96 """ 

97 content_type = self.headers.get("content-type", "") 

98 if content_type.startswith("multipart/form-data"): 

99 multipart_parser = MultiPartParser(self.headers, self.stream) 

100 return await multipart_parser.parse() 

101 

102 # Default to application/x-www-form-urlencoded 

103 form = FormData() 

104 content = b"" 

105 

106 # Collect all chunks into a single content buffer 

107 async for chunk in self.stream: 

108 if chunk: 

109 content += chunk 

110 

111 if content: 

112 try: 

113 # Use parse_qsl to get a list of key-value pairs 

114 field_items = urllib.parse.parse_qsl( 

115 content.decode("utf-8"), keep_blank_values=True 

116 ) 

117 

118 # Add each field to the form data 

119 for key, value in field_items: 

120 # URL decode the value to handle special characters 

121 decoded_value = urllib.parse.unquote(value) 

122 form.append(key, decoded_value) 

123 except (UnicodeDecodeError, ValueError) as e: 

124 # If there's a decoding error, try with latin-1 encoding 

125 try: 

126 field_items = urllib.parse.parse_qsl( 

127 content.decode("latin-1"), keep_blank_values=True 

128 ) 

129 for key, value in field_items: 

130 decoded_value = urllib.parse.unquote(value) 

131 form.append(key, decoded_value) 

132 except Exception: 

133 # If still can't parse, return empty form 

134 pass 

135 

136 return form 

137 

138 

139class MultiPartParser: 

140 max_file_size = 1024 * 1024 # 1MB 

141 max_part_size = 1024 * 1024 # 1MB 

142 max_fields = 1000 

143 max_files = 1000 

144 

145 def __init__( 

146 self, 

147 headers: Headers, 

148 stream: typing.AsyncGenerator[bytes, None], 

149 *, 

150 max_fields: typing.Optional[int] = None, 

151 max_files: typing.Optional[int] = None, 

152 ) -> None: 

153 assert ( 

154 multipart is not None 

155 ), "The `python-multipart` library must be installed to use form parsing." 

156 self.headers = headers 

157 self.stream = stream 

158 self.max_files = max_files if max_files is not None else self.max_files 

159 self.max_fields = max_fields if max_fields is not None else self.max_fields 

160 self.items: list[tuple[str, typing.Union[str, UploadedFile]]] = [] 

161 self._current_files = 0 

162 self._current_fields = 0 

163 self._current_partial_header_name: bytes = b"" 

164 self._current_partial_header_value: bytes = b"" 

165 self._current_part = MultipartPart() 

166 self._charset = "" 

167 self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = [] 

168 self._file_parts_to_finish: list[MultipartPart] = [] 

169 self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = [] 

170 

171 def on_part_begin(self) -> None: 

172 self._current_part = MultipartPart() 

173 

174 def on_part_data(self, data: bytes, start: int, end: int) -> None: 

175 message_bytes = data[start:end] 

176 if self._current_part.file is None: 

177 if len(self._current_part.data) + len(message_bytes) > self.max_part_size: 

178 raise MultiPartException( 

179 f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB." 

180 ) 

181 self._current_part.data.extend(message_bytes) 

182 else: 

183 # Check file size limit when writing file parts 

184 if self._current_part.file and self._current_part.file.size is not None: 

185 new_size = self._current_part.file.size + len(message_bytes) 

186 if new_size > self.max_file_size: 

187 raise MultiPartException( 

188 f"File too large. Maximum size is {self.max_file_size} bytes" 

189 ) 

190 self._file_parts_to_write.append((self._current_part, message_bytes)) 

191 

192 def on_part_end(self) -> None: 

193 if self._current_part.file is None: 

194 self.items.append( 

195 ( 

196 self._current_part.field_name, 

197 _user_safe_decode( 

198 self._current_part.data, self._charset # type: ignore 

199 ), 

200 ) 

201 ) 

202 else: 

203 self._file_parts_to_finish.append(self._current_part) 

204 # The file can be added to the items right now even though it's not 

205 # finished yet, because it will be finished in the `parse()` method, before 

206 # self.items is used in the return value. 

207 self.items.append((self._current_part.field_name, self._current_part.file)) 

208 

209 def on_header_field(self, data: bytes, start: int, end: int) -> None: 

210 self._current_partial_header_name += data[start:end] 

211 

212 def on_header_value(self, data: bytes, start: int, end: int) -> None: 

213 self._current_partial_header_value += data[start:end] 

214 

215 def on_header_end(self) -> None: 

216 field = self._current_partial_header_name.lower() 

217 if field == b"content-disposition": 

218 self._current_part.content_disposition = self._current_partial_header_value 

219 self._current_part.item_headers.append( 

220 (field, self._current_partial_header_value) 

221 ) 

222 self._current_partial_header_name = b"" 

223 self._current_partial_header_value = b"" 

224 

225 def on_headers_finished(self) -> None: 

226 _, options = parse_options_header(self._current_part.content_disposition) 

227 try: 

228 self._current_part.field_name = _user_safe_decode( 

229 options[b"name"], self._charset # type: ignore 

230 ) 

231 except KeyError: 

232 raise MultiPartException( 

233 'The Content-Disposition header field "name" must be provided.' 

234 ) 

235 if b"filename" in options: 

236 self._current_files += 1 

237 if self._current_files > self.max_files: 

238 raise MultiPartException( 

239 f"Too many files. Maximum number of files is {self.max_files}." 

240 ) 

241 filename = _user_safe_decode( 

242 options[b"filename"], self._charset # type: ignore 

243 ) # type:ignore 

244 tempfile = SpooledTemporaryFile(max_size=self.max_file_size) 

245 self._files_to_close_on_error.append(tempfile) 

246 self._current_part.file = UploadedFile( 

247 file=tempfile, # type: ignore[arg-type] 

248 size=0, 

249 filename=filename, 

250 headers=Headers(raw=self._current_part.item_headers), 

251 ) 

252 else: 

253 self._current_fields += 1 

254 if self._current_fields > self.max_fields: 

255 raise MultiPartException( 

256 f"Too many fields. Maximum number of fields is {self.max_fields}." 

257 ) 

258 self._current_part.file = None 

259 

260 def on_end(self) -> None: 

261 pass 

262 

263 async def parse(self) -> FormData: 

264 """Parse the form data from the request body.""" 

265 content_type = self.headers.get("content-type", "") 

266 content_type, params = parse_options_header(content_type) 

267 

268 if content_type != b"multipart/form-data": 

269 return FormData() 

270 

271 boundary = params.get(b"boundary") 

272 if not boundary: 

273 return FormData() 

274 

275 charset = params.get(b"charset") 

276 self._charset = charset.decode("latin-1") if charset else "utf-8" 

277 

278 callbacks: typing.Dict[str, typing.Callable[..., typing.Any]] = { 

279 "on_part_begin": self.on_part_begin, 

280 "on_part_data": self.on_part_data, 

281 "on_part_end": self.on_part_end, 

282 "on_header_field": self.on_header_field, 

283 "on_header_value": self.on_header_value, 

284 "on_header_end": self.on_header_end, 

285 "on_headers_finished": self.on_headers_finished, 

286 "on_end": self.on_end, 

287 } 

288 

289 parser = multipart.MultipartParser(boundary, callbacks) # type:ignore 

290 try: 

291 # Feed the parser with data from the request. 

292 async for chunk in self.stream: 

293 parser.write(chunk) # type:ignore 

294 # Write file data, it needs to use await with the UploadedFile methods 

295 # that call the corresponding file methods *in a threadpool*, 

296 # otherwise, if they were called directly in the callback methods above 

297 # (regular, non-async functions), that would block the event loop in 

298 # the main thread. 

299 for part, data in self._file_parts_to_write: 

300 # assert part.file # for type checkers 

301 await part.file.write(data) # type:ignore 

302 for part in self._file_parts_to_finish: 

303 assert part.file # for type checkers 

304 await part.file.seek(0) 

305 self._file_parts_to_write.clear() 

306 self._file_parts_to_finish.clear() 

307 except MultiPartException as exc: 

308 # Close all the files if there was an error. 

309 for file in self._files_to_close_on_error: 

310 file.close() 

311 raise exc 

312 

313 parser.finalize() # type:ignore 

314 return FormData(self.items)