Coverage for nexios\middlewares\cors.py: 88%

113 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-21 20:31 +0100

1import re 

2 

3# from typing_extensions import Annotated, Doc 

4from nexios.middlewares.base import BaseMiddleware 

5from nexios.http import Request, Response 

6from nexios.config import get_config 

7from typing import Callable, Optional, List, Dict, Any 

8import typing 

9from nexios.logging import getLogger 

10 

11 

12logger = getLogger() 

13 

14ALL_METHODS = ("delete", "get", "head", "options", "patch", "post", "put") 

15BASIC_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"} 

16SAFELISTED_HEADERS = {"accept", "accept-language", "content-language", "content-type"} 

17 

18 

19class CORSMiddleware(BaseMiddleware): 

20 def __init__(self): 

21 config = get_config().cors 

22 

23 if not config: 

24 return None 

25 self.allow_origins: List[str] = config.allow_origins or [] 

26 self.blacklist_origins: List[str] = config.blacklist_origins or [] 

27 self.allow_methods = config.allow_methods or ALL_METHODS 

28 self.blacklist_headers: List[str] = config.blacklist_headers or [] 

29 self.allow_credentials = ( 

30 config.allow_credentials if config.allow_credentials is not None else True 

31 ) 

32 self.allow_origin_regex = ( 

33 re.compile(config.allow_origin_regex) if config.allow_origin_regex else None 

34 ) 

35 self.expose_headers: List[str] = config.expose_headers or [] 

36 self.max_age = config.max_age or 600 

37 self.strict_origin_checking = config.strict_origin_checking or False 

38 self.dynamic_origin_validator: Optional[Callable[[Optional[str]], bool]] = ( 

39 getattr(config, "dynamic_origin_validator", None) 

40 ) 

41 self.debug = config.debug or False 

42 self.custom_error_status = config.custom_error_status or 400 

43 self.custom_error_messages = getattr(config, "custom_error_messages", {}) or {} 

44 

45 self.simple_headers: Dict[str, Any] = {} 

46 if self.allow_credentials: 

47 self.simple_headers["Access-Control-Allow-Credentials"] = "true" 

48 if self.expose_headers: 

49 self.simple_headers["Access-Control-Expose-Headers"] = ", ".join( 

50 self.expose_headers 

51 ) 

52 

53 self.preflight_headers = { 

54 "Access-Control-Allow-Methods": ", ".join( 

55 [x.upper() for x in self.allow_methods] 

56 ), 

57 "Access-Control-Max-Age": str(self.max_age), 

58 } 

59 if self.allow_credentials: 

60 self.preflight_headers["Access-Control-Allow-Credentials"] = "true" 

61 if config.allow_headers: 

62 self.allow_headers: List[str] = [ 

63 *list(SAFELISTED_HEADERS), 

64 *config.allow_headers, 

65 ] 

66 else: 

67 self.allow_headers = list(SAFELISTED_HEADERS) 

68 

69 async def process_request( 

70 self, 

71 request: Request, 

72 response: Response, 

73 call_next: typing.Callable[..., typing.Awaitable[Any]], 

74 ): 

75 config = get_config().cors 

76 

77 if not config: 

78 await call_next() 

79 return None 

80 

81 origin = request.origin 

82 if not origin: 

83 return await call_next() 

84 method = request.scope["method"] 

85 

86 if not origin and self.strict_origin_checking: 

87 if self.debug: 

88 logger.error("Request denied: Missing 'Origin' header.") 

89 return response.json( 

90 self.get_error_message("missing_origin"), 

91 status_code=self.custom_error_status, 

92 ) 

93 if ( 

94 method.lower() == "options" 

95 and "access-control-request-method" in request.headers 

96 ): 

97 return await self.preflight_response(request, response) 

98 await self.simple_response(request, response, call_next) 

99 

100 async def simple_response( 

101 self, 

102 request: Request, 

103 response: Response, 

104 call_next: typing.Callable[..., typing.Awaitable[Any]], 

105 ): 

106 config = get_config().cors 

107 await call_next() 

108 if not config: 

109 return None 

110 origin = request.origin 

111 

112 if origin and self.is_allowed_origin(origin): 

113 response.set_header("Access-Control-Allow-Origin", origin, overide=True) 

114 

115 if self.allow_credentials: 

116 response.set_header( 

117 "Access-Control-Allow-Credentials", "true", overide=True 

118 ) 

119 

120 if self.expose_headers: 

121 response.set_header( 

122 "Access-Control-Expose-Headers", 

123 ", ".join(self.expose_headers), 

124 overide=True, 

125 ) 

126 

127 def is_allowed_origin(self, origin: Optional[str]) -> bool: 

128 if origin in self.blacklist_origins: 

129 if self.debug: 

130 logger.error(f"Request denied: Origin '{origin}' is blacklisted.") 

131 

132 return False 

133 

134 if "*" in self.allow_origins: 

135 

136 return True 

137 

138 if self.allow_origin_regex and self.allow_origin_regex.fullmatch(origin): 

139 return True 

140 

141 if self.dynamic_origin_validator and callable(self.dynamic_origin_validator): 

142 return self.dynamic_origin_validator(origin) 

143 

144 return origin in self.allow_origins 

145 

146 def is_allowed_method(self, method: Optional[str]) -> bool: 

147 if "*" in self.allow_methods: 

148 return True 

149 if not (method or str()).lower() in [x.lower() for x in self.allow_methods]: 

150 

151 return False 

152 return True 

153 

154 async def preflight_response(self, request: Request, response: Response) -> Any: 

155 origin = request.headers.get("origin") 

156 requested_method = request.headers.get("access-control-request-method") 

157 requested_headers = request.headers.get("access-control-request-headers") 

158 

159 headers = self.preflight_headers.copy() 

160 

161 if not self.is_allowed_origin(origin): 

162 

163 if self.debug: 

164 logger.error( 

165 f"Preflight request denied: Origin '{origin}' is not allowed." 

166 ) 

167 return response.json( 

168 self.get_error_message("disallowed_origin"), 

169 status_code=self.custom_error_status, 

170 ) 

171 

172 headers["Access-Control-Allow-Origin"] = origin # type:ignore 

173 

174 if not self.is_allowed_method(requested_method): 

175 if self.debug: 

176 logger.error( 

177 f"Preflight request denied: Method '{requested_method}' is not allowed." 

178 ) 

179 return response.json( 

180 self.get_error_message("disallowed_method"), 

181 status_code=self.custom_error_status, 

182 ) 

183 

184 if requested_headers: 

185 requested_header_list = [ 

186 h.strip().lower() for h in requested_headers.split(",") 

187 ] 

188 if "*" in self.allow_headers: 

189 headers["Access-Control-Allow-Headers"] = "*" 

190 else: 

191 for header in requested_header_list: 

192 if ( 

193 header not in [x.lower() for x in self.allow_headers] 

194 or header in self.blacklist_headers 

195 ): 

196 if self.debug: 

197 logger.error( 

198 f"Preflight request denied: Header '{header}' is not allowed." 

199 ) 

200 return response.json( 

201 self.get_error_message("disallowed_header"), 

202 status_code=self.custom_error_status, 

203 ) 

204 headers["Access-Control-Allow-Headers"] = requested_headers 

205 return response.json("OK", status_code=201, headers=headers) 

206 

207 def get_error_message(self, error_type: str) -> str: 

208 return self.custom_error_messages.get(error_type, "CORS request denied.")