diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/middleware/csrf.py | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/middleware/csrf.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/middleware/csrf.py | 190 |
1 files changed, 190 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/csrf.py b/venv/lib/python3.11/site-packages/litestar/middleware/csrf.py new file mode 100644 index 0000000..94dd422 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/csrf.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import hashlib +import hmac +import secrets +from secrets import compare_digest +from typing import TYPE_CHECKING, Any + +from litestar.datastructures import MutableScopeHeaders +from litestar.datastructures.cookie import Cookie +from litestar.enums import RequestEncodingType, ScopeType +from litestar.exceptions import PermissionDeniedException +from litestar.middleware._utils import ( + build_exclude_path_pattern, + should_bypass_middleware, +) +from litestar.middleware.base import MiddlewareProtocol +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from litestar.config.csrf import CSRFConfig + from litestar.connection import Request + from litestar.types import ( + ASGIApp, + HTTPSendMessage, + Message, + Receive, + Scope, + Scopes, + Send, + ) + +__all__ = ("CSRFMiddleware",) + + +CSRF_SECRET_BYTES = 32 +CSRF_SECRET_LENGTH = CSRF_SECRET_BYTES * 2 + + +def generate_csrf_hash(token: str, secret: str) -> str: + """Generate an HMAC that signs the CSRF token. + + Args: + token: A hashed token. + secret: A secret value. + + Returns: + A CSRF hash. + """ + return hmac.new(secret.encode(), token.encode(), hashlib.sha256).hexdigest() + + +def generate_csrf_token(secret: str) -> str: + """Generate a CSRF token that includes a randomly generated string signed by an HMAC. + + Args: + secret: A secret string. + + Returns: + A unique CSRF token. + """ + token = secrets.token_hex(CSRF_SECRET_BYTES) + token_hash = generate_csrf_hash(token=token, secret=secret) + return token + token_hash + + +class CSRFMiddleware(MiddlewareProtocol): + """CSRF Middleware class. + + This Middleware protects against attacks by setting a CSRF cookie with a token and verifying it in request headers. + """ + + scopes: Scopes = {ScopeType.HTTP} + + def __init__(self, app: ASGIApp, config: CSRFConfig) -> None: + """Initialize ``CSRFMiddleware``. + + Args: + app: The ``next`` ASGI app to call. + config: The CSRFConfig instance. + """ + self.app = app + self.config = config + self.exclude = build_exclude_path_pattern(exclude=config.exclude) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + if scope["type"] != ScopeType.HTTP: + await self.app(scope, receive, send) + return + + request: Request[Any, Any, Any] = scope["app"].request_class(scope=scope, receive=receive) + content_type, _ = request.content_type + csrf_cookie = request.cookies.get(self.config.cookie_name) + existing_csrf_token = request.headers.get(self.config.header_name) + + if not existing_csrf_token and content_type in { + RequestEncodingType.URL_ENCODED, + RequestEncodingType.MULTI_PART, + }: + form = await request.form() + existing_csrf_token = form.get("_csrf_token", None) + + connection_state = ScopeState.from_scope(scope) + if request.method in self.config.safe_methods or should_bypass_middleware( + scope=scope, + scopes=self.scopes, + exclude_opt_key=self.config.exclude_from_csrf_key, + exclude_path_pattern=self.exclude, + ): + token = connection_state.csrf_token = csrf_cookie or generate_csrf_token(secret=self.config.secret) + await self.app(scope, receive, self.create_send_wrapper(send=send, csrf_cookie=csrf_cookie, token=token)) + elif ( + existing_csrf_token is not None + and csrf_cookie is not None + and self._csrf_tokens_match(existing_csrf_token, csrf_cookie) + ): + connection_state.csrf_token = existing_csrf_token + await self.app(scope, receive, send) + else: + raise PermissionDeniedException("CSRF token verification failed") + + def create_send_wrapper(self, send: Send, token: str, csrf_cookie: str | None) -> Send: + """Wrap ``send`` to handle CSRF validation. + + Args: + token: The CSRF token. + send: The ASGI send function. + csrf_cookie: CSRF cookie. + + Returns: + An ASGI send function. + """ + + async def send_wrapper(message: Message) -> None: + """Send function that wraps the original send to inject a cookie. + + Args: + message: An ASGI ``Message`` + + Returns: + None + """ + if csrf_cookie is None and message["type"] == "http.response.start": + message.setdefault("headers", []) + self._set_cookie_if_needed(message=message, token=token) + await send(message) + + return send_wrapper + + def _set_cookie_if_needed(self, message: HTTPSendMessage, token: str) -> None: + headers = MutableScopeHeaders.from_message(message) + cookie = Cookie( + key=self.config.cookie_name, + value=token, + path=self.config.cookie_path, + secure=self.config.cookie_secure, + httponly=self.config.cookie_httponly, + samesite=self.config.cookie_samesite, + domain=self.config.cookie_domain, + ) + headers.add("set-cookie", cookie.to_header(header="")) + + def _decode_csrf_token(self, token: str) -> str | None: + """Decode a CSRF token and validate its HMAC.""" + if len(token) < CSRF_SECRET_LENGTH + 1: + return None + + token_secret = token[:CSRF_SECRET_LENGTH] + existing_hash = token[CSRF_SECRET_LENGTH:] + expected_hash = generate_csrf_hash(token=token_secret, secret=self.config.secret) + return token_secret if compare_digest(existing_hash, expected_hash) else None + + def _csrf_tokens_match(self, request_csrf_token: str, cookie_csrf_token: str) -> bool: + """Take the CSRF tokens from the request and the cookie and verify both are valid and identical.""" + decoded_request_token = self._decode_csrf_token(request_csrf_token) + decoded_cookie_token = self._decode_csrf_token(cookie_csrf_token) + if decoded_request_token is None or decoded_cookie_token is None: + return False + + return compare_digest(decoded_request_token, decoded_cookie_token) |