Coverage for nexios\middlewares\csrf.py: 29%
82 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-21 20:31 +0100
1import secrets, re, typing
2from nexios.config import get_config
3from itsdangerous import URLSafeSerializer, BadSignature # type:ignore
4from nexios.middlewares.base import BaseMiddleware
5from nexios.http import Request, Response
8class CSRFMiddleware(BaseMiddleware):
9 """
10 Middleware to protect against Cross-Site Request Forgery (CSRF) attacks for Nexios.
11 """
13 def __init__(self) -> None:
14 app_config = get_config()
15 self.use_csrf = app_config.csrf_enabled or False
16 if self.use_csrf:
17 assert app_config.secret_key != None, ""
18 if not self.use_csrf:
19 return
20 self.serializer = URLSafeSerializer(
21 app_config.secret_key, "csrftoken"
22 ) # type:ignore
23 self.required_urls: typing.List[str] = app_config.csrf_required_urls or []
24 self.exempt_urls = app_config.csrf_exempt_urls
25 self.sensitive_cookies = app_config.csrf_sensitive_cookies
26 self.safe_methods = set(
27 app_config.csrf_safe_methods
28 or [
29 "GET",
30 "HEAD",
31 "OPTIONS",
32 "TRACE",
33 ]
34 )
35 self.cookie_name = app_config.csrf_cookie_name or "csrftoken"
36 self.cookie_path = app_config.csrf_cookie_path or "/"
37 self.cookie_domain = app_config.csrf_cookie_domain
38 self.cookie_secure = app_config.csrf_cookie_secure or False
39 self.cookie_httponly = app_config.csrf_cookie_httponly or True
40 self.cookie_samesite: typing.Literal["lax", "none", "strict"] = (
41 app_config.csrf_cookie_samesite or "lax"
42 )
43 self.header_name = app_config.csrf_header_name or "X-CSRFToken"
45 async def process_request(
46 self,
47 request: Request,
48 response: Response,
49 call_next: typing.Callable[..., typing.Awaitable[typing.Any]],
50 ):
51 """
52 Process the incoming request to validate the CSRF token for unsafe HTTP methods.
53 """
55 if not self.use_csrf:
56 await call_next()
57 return
58 csrf_cookie = request.cookies.get(self.cookie_name)
59 if request.method.upper() in self.safe_methods:
60 await call_next()
61 return
62 if self._url_is_required(request.url.path) or (
63 self._url_is_exempt(request.url.path)
64 and self._has_sensitive_cookies(request.cookies)
65 ):
66 submitted_csrf_token = request.headers.get(self.header_name)
67 if not csrf_cookie:
68 return response.text("CSRF token missing from cookies", status_code=403)
70 if not submitted_csrf_token:
71 return response.text("CSRF token missing from headers", status_code=403)
73 if not self._csrf_tokens_match(csrf_cookie, submitted_csrf_token):
74 return response.text("CSRF token incorrect", status_code=403)
75 response.delete_cookie(self.cookie_name, self.cookie_path, self.cookie_domain)
76 await call_next()
78 async def process_response(self, request: Request, response: Response):
79 """
80 Inject the CSRF token into the response for client-side usage if not already set.
81 """
82 if not self.use_csrf:
83 return
84 csrf_token = self._generate_csrf_token()
86 response.set_cookie(
87 key=self.cookie_name,
88 value=csrf_token,
89 path=self.cookie_path,
90 domain=self.cookie_domain,
91 secure=self.cookie_secure,
92 httponly=self.cookie_httponly,
93 samesite=self.cookie_samesite,
94 )
96 def _has_sensitive_cookies(self, cookies: typing.Dict[str, typing.Any]) -> bool:
97 """Check if the request contains sensitive cookies."""
98 if not self.sensitive_cookies:
99 return True
100 for sensitive_cookie in self.sensitive_cookies:
101 if sensitive_cookie in cookies:
102 return True
103 return False
105 def _url_is_required(self, url: str) -> bool:
106 """Check if the URL requires CSRF validation."""
108 if not self.required_urls:
109 return False
111 if "*" in self.required_urls:
112 return True
113 for required_url in self.required_urls:
114 match = re.match(required_url, url)
115 if match and match.group() == url:
116 return True
117 return False
119 def _url_is_exempt(self, url: str) -> bool:
120 """Check if the URL is exempt from CSRF validation."""
121 if not self.exempt_urls:
122 return False
123 for exempt_url in self.exempt_urls:
124 match = re.match(exempt_url, url)
125 if match and match.group() == url:
126 return True
127 return False
129 def _generate_csrf_token(self) -> str: # type:ignore
130 """Generate a secure CSRF token."""
131 return self.serializer.dumps(secrets.token_urlsafe(32)) # type:ignore
133 def _csrf_tokens_match(self, token1: str, token2: str) -> bool:
134 """Compare two CSRF tokens securely."""
135 try:
136 decoded1 = self.serializer.loads(token1) # type:ignore
137 decoded2 = self.serializer.loads(token2) # type:ignore
138 return secrets.compare_digest(decoded1, decoded2) # type:ignore
139 except BadSignature:
140 return False