Coverage for nexios\middlewares\csrf.py: 29%

82 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-21 20:31 +0100

1import secrets, re, typing 

2from nexios.config import get_config 

3from itsdangerous import URLSafeSerializer, BadSignature # type:ignore 

4from nexios.middlewares.base import BaseMiddleware 

5from nexios.http import Request, Response 

6 

7 

8class CSRFMiddleware(BaseMiddleware): 

9 """ 

10 Middleware to protect against Cross-Site Request Forgery (CSRF) attacks for Nexios. 

11 """ 

12 

13 def __init__(self) -> None: 

14 app_config = get_config() 

15 self.use_csrf = app_config.csrf_enabled or False 

16 if self.use_csrf: 

17 assert app_config.secret_key != None, "" 

18 if not self.use_csrf: 

19 return 

20 self.serializer = URLSafeSerializer( 

21 app_config.secret_key, "csrftoken" 

22 ) # type:ignore 

23 self.required_urls: typing.List[str] = app_config.csrf_required_urls or [] 

24 self.exempt_urls = app_config.csrf_exempt_urls 

25 self.sensitive_cookies = app_config.csrf_sensitive_cookies 

26 self.safe_methods = set( 

27 app_config.csrf_safe_methods 

28 or [ 

29 "GET", 

30 "HEAD", 

31 "OPTIONS", 

32 "TRACE", 

33 ] 

34 ) 

35 self.cookie_name = app_config.csrf_cookie_name or "csrftoken" 

36 self.cookie_path = app_config.csrf_cookie_path or "/" 

37 self.cookie_domain = app_config.csrf_cookie_domain 

38 self.cookie_secure = app_config.csrf_cookie_secure or False 

39 self.cookie_httponly = app_config.csrf_cookie_httponly or True 

40 self.cookie_samesite: typing.Literal["lax", "none", "strict"] = ( 

41 app_config.csrf_cookie_samesite or "lax" 

42 ) 

43 self.header_name = app_config.csrf_header_name or "X-CSRFToken" 

44 

45 async def process_request( 

46 self, 

47 request: Request, 

48 response: Response, 

49 call_next: typing.Callable[..., typing.Awaitable[typing.Any]], 

50 ): 

51 """ 

52 Process the incoming request to validate the CSRF token for unsafe HTTP methods. 

53 """ 

54 

55 if not self.use_csrf: 

56 await call_next() 

57 return 

58 csrf_cookie = request.cookies.get(self.cookie_name) 

59 if request.method.upper() in self.safe_methods: 

60 await call_next() 

61 return 

62 if self._url_is_required(request.url.path) or ( 

63 self._url_is_exempt(request.url.path) 

64 and self._has_sensitive_cookies(request.cookies) 

65 ): 

66 submitted_csrf_token = request.headers.get(self.header_name) 

67 if not csrf_cookie: 

68 return response.text("CSRF token missing from cookies", status_code=403) 

69 

70 if not submitted_csrf_token: 

71 return response.text("CSRF token missing from headers", status_code=403) 

72 

73 if not self._csrf_tokens_match(csrf_cookie, submitted_csrf_token): 

74 return response.text("CSRF token incorrect", status_code=403) 

75 response.delete_cookie(self.cookie_name, self.cookie_path, self.cookie_domain) 

76 await call_next() 

77 

78 async def process_response(self, request: Request, response: Response): 

79 """ 

80 Inject the CSRF token into the response for client-side usage if not already set. 

81 """ 

82 if not self.use_csrf: 

83 return 

84 csrf_token = self._generate_csrf_token() 

85 

86 response.set_cookie( 

87 key=self.cookie_name, 

88 value=csrf_token, 

89 path=self.cookie_path, 

90 domain=self.cookie_domain, 

91 secure=self.cookie_secure, 

92 httponly=self.cookie_httponly, 

93 samesite=self.cookie_samesite, 

94 ) 

95 

96 def _has_sensitive_cookies(self, cookies: typing.Dict[str, typing.Any]) -> bool: 

97 """Check if the request contains sensitive cookies.""" 

98 if not self.sensitive_cookies: 

99 return True 

100 for sensitive_cookie in self.sensitive_cookies: 

101 if sensitive_cookie in cookies: 

102 return True 

103 return False 

104 

105 def _url_is_required(self, url: str) -> bool: 

106 """Check if the URL requires CSRF validation.""" 

107 

108 if not self.required_urls: 

109 return False 

110 

111 if "*" in self.required_urls: 

112 return True 

113 for required_url in self.required_urls: 

114 match = re.match(required_url, url) 

115 if match and match.group() == url: 

116 return True 

117 return False 

118 

119 def _url_is_exempt(self, url: str) -> bool: 

120 """Check if the URL is exempt from CSRF validation.""" 

121 if not self.exempt_urls: 

122 return False 

123 for exempt_url in self.exempt_urls: 

124 match = re.match(exempt_url, url) 

125 if match and match.group() == url: 

126 return True 

127 return False 

128 

129 def _generate_csrf_token(self) -> str: # type:ignore 

130 """Generate a secure CSRF token.""" 

131 return self.serializer.dumps(secrets.token_urlsafe(32)) # type:ignore 

132 

133 def _csrf_tokens_match(self, token1: str, token2: str) -> bool: 

134 """Compare two CSRF tokens securely.""" 

135 try: 

136 decoded1 = self.serializer.loads(token1) # type:ignore 

137 decoded2 = self.serializer.loads(token2) # type:ignore 

138 return secrets.compare_digest(decoded1, decoded2) # type:ignore 

139 except BadSignature: 

140 return False