Coverage for nexios\middlewares\cors.py: 88%
113 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
1import re
3# from typing_extensions import Annotated, Doc
4from nexios.middlewares.base import BaseMiddleware
5from nexios.http import Request, Response
6from nexios.config import get_config
7from typing import Callable, Optional, List, Dict, Any
8import typing
9from nexios.logging import getLogger
12logger = getLogger()
14ALL_METHODS = ("delete", "get", "head", "options", "patch", "post", "put")
15BASIC_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"}
16SAFELISTED_HEADERS = {"accept", "accept-language", "content-language", "content-type"}
19class CORSMiddleware(BaseMiddleware):
20 def __init__(self):
21 config = get_config().cors
23 if not config:
24 return None
25 self.allow_origins: List[str] = config.allow_origins or []
26 self.blacklist_origins: List[str] = config.blacklist_origins or []
27 self.allow_methods = config.allow_methods or ALL_METHODS
28 self.blacklist_headers: List[str] = config.blacklist_headers or []
29 self.allow_credentials = (
30 config.allow_credentials if config.allow_credentials is not None else True
31 )
32 self.allow_origin_regex = (
33 re.compile(config.allow_origin_regex) if config.allow_origin_regex else None
34 )
35 self.expose_headers: List[str] = config.expose_headers or []
36 self.max_age = config.max_age or 600
37 self.strict_origin_checking = config.strict_origin_checking or False
38 self.dynamic_origin_validator: Optional[Callable[[Optional[str]], bool]] = (
39 getattr(config, "dynamic_origin_validator", None)
40 )
41 self.debug = config.debug or False
42 self.custom_error_status = config.custom_error_status or 400
43 self.custom_error_messages = getattr(config, "custom_error_messages", {}) or {}
45 self.simple_headers: Dict[str, Any] = {}
46 if self.allow_credentials:
47 self.simple_headers["Access-Control-Allow-Credentials"] = "true"
48 if self.expose_headers:
49 self.simple_headers["Access-Control-Expose-Headers"] = ", ".join(
50 self.expose_headers
51 )
53 self.preflight_headers = {
54 "Access-Control-Allow-Methods": ", ".join(
55 [x.upper() for x in self.allow_methods]
56 ),
57 "Access-Control-Max-Age": str(self.max_age),
58 }
59 if self.allow_credentials:
60 self.preflight_headers["Access-Control-Allow-Credentials"] = "true"
61 if config.allow_headers:
62 self.allow_headers: List[str] = [
63 *list(SAFELISTED_HEADERS),
64 *config.allow_headers,
65 ]
66 else:
67 self.allow_headers = list(SAFELISTED_HEADERS)
69 async def process_request(
70 self,
71 request: Request,
72 response: Response,
73 call_next: typing.Callable[..., typing.Awaitable[Any]],
74 ):
75 config = get_config().cors
77 if not config:
78 await call_next()
79 return None
81 origin = request.origin
82 if not origin:
83 return await call_next()
84 method = request.scope["method"]
86 if not origin and self.strict_origin_checking:
87 if self.debug:
88 logger.error("Request denied: Missing 'Origin' header.")
89 return response.json(
90 self.get_error_message("missing_origin"),
91 status_code=self.custom_error_status,
92 )
93 if (
94 method.lower() == "options"
95 and "access-control-request-method" in request.headers
96 ):
97 return await self.preflight_response(request, response)
98 await self.simple_response(request, response, call_next)
100 async def simple_response(
101 self,
102 request: Request,
103 response: Response,
104 call_next: typing.Callable[..., typing.Awaitable[Any]],
105 ):
106 config = get_config().cors
107 await call_next()
108 if not config:
109 return None
110 origin = request.origin
112 if origin and self.is_allowed_origin(origin):
113 response.set_header("Access-Control-Allow-Origin", origin, overide=True)
115 if self.allow_credentials:
116 response.set_header(
117 "Access-Control-Allow-Credentials", "true", overide=True
118 )
120 if self.expose_headers:
121 response.set_header(
122 "Access-Control-Expose-Headers",
123 ", ".join(self.expose_headers),
124 overide=True,
125 )
127 def is_allowed_origin(self, origin: Optional[str]) -> bool:
128 if origin in self.blacklist_origins:
129 if self.debug:
130 logger.error(f"Request denied: Origin '{origin}' is blacklisted.")
132 return False
134 if "*" in self.allow_origins:
136 return True
138 if self.allow_origin_regex and self.allow_origin_regex.fullmatch(origin):
139 return True
141 if self.dynamic_origin_validator and callable(self.dynamic_origin_validator):
142 return self.dynamic_origin_validator(origin)
144 return origin in self.allow_origins
146 def is_allowed_method(self, method: Optional[str]) -> bool:
147 if "*" in self.allow_methods:
148 return True
149 if not (method or str()).lower() in [x.lower() for x in self.allow_methods]:
151 return False
152 return True
154 async def preflight_response(self, request: Request, response: Response) -> Any:
155 origin = request.headers.get("origin")
156 requested_method = request.headers.get("access-control-request-method")
157 requested_headers = request.headers.get("access-control-request-headers")
159 headers = self.preflight_headers.copy()
161 if not self.is_allowed_origin(origin):
163 if self.debug:
164 logger.error(
165 f"Preflight request denied: Origin '{origin}' is not allowed."
166 )
167 return response.json(
168 self.get_error_message("disallowed_origin"),
169 status_code=self.custom_error_status,
170 )
172 headers["Access-Control-Allow-Origin"] = origin # type:ignore
174 if not self.is_allowed_method(requested_method):
175 if self.debug:
176 logger.error(
177 f"Preflight request denied: Method '{requested_method}' is not allowed."
178 )
179 return response.json(
180 self.get_error_message("disallowed_method"),
181 status_code=self.custom_error_status,
182 )
184 if requested_headers:
185 requested_header_list = [
186 h.strip().lower() for h in requested_headers.split(",")
187 ]
188 if "*" in self.allow_headers:
189 headers["Access-Control-Allow-Headers"] = "*"
190 else:
191 for header in requested_header_list:
192 if (
193 header not in [x.lower() for x in self.allow_headers]
194 or header in self.blacklist_headers
195 ):
196 if self.debug:
197 logger.error(
198 f"Preflight request denied: Header '{header}' is not allowed."
199 )
200 return response.json(
201 self.get_error_message("disallowed_header"),
202 status_code=self.custom_error_status,
203 )
204 headers["Access-Control-Allow-Headers"] = requested_headers
205 return response.json("OK", status_code=201, headers=headers)
207 def get_error_message(self, error_type: str) -> str:
208 return self.custom_error_messages.get(error_type, "CORS request denied.")