diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/middleware')
48 files changed, 3167 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__init__.py b/venv/lib/python3.11/site-packages/litestar/middleware/__init__.py new file mode 100644 index 0000000..7024e54 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__init__.py @@ -0,0 +1,17 @@ +from litestar.middleware.authentication import ( + AbstractAuthenticationMiddleware, + AuthenticationResult, +) +from litestar.middleware.base import ( + AbstractMiddleware, + DefineMiddleware, + MiddlewareProtocol, +) + +__all__ = ( + "AbstractAuthenticationMiddleware", + "AbstractMiddleware", + "AuthenticationResult", + "DefineMiddleware", + "MiddlewareProtocol", +) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c807c7a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/_utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/_utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..95f4515 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/_utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/allowed_hosts.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/allowed_hosts.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..48bfb2a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/allowed_hosts.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/authentication.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/authentication.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..193563c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/authentication.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6a0ef6f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/cors.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/cors.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f4277f4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/cors.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/csrf.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/csrf.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4679eea --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/csrf.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/logging.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/logging.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..20db6ff --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/logging.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/rate_limit.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/rate_limit.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..83090e5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/rate_limit.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/response_cache.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/response_cache.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1672eb4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/response_cache.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/_utils.py b/venv/lib/python3.11/site-packages/litestar/middleware/_utils.py new file mode 100644 index 0000000..778a508 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/_utils.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Pattern, Sequence + +from litestar.exceptions import ImproperlyConfiguredException + +__all__ = ("build_exclude_path_pattern", "should_bypass_middleware") + + +if TYPE_CHECKING: + from litestar.types import Method, Scope, Scopes + + +def build_exclude_path_pattern(*, exclude: str | list[str] | None = None) -> Pattern | None: + """Build single path pattern from list of patterns to opt-out from middleware processing. + + Args: + exclude: A pattern or a list of patterns. + + Returns: + An optional pattern to match against scope["path"] to opt-out from middleware processing. + """ + if exclude is None: + return None + + try: + return re.compile("|".join(exclude)) if isinstance(exclude, list) else re.compile(exclude) + except re.error as e: # pragma: no cover + raise ImproperlyConfiguredException( + "Unable to compile exclude patterns for middleware. Please make sure you passed a valid regular expression." + ) from e + + +def should_bypass_middleware( + *, + exclude_http_methods: Sequence[Method] | None = None, + exclude_opt_key: str | None = None, + exclude_path_pattern: Pattern | None = None, + scope: Scope, + scopes: Scopes, +) -> bool: + """Determine weather a middleware should be bypassed. + + Args: + exclude_http_methods: A sequence of http methods that do not require authentication. + exclude_opt_key: Key in ``opt`` with which a route handler can "opt-out" of a middleware. + exclude_path_pattern: If this pattern matches scope["path"], the middleware should be bypassed. + scope: The ASGI scope. + scopes: A set with the ASGI scope types that are supported by the middleware. + + Returns: + A boolean indicating if a middleware should be bypassed + """ + if scope["type"] not in scopes: + return True + + if exclude_opt_key and scope["route_handler"].opt.get(exclude_opt_key): + return True + + if exclude_http_methods and scope.get("method") in exclude_http_methods: + return True + + return bool( + exclude_path_pattern + and exclude_path_pattern.findall( + scope["raw_path"].decode() if getattr(scope.get("route_handler", {}), "is_mount", False) else scope["path"] + ) + ) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/allowed_hosts.py b/venv/lib/python3.11/site-packages/litestar/middleware/allowed_hosts.py new file mode 100644 index 0000000..0172176 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/allowed_hosts.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Pattern + +from litestar.datastructures import URL, MutableScopeHeaders +from litestar.middleware.base import AbstractMiddleware +from litestar.response.base import ASGIResponse +from litestar.response.redirect import ASGIRedirectResponse +from litestar.status_codes import HTTP_400_BAD_REQUEST + +__all__ = ("AllowedHostsMiddleware",) + + +if TYPE_CHECKING: + from litestar.config.allowed_hosts import AllowedHostsConfig + from litestar.types import ASGIApp, Receive, Scope, Send + + +class AllowedHostsMiddleware(AbstractMiddleware): + """Middleware ensuring the host of a request originated in a trusted host.""" + + def __init__(self, app: ASGIApp, config: AllowedHostsConfig) -> None: + """Initialize ``AllowedHostsMiddleware``. + + Args: + app: The ``next`` ASGI app to call. + config: An instance of AllowedHostsConfig. + """ + + super().__init__(app=app, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key, scopes=config.scopes) + + self.allowed_hosts_regex: Pattern | None = None + self.redirect_domains: Pattern | None = None + + if any(host == "*" for host in config.allowed_hosts): + return + + allowed_hosts: set[str] = { + rf".*\.{host.replace('*.', '')}$" if host.startswith("*.") else host for host in config.allowed_hosts + } + + self.allowed_hosts_regex = re.compile("|".join(sorted(allowed_hosts))) # pyright: ignore + + if config.www_redirect and ( + redirect_domains := {host.replace("www.", "") for host in config.allowed_hosts if host.startswith("www.")} + ): + self.redirect_domains = re.compile("|".join(sorted(redirect_domains))) # pyright: ignore + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + if self.allowed_hosts_regex is None: + await self.app(scope, receive, send) + return + + headers = MutableScopeHeaders(scope=scope) + if host := headers.get("host", headers.get("x-forwarded-host", "")).split(":")[0]: + if self.allowed_hosts_regex.fullmatch(host): + await self.app(scope, receive, send) + return + + if self.redirect_domains is not None and self.redirect_domains.fullmatch(host): + url = URL.from_scope(scope) + redirect_url = url.with_replacements(netloc=f"www.{url.netloc}") + redirect_response = ASGIRedirectResponse(path=str(redirect_url)) + await redirect_response(scope, receive, send) + return + + response = ASGIResponse(body=b'{"message":"invalid host header"}', status_code=HTTP_400_BAD_REQUEST) + await response(scope, receive, send) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/authentication.py b/venv/lib/python3.11/site-packages/litestar/middleware/authentication.py new file mode 100644 index 0000000..9502df0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/authentication.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Sequence + +from litestar.connection import ASGIConnection +from litestar.enums import HttpMethod, ScopeType +from litestar.middleware._utils import ( + build_exclude_path_pattern, + should_bypass_middleware, +) + +__all__ = ("AbstractAuthenticationMiddleware", "AuthenticationResult") + + +if TYPE_CHECKING: + from litestar.types import ASGIApp, Method, Receive, Scope, Scopes, Send + + +@dataclass +class AuthenticationResult: + """Pydantic model for authentication data.""" + + __slots__ = ("user", "auth") + + user: Any + """The user model, this can be any value corresponding to a user of the API.""" + auth: Any + """The auth value, this can for example be a JWT token.""" + + +class AbstractAuthenticationMiddleware(ABC): + """Abstract AuthenticationMiddleware that allows users to create their own AuthenticationMiddleware by extending it + and overriding :meth:`AbstractAuthenticationMiddleware.authenticate_request`. + """ + + __slots__ = ( + "app", + "exclude", + "exclude_http_methods", + "exclude_opt_key", + "scopes", + ) + + def __init__( + self, + app: ASGIApp, + exclude: str | list[str] | None = None, + exclude_from_auth_key: str = "exclude_from_auth", + exclude_http_methods: Sequence[Method] | None = None, + scopes: Scopes | None = None, + ) -> None: + """Initialize ``AbstractAuthenticationMiddleware``. + + Args: + app: An ASGIApp, this value is the next ASGI handler to call in the middleware stack. + exclude: A pattern or list of patterns to skip in the authentication middleware. + exclude_from_auth_key: An identifier to use on routes to disable authentication for a particular route. + exclude_http_methods: A sequence of http methods that do not require authentication. + scopes: ASGI scopes processed by the authentication middleware. + """ + self.app = app + self.exclude = build_exclude_path_pattern(exclude=exclude) + self.exclude_http_methods = (HttpMethod.OPTIONS,) if exclude_http_methods is None else exclude_http_methods + self.exclude_opt_key = exclude_from_auth_key + self.scopes = scopes or {ScopeType.HTTP, ScopeType.WEBSOCKET} + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + if not should_bypass_middleware( + exclude_http_methods=self.exclude_http_methods, + exclude_opt_key=self.exclude_opt_key, + exclude_path_pattern=self.exclude, + scope=scope, + scopes=self.scopes, + ): + auth_result = await self.authenticate_request(ASGIConnection(scope)) + scope["user"] = auth_result.user + scope["auth"] = auth_result.auth + await self.app(scope, receive, send) + + @abstractmethod + async def authenticate_request(self, connection: ASGIConnection) -> AuthenticationResult: + """Receive the http connection and return an :class:`AuthenticationResult`. + + Notes: + - This method must be overridden by subclasses. + + Args: + connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. + + Raises: + NotAuthorizedException | PermissionDeniedException: if authentication fails. + + Returns: + An instance of :class:`AuthenticationResult <litestar.middleware.authentication.AuthenticationResult>`. + """ + raise NotImplementedError("authenticate_request must be overridden by subclasses") diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/base.py b/venv/lib/python3.11/site-packages/litestar/middleware/base.py new file mode 100644 index 0000000..43106c9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/base.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Callable, Protocol, runtime_checkable + +from litestar.enums import ScopeType +from litestar.middleware._utils import ( + build_exclude_path_pattern, + should_bypass_middleware, +) + +__all__ = ("AbstractMiddleware", "DefineMiddleware", "MiddlewareProtocol") + + +if TYPE_CHECKING: + from litestar.types import Scopes + from litestar.types.asgi_types import ASGIApp, Receive, Scope, Send + + +@runtime_checkable +class MiddlewareProtocol(Protocol): + """Abstract middleware protocol.""" + + __slots__ = ("app",) + + app: ASGIApp + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Execute the ASGI middleware. + + Called by the previous middleware in the stack if a response is not awaited prior. + + Upon completion, middleware should call the next ASGI handler and await it - or await a response created in its + closure. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + + +class DefineMiddleware: + """Container enabling passing ``*args`` and ``**kwargs`` to Middleware class constructors and factory functions.""" + + __slots__ = ("middleware", "args", "kwargs") + + def __init__(self, middleware: Callable[..., ASGIApp], *args: Any, **kwargs: Any) -> None: + """Initialize ``DefineMiddleware``. + + Args: + middleware: A callable that returns an ASGIApp. + *args: Positional arguments to pass to the callable. + **kwargs: Key word arguments to pass to the callable. + + Notes: + The callable will be passed a kwarg ``app``, which is the next ASGI app to call in the middleware stack. + It therefore must define such a kwarg. + """ + self.middleware = middleware + self.args = args + self.kwargs = kwargs + + def __call__(self, app: ASGIApp) -> ASGIApp: + """Call the middleware constructor or factory. + + Args: + app: An ASGIApp, this value is the next ASGI handler to call in the middleware stack. + + Returns: + Calls :class:`DefineMiddleware.middleware <.DefineMiddleware>` and returns the ASGIApp created. + """ + + return self.middleware(*self.args, app=app, **self.kwargs) + + +class AbstractMiddleware: + """Abstract middleware providing base functionality common to all middlewares, for dynamically engaging/bypassing + the middleware based on paths, ``opt``-keys and scope types. + + When implementing new middleware, this class should be used as a base. + """ + + scopes: Scopes = {ScopeType.HTTP, ScopeType.WEBSOCKET} + exclude: str | list[str] | None = None + exclude_opt_key: str | None = None + + def __init__( + self, + app: ASGIApp, + exclude: str | list[str] | None = None, + exclude_opt_key: str | None = None, + scopes: Scopes | None = None, + ) -> None: + """Initialize the middleware. + + Args: + app: The ``next`` ASGI app to call. + exclude: A pattern or list of patterns to match against a request's path. + If a match is found, the middleware will be skipped. + exclude_opt_key: An identifier that is set in the route handler + ``opt`` key which allows skipping the middleware. + scopes: ASGI scope types, should be a set including + either or both 'ScopeType.HTTP' and 'ScopeType.WEBSOCKET'. + """ + self.app = app + self.scopes = scopes or self.scopes + self.exclude_opt_key = exclude_opt_key or self.exclude_opt_key + self.exclude_pattern = build_exclude_path_pattern(exclude=(exclude or self.exclude)) + + @classmethod + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + + original__call__ = cls.__call__ + + async def wrapped_call(self: AbstractMiddleware, scope: Scope, receive: Receive, send: Send) -> None: + if should_bypass_middleware( + scope=scope, + scopes=self.scopes, + exclude_path_pattern=self.exclude_pattern, + exclude_opt_key=self.exclude_opt_key, + ): + await self.app(scope, receive, send) + else: + await original__call__(self, scope, receive, send) # pyright: ignore + + # https://github.com/python/mypy/issues/2427#issuecomment-384229898 + setattr(cls, "__call__", wrapped_call) + + @abstractmethod + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Execute the ASGI middleware. + + Called by the previous middleware in the stack if a response is not awaited prior. + + Upon completion, middleware should call the next ASGI handler and await it - or await a response created in its + closure. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + raise NotImplementedError("abstract method must be implemented") diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/__init__.py b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__init__.py new file mode 100644 index 0000000..0885932 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__init__.py @@ -0,0 +1,4 @@ +from litestar.middleware.compression.facade import CompressionFacade +from litestar.middleware.compression.middleware import CompressionMiddleware + +__all__ = ("CompressionMiddleware", "CompressionFacade") diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..80ea058 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/brotli_facade.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/brotli_facade.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..7378c0f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/brotli_facade.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/facade.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/facade.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..d336c8f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/facade.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/gzip_facade.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/gzip_facade.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..66e1df4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/gzip_facade.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/middleware.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/middleware.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a683673 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/middleware.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/brotli_facade.py b/venv/lib/python3.11/site-packages/litestar/middleware/compression/brotli_facade.py new file mode 100644 index 0000000..3d01950 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/brotli_facade.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from litestar.enums import CompressionEncoding +from litestar.exceptions import MissingDependencyException +from litestar.middleware.compression.facade import CompressionFacade + +try: + from brotli import MODE_FONT, MODE_GENERIC, MODE_TEXT, Compressor +except ImportError as e: + raise MissingDependencyException("brotli") from e + + +if TYPE_CHECKING: + from io import BytesIO + + from litestar.config.compression import CompressionConfig + + +class BrotliCompression(CompressionFacade): + __slots__ = ("compressor", "buffer", "compression_encoding") + + encoding = CompressionEncoding.BROTLI + + def __init__( + self, + buffer: BytesIO, + compression_encoding: Literal[CompressionEncoding.BROTLI] | str, + config: CompressionConfig, + ) -> None: + self.buffer = buffer + self.compression_encoding = compression_encoding + modes: dict[Literal["generic", "text", "font"], int] = { + "text": int(MODE_TEXT), + "font": int(MODE_FONT), + "generic": int(MODE_GENERIC), + } + self.compressor = Compressor( + quality=config.brotli_quality, + mode=modes[config.brotli_mode], + lgwin=config.brotli_lgwin, + lgblock=config.brotli_lgblock, + ) + + def write(self, body: bytes) -> None: + self.buffer.write(self.compressor.process(body)) + self.buffer.write(self.compressor.flush()) + + def close(self) -> None: + self.buffer.write(self.compressor.finish()) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/facade.py b/venv/lib/python3.11/site-packages/litestar/middleware/compression/facade.py new file mode 100644 index 0000000..0074b57 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/facade.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar, Protocol + +if TYPE_CHECKING: + from io import BytesIO + + from litestar.config.compression import CompressionConfig + from litestar.enums import CompressionEncoding + + +class CompressionFacade(Protocol): + """A unified facade offering a uniform interface for different compression libraries.""" + + encoding: ClassVar[str] + """The encoding of the compression.""" + + def __init__( + self, buffer: BytesIO, compression_encoding: CompressionEncoding | str, config: CompressionConfig + ) -> None: + """Initialize ``CompressionFacade``. + + Args: + buffer: A bytes IO buffer to write the compressed data into. + compression_encoding: The compression encoding used. + config: The app compression config. + """ + ... + + def write(self, body: bytes) -> None: + """Write compressed bytes. + + Args: + body: Message body to process + + Returns: + None + """ + ... + + def close(self) -> None: + """Close the compression stream. + + Returns: + None + """ + ... diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/gzip_facade.py b/venv/lib/python3.11/site-packages/litestar/middleware/compression/gzip_facade.py new file mode 100644 index 0000000..b10ef73 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/gzip_facade.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from gzip import GzipFile +from typing import TYPE_CHECKING, Literal + +from litestar.enums import CompressionEncoding +from litestar.middleware.compression.facade import CompressionFacade + +if TYPE_CHECKING: + from io import BytesIO + + from litestar.config.compression import CompressionConfig + + +class GzipCompression(CompressionFacade): + __slots__ = ("compressor", "buffer", "compression_encoding") + + encoding = CompressionEncoding.GZIP + + def __init__( + self, buffer: BytesIO, compression_encoding: Literal[CompressionEncoding.GZIP] | str, config: CompressionConfig + ) -> None: + self.buffer = buffer + self.compression_encoding = compression_encoding + self.compressor = GzipFile(mode="wb", fileobj=buffer, compresslevel=config.gzip_compress_level) + + def write(self, body: bytes) -> None: + self.compressor.write(body) + self.compressor.flush() + + def close(self) -> None: + self.compressor.close() diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/middleware.py b/venv/lib/python3.11/site-packages/litestar/middleware/compression/middleware.py new file mode 100644 index 0000000..7ea7853 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/middleware.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +from io import BytesIO +from typing import TYPE_CHECKING, Any, Literal + +from litestar.datastructures import Headers, MutableScopeHeaders +from litestar.enums import CompressionEncoding, ScopeType +from litestar.middleware.base import AbstractMiddleware +from litestar.middleware.compression.gzip_facade import GzipCompression +from litestar.utils.empty import value_or_default +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from litestar.config.compression import CompressionConfig + from litestar.middleware.compression.facade import CompressionFacade + from litestar.types import ( + ASGIApp, + HTTPResponseStartEvent, + Message, + Receive, + Scope, + Send, + ) + + try: + from brotli import Compressor + except ImportError: + Compressor = Any + + +class CompressionMiddleware(AbstractMiddleware): + """Compression Middleware Wrapper. + + This is a wrapper allowing for generic compression configuration / handler middleware + """ + + def __init__(self, app: ASGIApp, config: CompressionConfig) -> None: + """Initialize ``CompressionMiddleware`` + + Args: + app: The ``next`` ASGI app to call. + config: An instance of CompressionConfig. + """ + super().__init__( + app=app, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key, scopes={ScopeType.HTTP} + ) + self.config = config + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + accept_encoding = Headers.from_scope(scope).get("accept-encoding", "") + config = self.config + + if config.compression_facade.encoding in accept_encoding: + await self.app( + scope, + receive, + self.create_compression_send_wrapper( + send=send, compression_encoding=config.compression_facade.encoding, scope=scope + ), + ) + return + + if config.gzip_fallback and CompressionEncoding.GZIP in accept_encoding: + await self.app( + scope, + receive, + self.create_compression_send_wrapper( + send=send, compression_encoding=CompressionEncoding.GZIP, scope=scope + ), + ) + return + + await self.app(scope, receive, send) + + def create_compression_send_wrapper( + self, + send: Send, + compression_encoding: Literal[CompressionEncoding.BROTLI, CompressionEncoding.GZIP] | str, + scope: Scope, + ) -> Send: + """Wrap ``send`` to handle brotli compression. + + Args: + send: The ASGI send function. + compression_encoding: The compression encoding used. + scope: The ASGI connection scope + + Returns: + An ASGI send function. + """ + bytes_buffer = BytesIO() + + facade: CompressionFacade + # We can't use `self.config.compression_facade` directly if the compression is `gzip` since + # it may be being used as a fallback. + if compression_encoding == CompressionEncoding.GZIP: + facade = GzipCompression(buffer=bytes_buffer, compression_encoding=compression_encoding, config=self.config) + else: + facade = self.config.compression_facade( + buffer=bytes_buffer, compression_encoding=compression_encoding, config=self.config + ) + + initial_message: HTTPResponseStartEvent | None = None + started = False + + connection_state = ScopeState.from_scope(scope) + + async def send_wrapper(message: Message) -> None: + """Handle and compresses the HTTP Message with brotli. + + Args: + message (Message): An ASGI Message. + """ + nonlocal started + nonlocal initial_message + + if message["type"] == "http.response.start": + initial_message = message + return + + if initial_message is not None and value_or_default(connection_state.is_cached, False): + await send(initial_message) + await send(message) + return + + if initial_message and message["type"] == "http.response.body": + body = message["body"] + more_body = message.get("more_body") + + if not started: + started = True + if more_body: + headers = MutableScopeHeaders(initial_message) + headers["Content-Encoding"] = compression_encoding + headers.extend_header_value("vary", "Accept-Encoding") + del headers["Content-Length"] + connection_state.response_compressed = True + + facade.write(body) + + message["body"] = bytes_buffer.getvalue() + bytes_buffer.seek(0) + bytes_buffer.truncate() + await send(initial_message) + await send(message) + + elif len(body) >= self.config.minimum_size: + facade.write(body) + facade.close() + body = bytes_buffer.getvalue() + + headers = MutableScopeHeaders(initial_message) + headers["Content-Encoding"] = compression_encoding + headers["Content-Length"] = str(len(body)) + headers.extend_header_value("vary", "Accept-Encoding") + message["body"] = body + connection_state.response_compressed = True + + await send(initial_message) + await send(message) + + else: + await send(initial_message) + await send(message) + + else: + facade.write(body) + if not more_body: + facade.close() + + message["body"] = bytes_buffer.getvalue() + + bytes_buffer.seek(0) + bytes_buffer.truncate() + + if not more_body: + bytes_buffer.close() + + await send(message) + + return send_wrapper diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/cors.py b/venv/lib/python3.11/site-packages/litestar/middleware/cors.py new file mode 100644 index 0000000..6c4de31 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/cors.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from litestar.datastructures import Headers, MutableScopeHeaders +from litestar.enums import ScopeType +from litestar.middleware.base import AbstractMiddleware + +__all__ = ("CORSMiddleware",) + + +if TYPE_CHECKING: + from litestar.config.cors import CORSConfig + from litestar.types import ASGIApp, Message, Receive, Scope, Send + + +class CORSMiddleware(AbstractMiddleware): + """CORS Middleware.""" + + __slots__ = ("config",) + + def __init__(self, app: ASGIApp, config: CORSConfig) -> None: + """Middleware that adds CORS validation to the application. + + Args: + app: The ``next`` ASGI app to call. + config: An instance of :class:`CORSConfig <litestar.config.cors.CORSConfig>` + """ + super().__init__(app=app, scopes={ScopeType.HTTP}) + self.config = config + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + headers = Headers.from_scope(scope=scope) + if origin := headers.get("origin"): + await self.app(scope, receive, self.send_wrapper(send=send, origin=origin, has_cookie="cookie" in headers)) + else: + await self.app(scope, receive, send) + + def send_wrapper(self, send: Send, origin: str, has_cookie: bool) -> Send: + """Wrap ``send`` to ensure that state is not disconnected. + + Args: + has_cookie: Boolean flag dictating if the connection has a cookie set. + origin: The value of the ``Origin`` header. + send: The ASGI send function. + + Returns: + An ASGI send function. + """ + + async def wrapped_send(message: Message) -> None: + if message["type"] == "http.response.start": + message.setdefault("headers", []) + headers = MutableScopeHeaders.from_message(message=message) + headers.update(self.config.simple_headers) + + if (self.config.is_allow_all_origins and has_cookie) or ( + not self.config.is_allow_all_origins and self.config.is_origin_allowed(origin=origin) + ): + headers["Access-Control-Allow-Origin"] = origin + headers["Vary"] = "Origin" + + # We don't want to overwrite this for preflight requests. + allow_headers = headers.get("Access-Control-Allow-Headers") + if not allow_headers and self.config.allow_headers: + headers["Access-Control-Allow-Headers"] = ", ".join(sorted(set(self.config.allow_headers))) + + allow_methods = headers.get("Access-Control-Allow-Methods") + if not allow_methods and self.config.allow_methods: + headers["Access-Control-Allow-Methods"] = ", ".join(sorted(set(self.config.allow_methods))) + + await send(message) + + return wrapped_send diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/csrf.py b/venv/lib/python3.11/site-packages/litestar/middleware/csrf.py new file mode 100644 index 0000000..94dd422 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/csrf.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import hashlib +import hmac +import secrets +from secrets import compare_digest +from typing import TYPE_CHECKING, Any + +from litestar.datastructures import MutableScopeHeaders +from litestar.datastructures.cookie import Cookie +from litestar.enums import RequestEncodingType, ScopeType +from litestar.exceptions import PermissionDeniedException +from litestar.middleware._utils import ( + build_exclude_path_pattern, + should_bypass_middleware, +) +from litestar.middleware.base import MiddlewareProtocol +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from litestar.config.csrf import CSRFConfig + from litestar.connection import Request + from litestar.types import ( + ASGIApp, + HTTPSendMessage, + Message, + Receive, + Scope, + Scopes, + Send, + ) + +__all__ = ("CSRFMiddleware",) + + +CSRF_SECRET_BYTES = 32 +CSRF_SECRET_LENGTH = CSRF_SECRET_BYTES * 2 + + +def generate_csrf_hash(token: str, secret: str) -> str: + """Generate an HMAC that signs the CSRF token. + + Args: + token: A hashed token. + secret: A secret value. + + Returns: + A CSRF hash. + """ + return hmac.new(secret.encode(), token.encode(), hashlib.sha256).hexdigest() + + +def generate_csrf_token(secret: str) -> str: + """Generate a CSRF token that includes a randomly generated string signed by an HMAC. + + Args: + secret: A secret string. + + Returns: + A unique CSRF token. + """ + token = secrets.token_hex(CSRF_SECRET_BYTES) + token_hash = generate_csrf_hash(token=token, secret=secret) + return token + token_hash + + +class CSRFMiddleware(MiddlewareProtocol): + """CSRF Middleware class. + + This Middleware protects against attacks by setting a CSRF cookie with a token and verifying it in request headers. + """ + + scopes: Scopes = {ScopeType.HTTP} + + def __init__(self, app: ASGIApp, config: CSRFConfig) -> None: + """Initialize ``CSRFMiddleware``. + + Args: + app: The ``next`` ASGI app to call. + config: The CSRFConfig instance. + """ + self.app = app + self.config = config + self.exclude = build_exclude_path_pattern(exclude=config.exclude) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + if scope["type"] != ScopeType.HTTP: + await self.app(scope, receive, send) + return + + request: Request[Any, Any, Any] = scope["app"].request_class(scope=scope, receive=receive) + content_type, _ = request.content_type + csrf_cookie = request.cookies.get(self.config.cookie_name) + existing_csrf_token = request.headers.get(self.config.header_name) + + if not existing_csrf_token and content_type in { + RequestEncodingType.URL_ENCODED, + RequestEncodingType.MULTI_PART, + }: + form = await request.form() + existing_csrf_token = form.get("_csrf_token", None) + + connection_state = ScopeState.from_scope(scope) + if request.method in self.config.safe_methods or should_bypass_middleware( + scope=scope, + scopes=self.scopes, + exclude_opt_key=self.config.exclude_from_csrf_key, + exclude_path_pattern=self.exclude, + ): + token = connection_state.csrf_token = csrf_cookie or generate_csrf_token(secret=self.config.secret) + await self.app(scope, receive, self.create_send_wrapper(send=send, csrf_cookie=csrf_cookie, token=token)) + elif ( + existing_csrf_token is not None + and csrf_cookie is not None + and self._csrf_tokens_match(existing_csrf_token, csrf_cookie) + ): + connection_state.csrf_token = existing_csrf_token + await self.app(scope, receive, send) + else: + raise PermissionDeniedException("CSRF token verification failed") + + def create_send_wrapper(self, send: Send, token: str, csrf_cookie: str | None) -> Send: + """Wrap ``send`` to handle CSRF validation. + + Args: + token: The CSRF token. + send: The ASGI send function. + csrf_cookie: CSRF cookie. + + Returns: + An ASGI send function. + """ + + async def send_wrapper(message: Message) -> None: + """Send function that wraps the original send to inject a cookie. + + Args: + message: An ASGI ``Message`` + + Returns: + None + """ + if csrf_cookie is None and message["type"] == "http.response.start": + message.setdefault("headers", []) + self._set_cookie_if_needed(message=message, token=token) + await send(message) + + return send_wrapper + + def _set_cookie_if_needed(self, message: HTTPSendMessage, token: str) -> None: + headers = MutableScopeHeaders.from_message(message) + cookie = Cookie( + key=self.config.cookie_name, + value=token, + path=self.config.cookie_path, + secure=self.config.cookie_secure, + httponly=self.config.cookie_httponly, + samesite=self.config.cookie_samesite, + domain=self.config.cookie_domain, + ) + headers.add("set-cookie", cookie.to_header(header="")) + + def _decode_csrf_token(self, token: str) -> str | None: + """Decode a CSRF token and validate its HMAC.""" + if len(token) < CSRF_SECRET_LENGTH + 1: + return None + + token_secret = token[:CSRF_SECRET_LENGTH] + existing_hash = token[CSRF_SECRET_LENGTH:] + expected_hash = generate_csrf_hash(token=token_secret, secret=self.config.secret) + return token_secret if compare_digest(existing_hash, expected_hash) else None + + def _csrf_tokens_match(self, request_csrf_token: str, cookie_csrf_token: str) -> bool: + """Take the CSRF tokens from the request and the cookie and verify both are valid and identical.""" + decoded_request_token = self._decode_csrf_token(request_csrf_token) + decoded_cookie_token = self._decode_csrf_token(cookie_csrf_token) + if decoded_request_token is None or decoded_cookie_token is None: + return False + + return compare_digest(decoded_request_token, decoded_cookie_token) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__init__.py b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__init__.py new file mode 100644 index 0000000..5328adf --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__init__.py @@ -0,0 +1,3 @@ +from litestar.middleware.exceptions.middleware import ExceptionHandlerMiddleware + +__all__ = ("ExceptionHandlerMiddleware",) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c443e00 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/_debug_response.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/_debug_response.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b41fc85 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/_debug_response.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/middleware.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/middleware.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2259206 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/middleware.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/_debug_response.py b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/_debug_response.py new file mode 100644 index 0000000..99e8c87 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/_debug_response.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +from html import escape +from inspect import getinnerframes +from pathlib import Path +from traceback import format_exception +from typing import TYPE_CHECKING, Any + +from litestar.enums import MediaType +from litestar.response import Response +from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR +from litestar.utils import get_name + +__all__ = ( + "create_debug_response", + "create_exception_html", + "create_frame_html", + "create_html_response_content", + "create_line_html", + "create_plain_text_response_content", + "get_symbol_name", +) + + +if TYPE_CHECKING: + from inspect import FrameInfo + + from litestar.connection import Request + from litestar.types import TypeEncodersMap + +tpl_dir = Path(__file__).parent / "templates" + + +def get_symbol_name(frame: FrameInfo) -> str: + """Return full name of the function that is being executed by the given frame. + + Args: + frame: An instance of [FrameInfo](https://docs.python.org/3/library/inspect.html#inspect.FrameInfo). + + Notes: + - class detection assumes standard names (self and cls) of params. + - if current class name can not be determined only function (method) name will be returned. + - we can not distinguish static methods from ordinary functions at the moment. + + Returns: + A string containing full function name. + """ + + locals_dict = frame.frame.f_locals + # this piece assumes that the code uses standard names "self" and "cls" + # in instance and class methods + instance_or_cls = inst if (inst := locals_dict.get("self")) is not None else locals_dict.get("cls") + + classname = f"{get_name(instance_or_cls)}." if instance_or_cls is not None else "" + + return f"{classname}{frame.function}" + + +def create_line_html( + line: str, + line_no: int, + frame_index: int, + idx: int, +) -> str: + """Produce HTML representation of a line including real line number in the source code. + + Args: + line: A string representing the current line. + line_no: The line number associated with the executed line. + frame_index: Index of the executed line in the code context. + idx: Index of the current line in the code context. + + Returns: + A string containing HTML representation of the given line. + """ + template = '<tr class="{line_class}"><td class="line_no">{line_no}</td><td class="code_line">{line}</td></tr>' + data = { + # line_no - frame_index produces actual line number of the very first line in the frame code context. + # so adding index (aka relative number) of a line in the code context we can calculate its actual number in the source file, + "line_no": line_no - frame_index + idx, + "line": escape(line).replace(" ", " "), + "line_class": "executed-line" if idx == frame_index else "", + } + return template.format(**data) + + +def create_frame_html(frame: FrameInfo, collapsed: bool) -> str: + """Produce HTML representation of the given frame object including filename containing source code and name of the + function being executed. + + Args: + frame: An instance of [FrameInfo](https://docs.python.org/3/library/inspect.html#inspect.FrameInfo). + collapsed: Flag controlling whether frame should be collapsed on the page load. + + Returns: + A string containing HTML representation of the execution frame. + """ + frame_tpl = (tpl_dir / "frame.html").read_text() + + code_lines: list[str] = [ + create_line_html(line, frame.lineno, frame.index or 0, idx) for idx, line in enumerate(frame.code_context or []) + ] + data = { + "file": escape(frame.filename), + "line": frame.lineno, + "symbol_name": escape(get_symbol_name(frame)), + "code": "".join(code_lines), + "frame_class": "collapsed" if collapsed else "", + } + return frame_tpl.format(**data) + + +def create_exception_html(exc: BaseException, line_limit: int) -> str: + """Produce HTML representation of the exception frames. + + Args: + exc: An Exception instance to generate. + line_limit: Number of lines of code context to return, which are centered around the executed line. + + Returns: + A string containing HTML representation of the execution frames related to the exception. + """ + frames = getinnerframes(exc.__traceback__, line_limit) if exc.__traceback__ else [] + result = [create_frame_html(frame=frame, collapsed=idx > 0) for idx, frame in enumerate(reversed(frames))] + return "".join(result) + + +def create_html_response_content(exc: Exception, request: Request, line_limit: int = 15) -> str: + """Given an exception, produces its traceback in HTML. + + Args: + exc: An Exception instance to render debug response from. + request: A :class:`Request <litestar.connection.Request>` instance. + line_limit: Number of lines of code context to return, which are centered around the executed line. + + Returns: + A string containing HTML page with exception traceback. + """ + exception_data: list[str] = [create_exception_html(exc, line_limit)] + cause = exc.__cause__ + while cause: + cause_data = create_exception_html(cause, line_limit) + cause_header = '<h4 class="cause-header">The above exception was caused by</h4>' + cause_error_description = f"<h3><span>{escape(str(cause))}</span></h3>" + cause_error = f"<h4><span>{escape(cause.__class__.__name__)}</span></h4>" + exception_data.append( + f'<div class="cause-wrapper">{cause_header}{cause_error}{cause_error_description}{cause_data}</div>' + ) + cause = cause.__cause__ + + scripts = (tpl_dir / "scripts.js").read_text() + styles = (tpl_dir / "styles.css").read_text() + body_tpl = (tpl_dir / "body.html").read_text() + return body_tpl.format( + scripts=scripts, + styles=styles, + error=f"<span>{escape(exc.__class__.__name__)}</span> on {request.method} {escape(request.url.path)}", + error_description=escape(str(exc)), + exception_data="".join(exception_data), + ) + + +def create_plain_text_response_content(exc: Exception) -> str: + """Given an exception, produces its traceback in plain text. + + Args: + exc: An Exception instance to render debug response from. + + Returns: + A string containing exception traceback. + """ + return "".join(format_exception(type(exc), value=exc, tb=exc.__traceback__)) + + +def create_debug_response(request: Request, exc: Exception) -> Response: + """Create debug response either in plain text or HTML depending on client capabilities. + + Args: + request: A :class:`Request <litestar.connection.Request>` instance. + exc: An Exception instance to render debug response from. + + Returns: + A response with a rendered exception traceback. + """ + if MediaType.HTML in request.headers.get("accept", ""): + content: Any = create_html_response_content(exc=exc, request=request) + media_type = MediaType.HTML + elif MediaType.JSON in request.headers.get("accept", ""): + content = {"details": create_plain_text_response_content(exc), "status_code": HTTP_500_INTERNAL_SERVER_ERROR} + media_type = MediaType.JSON + else: + content = create_plain_text_response_content(exc) + media_type = MediaType.TEXT + + return Response( + content=content, + media_type=media_type, + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + type_encoders=_get_type_encoders_for_request(request), + ) + + +def _get_type_encoders_for_request(request: Request) -> TypeEncodersMap | None: + try: + return request.route_handler.resolve_type_encoders() + # we might be in a 404, or before we could resolve the handler, so this + # could potentially error out. In this case we fall back on the application + # type encoders + except (KeyError, AttributeError): + return request.app.type_encoders diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/middleware.py b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/middleware.py new file mode 100644 index 0000000..f3ff157 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/middleware.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import pdb # noqa: T100 +from dataclasses import asdict, dataclass, field +from inspect import getmro +from sys import exc_info +from traceback import format_exception +from typing import TYPE_CHECKING, Any, Type, cast + +from litestar.datastructures import Headers +from litestar.enums import MediaType, ScopeType +from litestar.exceptions import HTTPException, LitestarException, WebSocketException +from litestar.middleware.cors import CORSMiddleware +from litestar.middleware.exceptions._debug_response import _get_type_encoders_for_request, create_debug_response +from litestar.serialization import encode_json +from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR +from litestar.utils.deprecation import warn_deprecation + +__all__ = ("ExceptionHandlerMiddleware", "ExceptionResponseContent", "create_exception_response") + + +if TYPE_CHECKING: + from starlette.exceptions import HTTPException as StarletteHTTPException + + from litestar import Response + from litestar.app import Litestar + from litestar.connection import Request + from litestar.logging import BaseLoggingConfig + from litestar.types import ( + ASGIApp, + ExceptionHandler, + ExceptionHandlersMap, + Logger, + Receive, + Scope, + Send, + ) + from litestar.types.asgi_types import WebSocketCloseEvent + + +def get_exception_handler(exception_handlers: ExceptionHandlersMap, exc: Exception) -> ExceptionHandler | None: + """Given a dictionary that maps exceptions and status codes to handler functions, and an exception, returns the + appropriate handler if existing. + + Status codes are given preference over exception type. + + If no status code match exists, each class in the MRO of the exception type is checked and + the first matching handler is returned. + + Finally, if a ``500`` handler is registered, it will be returned for any exception that isn't a + subclass of :class:`HTTPException <litestar.exceptions.HTTPException>`. + + Args: + exception_handlers: Mapping of status codes and exception types to handlers. + exc: Exception Instance to be resolved to a handler. + + Returns: + Optional exception handler callable. + """ + if not exception_handlers: + return None + + default_handler: ExceptionHandler | None = None + if isinstance(exc, HTTPException): + if exception_handler := exception_handlers.get(exc.status_code): + return exception_handler + else: + default_handler = exception_handlers.get(HTTP_500_INTERNAL_SERVER_ERROR) + + return next( + (exception_handlers[cast("Type[Exception]", cls)] for cls in getmro(type(exc)) if cls in exception_handlers), + default_handler, + ) + + +@dataclass +class ExceptionResponseContent: + """Represent the contents of an exception-response.""" + + status_code: int + """Exception status code.""" + detail: str + """Exception details or message.""" + media_type: MediaType | str + """Media type of the response.""" + headers: dict[str, str] | None = field(default=None) + """Headers to attach to the response.""" + extra: dict[str, Any] | list[Any] | None = field(default=None) + """An extra mapping to attach to the exception.""" + + def to_response(self, request: Request | None = None) -> Response: + """Create a response from the model attributes. + + Returns: + A response instance. + """ + from litestar.response import Response + + content: Any = {k: v for k, v in asdict(self).items() if k not in ("headers", "media_type") and v is not None} + + if self.media_type != MediaType.JSON: + content = encode_json(content) + + return Response( + content=content, + headers=self.headers, + status_code=self.status_code, + media_type=self.media_type, + type_encoders=_get_type_encoders_for_request(request) if request is not None else None, + ) + + +def _starlette_exception_handler(request: Request[Any, Any, Any], exc: StarletteHTTPException) -> Response: + return create_exception_response( + request=request, + exc=HTTPException( + detail=exc.detail, + status_code=exc.status_code, + headers=exc.headers, + ), + ) + + +def create_exception_response(request: Request[Any, Any, Any], exc: Exception) -> Response: + """Construct a response from an exception. + + Notes: + - For instances of :class:`HTTPException <litestar.exceptions.HTTPException>` or other exception classes that have a + ``status_code`` attribute (e.g. Starlette exceptions), the status code is drawn from the exception, otherwise + response status is ``HTTP_500_INTERNAL_SERVER_ERROR``. + + Args: + request: The request that triggered the exception. + exc: An exception. + + Returns: + Response: HTTP response constructed from exception details. + """ + headers: dict[str, Any] | None + extra: dict[str, Any] | list | None + + if isinstance(exc, HTTPException): + status_code = exc.status_code + headers = exc.headers + extra = exc.extra + else: + status_code = HTTP_500_INTERNAL_SERVER_ERROR + headers = None + extra = None + + detail = ( + exc.detail + if isinstance(exc, LitestarException) and status_code != HTTP_500_INTERNAL_SERVER_ERROR + else "Internal Server Error" + ) + + try: + media_type = request.route_handler.media_type + except (KeyError, AttributeError): + media_type = MediaType.JSON + + content = ExceptionResponseContent( + status_code=status_code, + detail=detail, + headers=headers, + extra=extra, + media_type=media_type, + ) + return content.to_response(request=request) + + +class ExceptionHandlerMiddleware: + """Middleware used to wrap an ASGIApp inside a try catch block and handle any exceptions raised. + + This used in multiple layers of Litestar. + """ + + def __init__(self, app: ASGIApp, debug: bool | None, exception_handlers: ExceptionHandlersMap) -> None: + """Initialize ``ExceptionHandlerMiddleware``. + + Args: + app: The ``next`` ASGI app to call. + debug: Whether ``debug`` mode is enabled. Deprecated. Debug mode will be inferred from the request scope + exception_handlers: A dictionary mapping status codes and/or exception types to handler functions. + + .. deprecated:: 2.0.0 + The ``debug`` parameter is deprecated. It will be inferred from the request scope + """ + self.app = app + self.exception_handlers = exception_handlers + self.debug = debug + if debug is not None: + warn_deprecation( + "2.0.0", + deprecated_name="debug", + kind="parameter", + info="Debug mode will be inferred from the request scope", + ) + + self._get_debug = self._get_debug_scope if debug is None else lambda *a: debug + + @staticmethod + def _get_debug_scope(scope: Scope) -> bool: + return scope["app"].debug + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI-callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + try: + await self.app(scope, receive, send) + except Exception as e: # noqa: BLE001 + litestar_app = scope["app"] + + if litestar_app.logging_config and (logger := litestar_app.logger): + self.handle_exception_logging(logger=logger, logging_config=litestar_app.logging_config, scope=scope) + + for hook in litestar_app.after_exception: + await hook(e, scope) + + if litestar_app.pdb_on_exception: + pdb.post_mortem() + + if scope["type"] == ScopeType.HTTP: + await self.handle_request_exception( + litestar_app=litestar_app, scope=scope, receive=receive, send=send, exc=e + ) + else: + await self.handle_websocket_exception(send=send, exc=e) + + async def handle_request_exception( + self, litestar_app: Litestar, scope: Scope, receive: Receive, send: Send, exc: Exception + ) -> None: + """Handle exception raised inside 'http' scope routes. + + Args: + litestar_app: The litestar app instance. + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + exc: The caught exception. + + Returns: + None. + """ + + headers = Headers.from_scope(scope=scope) + if litestar_app.cors_config and (origin := headers.get("origin")): + cors_middleware = CORSMiddleware(app=self.app, config=litestar_app.cors_config) + send = cors_middleware.send_wrapper(send=send, origin=origin, has_cookie="cookie" in headers) + + exception_handler = get_exception_handler(self.exception_handlers, exc) or self.default_http_exception_handler + request: Request[Any, Any, Any] = litestar_app.request_class(scope=scope, receive=receive, send=send) + response = exception_handler(request, exc) + await response.to_asgi_response(app=None, request=request)(scope=scope, receive=receive, send=send) + + @staticmethod + async def handle_websocket_exception(send: Send, exc: Exception) -> None: + """Handle exception raised inside 'websocket' scope routes. + + Args: + send: The ASGI send function. + exc: The caught exception. + + Returns: + None. + """ + code = 4000 + HTTP_500_INTERNAL_SERVER_ERROR + reason = "Internal Server Error" + if isinstance(exc, WebSocketException): + code = exc.code + reason = exc.detail + elif isinstance(exc, LitestarException): + reason = exc.detail + + event: WebSocketCloseEvent = {"type": "websocket.close", "code": code, "reason": reason} + await send(event) + + def default_http_exception_handler(self, request: Request, exc: Exception) -> Response[Any]: + """Handle an HTTP exception by returning the appropriate response. + + Args: + request: An HTTP Request instance. + exc: The caught exception. + + Returns: + An HTTP response. + """ + status_code = exc.status_code if isinstance(exc, HTTPException) else HTTP_500_INTERNAL_SERVER_ERROR + if status_code == HTTP_500_INTERNAL_SERVER_ERROR and self._get_debug_scope(request.scope): + return create_debug_response(request=request, exc=exc) + return create_exception_response(request=request, exc=exc) + + def handle_exception_logging(self, logger: Logger, logging_config: BaseLoggingConfig, scope: Scope) -> None: + """Handle logging - if the litestar app has a logging config in place. + + Args: + logger: A logger instance. + logging_config: Logging Config instance. + scope: The ASGI connection scope. + + Returns: + None + """ + if ( + logging_config.log_exceptions == "always" + or (logging_config.log_exceptions == "debug" and self._get_debug_scope(scope)) + ) and logging_config.exception_logging_handler: + logging_config.exception_logging_handler(logger, scope, format_exception(*exc_info())) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/body.html b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/body.html new file mode 100644 index 0000000..1c6705c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/body.html @@ -0,0 +1,20 @@ +<!doctype html> + +<html lang="en"> + <head> + <meta charset="utf-8" /> + <style type="text/css"> + {styles} + </style> + <title>Litestar exception page</title> + </head> + <body> + <h4>{error}</h4> + <h3><span>{error_description}</span></h3> + {exception_data} + <script type="text/javascript"> + // prettier-ignore + {scripts} // NOSONAR + </script> + </body> +</html> diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/frame.html b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/frame.html new file mode 100644 index 0000000..2ead8dd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/frame.html @@ -0,0 +1,12 @@ +<div class="frame {frame_class}"> + <div class="frame-name"> + <span class="expander">â–¼</span> + <span class="breakable">{file}</span> in <span>{symbol_name}</span> at line + <span>{line}</span> + </div> + <div class="code-snippet-wrapper"> + <table role="presentation" class="code-snippet"> + {code} + </table> + </div> +</div> diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/scripts.js b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/scripts.js new file mode 100644 index 0000000..014a256 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/scripts.js @@ -0,0 +1,27 @@ +const expanders = document.querySelectorAll(".frame .expander"); + +for (const expander of expanders) { + expander.addEventListener("click", (evt) => { + const currentSnippet = evt.currentTarget.closest(".frame"); + const snippetWrapper = currentSnippet.querySelector( + ".code-snippet-wrapper", + ); + if (currentSnippet.classList.contains("collapsed")) { + snippetWrapper.style.height = `${snippetWrapper.scrollHeight}px`; + currentSnippet.classList.remove("collapsed"); + } else { + currentSnippet.classList.add("collapsed"); + snippetWrapper.style.height = "0px"; + } + }); +} + +// init height for non-collapsed code snippets so animation will be show +// their first collapse +const nonCollapsedSnippets = document.querySelectorAll( + ".frame:not(.collapsed) .code-snippet-wrapper", +); + +for (const snippet of nonCollapsedSnippets) { + snippet.style.height = `${snippet.scrollHeight}px`; +} diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/styles.css b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/styles.css new file mode 100644 index 0000000..6b98b89 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/styles.css @@ -0,0 +1,121 @@ +:root { + --code-background-color: #f5f5f5; + --code-background-color-dark: #b8b8b8; + --code-color: #1d2534; + --code-color-light: #546996; + --code-font-family: Consolas, monospace; + --header-color: #303b55; + --warn-color: hsl(356, 92%, 60%); + --text-font-family: -apple-system, BlinkMacSystemFont, Helvetica, Arial, + sans-serif; +} + +html { + font-size: 20px; +} + +body { + font-family: var(--text-font-family); + font-size: 0.8rem; +} + +h1, +h2, +h3, +h4 { + color: var(--header-color); +} + +h4 { + font-size: 1rem; +} + +h3 { + font-size: 1.35rem; +} + +h2 { + font-size: 1.83rem; +} + +h3 span, +h4 span { + color: var(--warn-color); +} + +.frame { + background-color: var(--code-background-color); + border-radius: 0.2rem; + margin-bottom: 20px; +} + +.frame-name { + border-bottom: 1px solid var(--code-color-light); + padding: 10px 16px; +} + +.frame.collapsed .frame-name { + border-bottom: none; +} + +.frame-name span { + font-weight: 700; +} + +span.expander { + display: inline-block; + margin-right: 10px; + cursor: pointer; + transition: transform 0.33s ease-in-out; +} + +.frame.collapsed span.expander { + transform: rotate(-90deg); +} + +.frame-name span.breakable { + word-break: break-all; +} + +.code-snippet-wrapper { + height: auto; + overflow-y: hidden; + transition: height 0.33s ease-in-out; +} + +.frame.collapsed .code-snippet-wrapper { + height: 0; +} + +.code-snippet { + margin: 10px 16px; + border-spacing: 0 0; + color: var(--code-color); + font-family: var(--code-font-family); + font-size: 0.68rem; +} + +.code-snippet td { + padding: 0; + text-align: left; +} + +td.line_no { + color: var(--code-color-light); + min-width: 4ch; + padding-right: 20px; + text-align: right; + user-select: none; +} + +td.code_line { + width: 99%; +} + +tr.executed-line { + background-color: var(--code-background-color-dark); +} + +.cause-wrapper { + margin-top: 50px; +} diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/logging.py b/venv/lib/python3.11/site-packages/litestar/middleware/logging.py new file mode 100644 index 0000000..0094f10 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/logging.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Iterable + +from litestar.constants import ( + HTTP_RESPONSE_BODY, + HTTP_RESPONSE_START, +) +from litestar.data_extractors import ( + ConnectionDataExtractor, + RequestExtractorField, + ResponseDataExtractor, + ResponseExtractorField, +) +from litestar.enums import ScopeType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.middleware.base import AbstractMiddleware, DefineMiddleware +from litestar.serialization import encode_json +from litestar.utils.empty import value_or_default +from litestar.utils.scope import get_serializer_from_scope +from litestar.utils.scope.state import ScopeState + +__all__ = ("LoggingMiddleware", "LoggingMiddlewareConfig") + + +if TYPE_CHECKING: + from litestar.connection import Request + from litestar.types import ( + ASGIApp, + Logger, + Message, + Receive, + Scope, + Send, + Serializer, + ) + +try: + from structlog.types import BindableLogger + + structlog_installed = True +except ImportError: + BindableLogger = object # type: ignore[assignment, misc] + structlog_installed = False + + +class LoggingMiddleware(AbstractMiddleware): + """Logging middleware.""" + + __slots__ = ("config", "logger", "request_extractor", "response_extractor", "is_struct_logger") + + logger: Logger + + def __init__(self, app: ASGIApp, config: LoggingMiddlewareConfig) -> None: + """Initialize ``LoggingMiddleware``. + + Args: + app: The ``next`` ASGI app to call. + config: An instance of LoggingMiddlewareConfig. + """ + super().__init__( + app=app, scopes={ScopeType.HTTP}, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key + ) + self.is_struct_logger = structlog_installed + self.config = config + + self.request_extractor = ConnectionDataExtractor( + extract_body="body" in self.config.request_log_fields, + extract_client="client" in self.config.request_log_fields, + extract_content_type="content_type" in self.config.request_log_fields, + extract_cookies="cookies" in self.config.request_log_fields, + extract_headers="headers" in self.config.request_log_fields, + extract_method="method" in self.config.request_log_fields, + extract_path="path" in self.config.request_log_fields, + extract_path_params="path_params" in self.config.request_log_fields, + extract_query="query" in self.config.request_log_fields, + extract_scheme="scheme" in self.config.request_log_fields, + obfuscate_cookies=self.config.request_cookies_to_obfuscate, + obfuscate_headers=self.config.request_headers_to_obfuscate, + parse_body=self.is_struct_logger, + parse_query=self.is_struct_logger, + skip_parse_malformed_body=True, + ) + self.response_extractor = ResponseDataExtractor( + extract_body="body" in self.config.response_log_fields, + extract_headers="headers" in self.config.response_log_fields, + extract_status_code="status_code" in self.config.response_log_fields, + obfuscate_cookies=self.config.response_cookies_to_obfuscate, + obfuscate_headers=self.config.response_headers_to_obfuscate, + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + if not hasattr(self, "logger"): + self.logger = scope["app"].get_logger(self.config.logger_name) + self.is_struct_logger = structlog_installed and repr(self.logger).startswith("<BoundLoggerLazyProxy") + + if self.config.response_log_fields: + send = self.create_send_wrapper(scope=scope, send=send) + + if self.config.request_log_fields: + await self.log_request(scope=scope, receive=receive) + + await self.app(scope, receive, send) + + async def log_request(self, scope: Scope, receive: Receive) -> None: + """Extract request data and log the message. + + Args: + scope: The ASGI connection scope. + receive: ASGI receive callable + + Returns: + None + """ + extracted_data = await self.extract_request_data(request=scope["app"].request_class(scope, receive)) + self.log_message(values=extracted_data) + + def log_response(self, scope: Scope) -> None: + """Extract the response data and log the message. + + Args: + scope: The ASGI connection scope. + + Returns: + None + """ + extracted_data = self.extract_response_data(scope=scope) + self.log_message(values=extracted_data) + + def log_message(self, values: dict[str, Any]) -> None: + """Log a message. + + Args: + values: Extract values to log. + + Returns: + None + """ + message = values.pop("message") + if self.is_struct_logger: + self.logger.info(message, **values) + else: + value_strings = [f"{key}={value}" for key, value in values.items()] + log_message = f"{message}: {', '.join(value_strings)}" + self.logger.info(log_message) + + def _serialize_value(self, serializer: Serializer | None, value: Any) -> Any: + if not self.is_struct_logger and isinstance(value, (dict, list, tuple, set)): + value = encode_json(value, serializer) + return value.decode("utf-8", errors="backslashreplace") if isinstance(value, bytes) else value + + async def extract_request_data(self, request: Request) -> dict[str, Any]: + """Create a dictionary of values for the message. + + Args: + request: A :class:`Request <litestar.connection.Request>` instance. + + Returns: + An dict. + """ + + data: dict[str, Any] = {"message": self.config.request_log_message} + serializer = get_serializer_from_scope(request.scope) + + extracted_data = await self.request_extractor.extract(connection=request, fields=self.config.request_log_fields) + + for key in self.config.request_log_fields: + data[key] = self._serialize_value(serializer, extracted_data.get(key)) + return data + + def extract_response_data(self, scope: Scope) -> dict[str, Any]: + """Extract data from the response. + + Args: + scope: The ASGI connection scope. + + Returns: + An dict. + """ + data: dict[str, Any] = {"message": self.config.response_log_message} + serializer = get_serializer_from_scope(scope) + connection_state = ScopeState.from_scope(scope) + extracted_data = self.response_extractor( + messages=( + connection_state.log_context.pop(HTTP_RESPONSE_START), + connection_state.log_context.pop(HTTP_RESPONSE_BODY), + ), + ) + response_body_compressed = value_or_default(connection_state.response_compressed, False) + for key in self.config.response_log_fields: + value: Any + value = extracted_data.get(key) + if key == "body" and response_body_compressed: + if self.config.include_compressed_body: + data[key] = value + continue + data[key] = self._serialize_value(serializer, value) + return data + + def create_send_wrapper(self, scope: Scope, send: Send) -> Send: + """Create a ``send`` wrapper, which handles logging response data. + + Args: + scope: The ASGI connection scope. + send: The ASGI send function. + + Returns: + An ASGI send function. + """ + connection_state = ScopeState.from_scope(scope) + + async def send_wrapper(message: Message) -> None: + if message["type"] == HTTP_RESPONSE_START: + connection_state.log_context[HTTP_RESPONSE_START] = message + elif message["type"] == HTTP_RESPONSE_BODY: + connection_state.log_context[HTTP_RESPONSE_BODY] = message + self.log_response(scope=scope) + await send(message) + + return send_wrapper + + +@dataclass +class LoggingMiddlewareConfig: + """Configuration for ``LoggingMiddleware``""" + + exclude: str | list[str] | None = field(default=None) + """List of paths to exclude from logging.""" + exclude_opt_key: str | None = field(default=None) + """An identifier to use on routes to disable logging for a particular route.""" + include_compressed_body: bool = field(default=False) + """Include body of compressed response in middleware. If `"body"` not set in. + :attr:`response_log_fields <LoggingMiddlewareConfig.response_log_fields>` this config value is ignored. + """ + logger_name: str = field(default="litestar") + """Name of the logger to retrieve using `app.get_logger("<name>")`.""" + request_cookies_to_obfuscate: set[str] = field(default_factory=lambda: {"session"}) + """Request cookie keys to obfuscate. + + Obfuscated values are replaced with '*****'. + """ + request_headers_to_obfuscate: set[str] = field(default_factory=lambda: {"Authorization", "X-API-KEY"}) + """Request header keys to obfuscate. + + Obfuscated values are replaced with '*****'. + """ + response_cookies_to_obfuscate: set[str] = field(default_factory=lambda: {"session"}) + """Response cookie keys to obfuscate. + + Obfuscated values are replaced with '*****'. + """ + response_headers_to_obfuscate: set[str] = field(default_factory=lambda: {"Authorization", "X-API-KEY"}) + """Response header keys to obfuscate. + + Obfuscated values are replaced with '*****'. + """ + request_log_message: str = field(default="HTTP Request") + """Log message to prepend when logging a request.""" + response_log_message: str = field(default="HTTP Response") + """Log message to prepend when logging a response.""" + request_log_fields: Iterable[RequestExtractorField] = field( + default=( + "path", + "method", + "content_type", + "headers", + "cookies", + "query", + "path_params", + "body", + ) + ) + """Fields to extract and log from the request. + + Notes: + - The order of fields in the iterable determines the order of the log message logged out. + Thus, re-arranging the log-message is as simple as changing the iterable. + - To turn off logging of requests, use and empty iterable. + """ + response_log_fields: Iterable[ResponseExtractorField] = field( + default=( + "status_code", + "cookies", + "headers", + "body", + ) + ) + """Fields to extract and log from the response. The order of fields in the iterable determines the order of the log + message logged out. + + Notes: + - The order of fields in the iterable determines the order of the log message logged out. + Thus, re-arranging the log-message is as simple as changing the iterable. + - To turn off logging of responses, use and empty iterable. + """ + middleware_class: type[LoggingMiddleware] = field(default=LoggingMiddleware) + """Middleware class to use. + + Should be a subclass of [litestar.middleware.LoggingMiddleware]. + """ + + def __post_init__(self) -> None: + """Override default Pydantic type conversion for iterables. + + Args: + value: An iterable + + Returns: + The `value` argument cast as a tuple. + """ + if not isinstance(self.response_log_fields, Iterable): + raise ImproperlyConfiguredException("response_log_fields must be a valid Iterable") + + if not isinstance(self.request_log_fields, Iterable): + raise ImproperlyConfiguredException("request_log_fields must be a valid Iterable") + + self.response_log_fields = tuple(self.response_log_fields) + self.request_log_fields = tuple(self.request_log_fields) + + @property + def middleware(self) -> DefineMiddleware: + """Use this property to insert the config into a middleware list on one of the application layers. + + Examples: + .. code-block:: python + + from litestar import Litestar, Request, get + from litestar.logging import LoggingConfig + from litestar.middleware.logging import LoggingMiddlewareConfig + + logging_config = LoggingConfig() + + logging_middleware_config = LoggingMiddlewareConfig() + + + @get("/") + def my_handler(request: Request) -> None: ... + + + app = Litestar( + route_handlers=[my_handler], + logging_config=logging_config, + middleware=[logging_middleware_config.middleware], + ) + + Returns: + An instance of DefineMiddleware including ``self`` as the config kwarg value. + """ + return DefineMiddleware(self.middleware_class, config=self) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py b/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py new file mode 100644 index 0000000..cd767ba --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from time import time +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +from litestar.datastructures import MutableScopeHeaders +from litestar.enums import ScopeType +from litestar.exceptions import TooManyRequestsException +from litestar.middleware.base import AbstractMiddleware, DefineMiddleware +from litestar.serialization import decode_json, encode_json +from litestar.utils import ensure_async_callable + +__all__ = ("CacheObject", "RateLimitConfig", "RateLimitMiddleware") + + +if TYPE_CHECKING: + from typing import Awaitable + + from litestar import Litestar + from litestar.connection import Request + from litestar.stores.base import Store + from litestar.types import ASGIApp, Message, Receive, Scope, Send, SyncOrAsyncUnion + + +DurationUnit = Literal["second", "minute", "hour", "day"] + +DURATION_VALUES: dict[DurationUnit, int] = {"second": 1, "minute": 60, "hour": 3600, "day": 86400} + + +@dataclass +class CacheObject: + """Representation of a cached object's metadata.""" + + __slots__ = ("history", "reset") + + history: list[int] + reset: int + + +class RateLimitMiddleware(AbstractMiddleware): + """Rate-limiting middleware.""" + + __slots__ = ("app", "check_throttle_handler", "max_requests", "unit", "request_quota", "config") + + def __init__(self, app: ASGIApp, config: RateLimitConfig) -> None: + """Initialize ``RateLimitMiddleware``. + + Args: + app: The ``next`` ASGI app to call. + config: An instance of RateLimitConfig. + """ + super().__init__( + app=app, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key, scopes={ScopeType.HTTP} + ) + self.check_throttle_handler = cast("Callable[[Request], Awaitable[bool]] | None", config.check_throttle_handler) + self.config = config + self.max_requests: int = config.rate_limit[1] + self.unit: DurationUnit = config.rate_limit[0] + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + app = scope["app"] + request: Request[Any, Any, Any] = app.request_class(scope) + store = self.config.get_store_from_app(app) + if await self.should_check_request(request=request): + key = self.cache_key_from_request(request=request) + cache_object = await self.retrieve_cached_history(key, store) + if len(cache_object.history) >= self.max_requests: + raise TooManyRequestsException( + headers=self.create_response_headers(cache_object=cache_object) + if self.config.set_rate_limit_headers + else None + ) + await self.set_cached_history(key=key, cache_object=cache_object, store=store) + if self.config.set_rate_limit_headers: + send = self.create_send_wrapper(send=send, cache_object=cache_object) + + await self.app(scope, receive, send) # pyright: ignore + + def create_send_wrapper(self, send: Send, cache_object: CacheObject) -> Send: + """Create a ``send`` function that wraps the original send to inject response headers. + + Args: + send: The ASGI send function. + cache_object: A StorageObject instance. + + Returns: + Send wrapper callable. + """ + + async def send_wrapper(message: Message) -> None: + """Wrap the ASGI ``Send`` callable. + + Args: + message: An ASGI ``Message`` + + Returns: + None + """ + if message["type"] == "http.response.start": + message.setdefault("headers", []) + headers = MutableScopeHeaders(message) + for key, value in self.create_response_headers(cache_object=cache_object).items(): + headers.add(key, value) + await send(message) + + return send_wrapper + + def cache_key_from_request(self, request: Request[Any, Any, Any]) -> str: + """Get a cache-key from a ``Request`` + + Args: + request: A :class:`Request <.connection.Request>` instance. + + Returns: + A cache key. + """ + host = request.client.host if request.client else "anonymous" + identifier = request.headers.get("X-Forwarded-For") or request.headers.get("X-Real-IP") or host + route_handler = request.scope["route_handler"] + if getattr(route_handler, "is_mount", False): + identifier += "::mount" + + if getattr(route_handler, "is_static", False): + identifier += "::static" + + return f"{type(self).__name__}::{identifier}" + + async def retrieve_cached_history(self, key: str, store: Store) -> CacheObject: + """Retrieve a list of time stamps for the given duration unit. + + Args: + key: Cache key. + store: A :class:`Store <.stores.base.Store>` + + Returns: + An :class:`CacheObject`. + """ + duration = DURATION_VALUES[self.unit] + now = int(time()) + cached_string = await store.get(key) + if cached_string: + cache_object = CacheObject(**decode_json(value=cached_string)) + if cache_object.reset <= now: + return CacheObject(history=[], reset=now + duration) + + while cache_object.history and cache_object.history[-1] <= now - duration: + cache_object.history.pop() + return cache_object + + return CacheObject(history=[], reset=now + duration) + + async def set_cached_history(self, key: str, cache_object: CacheObject, store: Store) -> None: + """Store history extended with the current timestamp in cache. + + Args: + key: Cache key. + cache_object: A :class:`CacheObject`. + store: A :class:`Store <.stores.base.Store>` + + Returns: + None + """ + cache_object.history = [int(time()), *cache_object.history] + await store.set(key, encode_json(cache_object), expires_in=DURATION_VALUES[self.unit]) + + async def should_check_request(self, request: Request[Any, Any, Any]) -> bool: + """Return a boolean indicating if a request should be checked for rate limiting. + + Args: + request: A :class:`Request <.connection.Request>` instance. + + Returns: + Boolean dictating whether the request should be checked for rate-limiting. + """ + if self.check_throttle_handler: + return await self.check_throttle_handler(request) + return True + + def create_response_headers(self, cache_object: CacheObject) -> dict[str, str]: + """Create ratelimit response headers. + + Notes: + * see the `IETF RateLimit draft <https://datatracker.ietf.org/doc/draft-ietf-httpapi-ratelimit-headers/>_` + + Args: + cache_object:A :class:`CacheObject`. + + Returns: + A dict of http headers. + """ + remaining_requests = str( + len(cache_object.history) - self.max_requests if len(cache_object.history) <= self.max_requests else 0 + ) + + return { + self.config.rate_limit_policy_header_key: f"{self.max_requests}; w={DURATION_VALUES[self.unit]}", + self.config.rate_limit_limit_header_key: str(self.max_requests), + self.config.rate_limit_remaining_header_key: remaining_requests, + self.config.rate_limit_reset_header_key: str(int(time()) - cache_object.reset), + } + + +@dataclass +class RateLimitConfig: + """Configuration for ``RateLimitMiddleware``""" + + rate_limit: tuple[DurationUnit, int] + """A tuple containing a time unit (second, minute, hour, day) and quantity, e.g. ("day", 1) or ("minute", 5).""" + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns to skip in the rate limiting middleware.""" + exclude_opt_key: str | None = field(default=None) + """An identifier to use on routes to disable rate limiting for a particular route.""" + check_throttle_handler: Callable[[Request[Any, Any, Any]], SyncOrAsyncUnion[bool]] | None = field(default=None) + """Handler callable that receives the request instance, returning a boolean dictating whether or not the request + should be checked for rate limiting. + """ + middleware_class: type[RateLimitMiddleware] = field(default=RateLimitMiddleware) + """The middleware class to use.""" + set_rate_limit_headers: bool = field(default=True) + """Boolean dictating whether to set the rate limit headers on the response.""" + rate_limit_policy_header_key: str = field(default="RateLimit-Policy") + """Key to use for the rate limit policy header.""" + rate_limit_remaining_header_key: str = field(default="RateLimit-Remaining") + """Key to use for the rate limit remaining header.""" + rate_limit_reset_header_key: str = field(default="RateLimit-Reset") + """Key to use for the rate limit reset header.""" + rate_limit_limit_header_key: str = field(default="RateLimit-Limit") + """Key to use for the rate limit limit header.""" + store: str = "rate_limit" + """Name of the :class:`Store <.stores.base.Store>` to use""" + + def __post_init__(self) -> None: + if self.check_throttle_handler: + self.check_throttle_handler = ensure_async_callable(self.check_throttle_handler) # type: ignore[arg-type] + + @property + def middleware(self) -> DefineMiddleware: + """Use this property to insert the config into a middleware list on one of the application layers. + + Examples: + .. code-block:: python + + from litestar import Litestar, Request, get + from litestar.middleware.rate_limit import RateLimitConfig + + # limit to 10 requests per minute, excluding the schema path + throttle_config = RateLimitConfig(rate_limit=("minute", 10), exclude=["/schema"]) + + + @get("/") + def my_handler(request: Request) -> None: ... + + + app = Litestar(route_handlers=[my_handler], middleware=[throttle_config.middleware]) + + Returns: + An instance of :class:`DefineMiddleware <.middleware.base.DefineMiddleware>` including ``self`` as the + config kwarg value. + """ + return DefineMiddleware(self.middleware_class, config=self) + + def get_store_from_app(self, app: Litestar) -> Store: + """Get the store defined in :attr:`store` from an :class:`Litestar <.app.Litestar>` instance.""" + return app.stores.get(self.store) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/response_cache.py b/venv/lib/python3.11/site-packages/litestar/middleware/response_cache.py new file mode 100644 index 0000000..62dcde6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/response_cache.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from msgspec.msgpack import encode as encode_msgpack + +from litestar import Request +from litestar.constants import HTTP_RESPONSE_BODY, HTTP_RESPONSE_START +from litestar.enums import ScopeType +from litestar.utils.empty import value_or_default +from litestar.utils.scope.state import ScopeState + +from .base import AbstractMiddleware + +if TYPE_CHECKING: + from litestar.config.response_cache import ResponseCacheConfig + from litestar.handlers import HTTPRouteHandler + from litestar.types import ASGIApp, HTTPScope, Message, Receive, Scope, Send + +__all__ = ["ResponseCacheMiddleware"] + + +class ResponseCacheMiddleware(AbstractMiddleware): + def __init__(self, app: ASGIApp, config: ResponseCacheConfig) -> None: + self.config = config + super().__init__(app=app, scopes={ScopeType.HTTP}) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + route_handler = cast("HTTPRouteHandler", scope["route_handler"]) + + expires_in: int | None = None + if route_handler.cache is True: + expires_in = self.config.default_expiration + elif route_handler.cache is not False and isinstance(route_handler.cache, int): + expires_in = route_handler.cache + + connection_state = ScopeState.from_scope(scope) + + messages: list[Message] = [] + + async def wrapped_send(message: Message) -> None: + if not value_or_default(connection_state.is_cached, False): + if message["type"] == HTTP_RESPONSE_START: + do_cache = connection_state.do_cache = self.config.cache_response_filter( + cast("HTTPScope", scope), message["status"] + ) + if do_cache: + messages.append(message) + elif value_or_default(connection_state.do_cache, False): + messages.append(message) + + if messages and message["type"] == HTTP_RESPONSE_BODY and not message["more_body"]: + key = (route_handler.cache_key_builder or self.config.key_builder)(Request(scope)) + store = self.config.get_store_from_app(scope["app"]) + await store.set(key, encode_msgpack(messages), expires_in=expires_in) + await send(message) + + await self.app(scope, receive, wrapped_send) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/__init__.py b/venv/lib/python3.11/site-packages/litestar/middleware/session/__init__.py new file mode 100644 index 0000000..1ca9c17 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/__init__.py @@ -0,0 +1,3 @@ +from .base import SessionMiddleware + +__all__ = ("SessionMiddleware",) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8748ce3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..68a8b9c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/client_side.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/client_side.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..692f54c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/client_side.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/server_side.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/server_side.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..bd2373c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/server_side.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/base.py b/venv/lib/python3.11/site-packages/litestar/middleware/session/base.py new file mode 100644 index 0000000..a823848 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/base.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Generic, + Literal, + TypeVar, + cast, +) + +from litestar.connection import ASGIConnection +from litestar.enums import ScopeType +from litestar.middleware.base import AbstractMiddleware, DefineMiddleware +from litestar.serialization import decode_json, encode_json +from litestar.utils import get_serializer_from_scope + +__all__ = ("BaseBackendConfig", "BaseSessionBackend", "SessionMiddleware") + + +if TYPE_CHECKING: + from litestar.types import ASGIApp, Message, Receive, Scope, Scopes, ScopeSession, Send + +ONE_DAY_IN_SECONDS = 60 * 60 * 24 + +ConfigT = TypeVar("ConfigT", bound="BaseBackendConfig") +BaseSessionBackendT = TypeVar("BaseSessionBackendT", bound="BaseSessionBackend") + + +class BaseBackendConfig(ABC, Generic[BaseSessionBackendT]): # pyright: ignore + """Configuration for Session middleware backends.""" + + _backend_class: type[BaseSessionBackendT] # pyright: ignore + + key: str + """Key to use for the cookie inside the header, e.g. ``session=<data>`` where ``session`` is the cookie key and + ``<data>`` is the session data. + + Notes: + - If a session cookie exceeds 4KB in size it is split. In this case the key will be of the format + ``session-{segment number}``. + + """ + max_age: int + """Maximal age of the cookie before its invalidated.""" + scopes: Scopes = {ScopeType.HTTP, ScopeType.WEBSOCKET} + """Scopes for the middleware - options are ``http`` and ``websocket`` with the default being both""" + path: str + """Path fragment that must exist in the request url for the cookie to be valid. + + Defaults to ``'/'``. + """ + domain: str | None + """Domain for which the cookie is valid.""" + secure: bool + """Https is required for the cookie.""" + httponly: bool + """Forbids javascript to access the cookie via 'Document.cookie'.""" + samesite: Literal["lax", "strict", "none"] + """Controls whether or not a cookie is sent with cross-site requests. + + Defaults to ``lax``. + """ + exclude: str | list[str] | None + """A pattern or list of patterns to skip in the session middleware.""" + exclude_opt_key: str + """An identifier to use on routes to disable the session middleware for a particular route.""" + + @property + def middleware(self) -> DefineMiddleware: + """Use this property to insert the config into a middleware list on one of the application layers. + + Examples: + .. code-block:: python + + from os import urandom + + from litestar import Litestar, Request, get + from litestar.middleware.sessions.cookie_backend import CookieBackendConfig + + session_config = CookieBackendConfig(secret=urandom(16)) + + + @get("/") + def my_handler(request: Request) -> None: ... + + + app = Litestar(route_handlers=[my_handler], middleware=[session_config.middleware]) + + + Returns: + An instance of DefineMiddleware including ``self`` as the config kwarg value. + """ + return DefineMiddleware(SessionMiddleware, backend=self._backend_class(config=self)) + + +class BaseSessionBackend(ABC, Generic[ConfigT]): + """Abstract session backend defining the interface between a storage mechanism and the application + :class:`SessionMiddleware`. + + This serves as the base class for all client- and server-side backends + """ + + __slots__ = ("config",) + + def __init__(self, config: ConfigT) -> None: + """Initialize ``BaseSessionBackend`` + + Args: + config: A instance of a subclass of ``BaseBackendConfig`` + """ + self.config = config + + @staticmethod + def serialize_data(data: ScopeSession, scope: Scope | None = None) -> bytes: + """Serialize data into bytes for storage in the backend. + + Args: + data: Session data of the current scope. + scope: A scope, if applicable, from which to extract a serializer. + + Notes: + - The serializer will be extracted from ``scope`` or fall back to + :func:`default_serializer <.serialization.default_serializer>` + + Returns: + ``data`` serialized as bytes. + """ + serializer = get_serializer_from_scope(scope) if scope else None + return encode_json(data, serializer) + + @staticmethod + def deserialize_data(data: Any) -> dict[str, Any]: + """Deserialize data into a dictionary for use in the application scope. + + Args: + data: Data to be deserialized + + Returns: + Deserialized data as a dictionary + """ + return cast("dict[str, Any]", decode_json(value=data)) + + @abstractmethod + def get_session_id(self, connection: ASGIConnection) -> str | None: + """Try to fetch session id from connection ScopeState. If one does not exist, generate one. + + Args: + connection: Originating ASGIConnection containing the scope + + Returns: + Session id str or None if the concept of a session id does not apply. + """ + + @abstractmethod + async def store_in_message(self, scope_session: ScopeSession, message: Message, connection: ASGIConnection) -> None: + """Store the necessary information in the outgoing ``Message`` + + Args: + scope_session: Current session to store + message: Outgoing send-message + connection: Originating ASGIConnection containing the scope + + Returns: + None + """ + + @abstractmethod + async def load_from_connection(self, connection: ASGIConnection) -> dict[str, Any]: + """Load session data from a connection and return it as a dictionary to be used in the current application + scope. + + Args: + connection: An ASGIConnection instance + + Returns: + The session data + + Notes: + - This should not modify the connection's scope. The data returned by this + method will be stored in the application scope by the middleware + + """ + + +class SessionMiddleware(AbstractMiddleware, Generic[BaseSessionBackendT]): + """Litestar session middleware for storing session data.""" + + def __init__(self, app: ASGIApp, backend: BaseSessionBackendT) -> None: + """Initialize ``SessionMiddleware`` + + Args: + app: An ASGI application + backend: A :class:`BaseSessionBackend` instance used to store and retrieve session data + """ + + super().__init__( + app=app, + exclude=backend.config.exclude, + exclude_opt_key=backend.config.exclude_opt_key, + scopes=backend.config.scopes, + ) + self.backend = backend + + def create_send_wrapper(self, connection: ASGIConnection) -> Callable[[Message], Awaitable[None]]: + """Create a wrapper for the ASGI send function, which handles setting the cookies on the outgoing response. + + Args: + connection: ASGIConnection + + Returns: + None + """ + + async def wrapped_send(message: Message) -> None: + """Wrap the ``send`` function. + + Declared in local scope to make use of closure values. + + Args: + message: An ASGI message. + + Returns: + None + """ + if message["type"] != "http.response.start": + await connection.send(message) + return + + scope_session = connection.scope.get("session") + + await self.backend.store_in_message(scope_session, message, connection) + await connection.send(message) + + return wrapped_send + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI-callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + + connection = ASGIConnection[Any, Any, Any, Any](scope, receive=receive, send=send) + scope["session"] = await self.backend.load_from_connection(connection) + connection._connection_state.session_id = self.backend.get_session_id(connection) # pyright: ignore [reportGeneralTypeIssues] + + await self.app(scope, receive, self.create_send_wrapper(connection)) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/client_side.py b/venv/lib/python3.11/site-packages/litestar/middleware/session/client_side.py new file mode 100644 index 0000000..f709410 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/client_side.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import binascii +import contextlib +import re +import time +from base64 import b64decode, b64encode +from dataclasses import dataclass, field +from os import urandom +from typing import TYPE_CHECKING, Any, Literal + +from litestar.datastructures import MutableScopeHeaders +from litestar.datastructures.cookie import Cookie +from litestar.enums import ScopeType +from litestar.exceptions import ( + ImproperlyConfiguredException, + MissingDependencyException, +) +from litestar.serialization import decode_json, encode_json +from litestar.types import Empty, Scopes +from litestar.utils.dataclass import extract_dataclass_items + +from .base import ONE_DAY_IN_SECONDS, BaseBackendConfig, BaseSessionBackend + +__all__ = ("ClientSideSessionBackend", "CookieBackendConfig") + + +try: + from cryptography.exceptions import InvalidTag + from cryptography.hazmat.primitives.ciphers.aead import AESGCM +except ImportError as e: + raise MissingDependencyException("cryptography") from e + +if TYPE_CHECKING: + from litestar.connection import ASGIConnection + from litestar.types import Message, Scope, ScopeSession + +NONCE_SIZE = 12 +CHUNK_SIZE = 4096 - 64 +AAD = b"additional_authenticated_data=" + + +class ClientSideSessionBackend(BaseSessionBackend["CookieBackendConfig"]): + """Cookie backend for SessionMiddleware.""" + + __slots__ = ("aesgcm", "cookie_re") + + def __init__(self, config: CookieBackendConfig) -> None: + """Initialize ``ClientSideSessionBackend``. + + Args: + config: SessionCookieConfig instance. + """ + super().__init__(config) + self.aesgcm = AESGCM(config.secret) + self.cookie_re = re.compile(rf"{self.config.key}(?:-\d+)?") + + def dump_data(self, data: Any, scope: Scope | None = None) -> list[bytes]: + """Given serializable data, including pydantic models and numpy types, dump it into a bytes string, encrypt, + encode and split it into chunks of the desirable size. + + Args: + data: Data to serialize, encrypt, encode and chunk. + scope: The ASGI connection scope. + + Notes: + - The returned list is composed of a chunks of a single base64 encoded + string that is encrypted using AES-CGM. + + Returns: + List of encoded bytes string of a maximum length equal to the ``CHUNK_SIZE`` constant. + """ + serialized = self.serialize_data(data, scope) + associated_data = encode_json({"expires_at": round(time.time()) + self.config.max_age}) + nonce = urandom(NONCE_SIZE) + encrypted = self.aesgcm.encrypt(nonce, serialized, associated_data=associated_data) + encoded = b64encode(nonce + encrypted + AAD + associated_data) + return [encoded[i : i + CHUNK_SIZE] for i in range(0, len(encoded), CHUNK_SIZE)] + + def load_data(self, data: list[bytes]) -> dict[str, Any]: + """Given a list of strings, decodes them into the session object. + + Args: + data: A list of strings derived from the request's session cookie(s). + + Returns: + A deserialized session value. + """ + decoded = b64decode(b"".join(data)) + nonce = decoded[:NONCE_SIZE] + aad_starts_from = decoded.find(AAD) + associated_data = decoded[aad_starts_from:].replace(AAD, b"") if aad_starts_from != -1 else None + if associated_data and decode_json(value=associated_data)["expires_at"] > round(time.time()): + encrypted_session = decoded[NONCE_SIZE:aad_starts_from] + decrypted = self.aesgcm.decrypt(nonce, encrypted_session, associated_data=associated_data) + return self.deserialize_data(decrypted) + return {} + + def get_cookie_keys(self, connection: ASGIConnection) -> list[str]: + """Return a list of cookie-keys from the connection if they match the session-cookie pattern. + + Args: + connection: An ASGIConnection instance + + Returns: + A list of session-cookie keys + """ + return sorted(key for key in connection.cookies if self.cookie_re.fullmatch(key)) + + def _create_session_cookies(self, data: list[bytes], cookie_params: dict[str, Any] | None = None) -> list[Cookie]: + """Create a list of cookies containing the session data. + If the data is split into multiple cookies, the key will be of the format ``session-{segment number}``, + however if only one cookie is needed, the key will be ``session``. + """ + if cookie_params is None: + cookie_params = dict( + extract_dataclass_items( + self.config, + exclude_none=True, + include={f for f in Cookie.__dict__ if f not in ("key", "secret")}, + ) + ) + + if len(data) == 1: + return [ + Cookie( + value=data[0].decode("utf-8"), + key=self.config.key, + **cookie_params, + ) + ] + + return [ + Cookie( + value=datum.decode("utf-8"), + key=f"{self.config.key}-{i}", + **cookie_params, + ) + for i, datum in enumerate(data) + ] + + async def store_in_message(self, scope_session: ScopeSession, message: Message, connection: ASGIConnection) -> None: + """Store data from ``scope_session`` in ``Message`` in the form of cookies. If the contents of ``scope_session`` + are too large to fit a single cookie, it will be split across several cookies, following the naming scheme of + ``<cookie key>-<n>``. If the session is empty or shrinks, cookies will be cleared by setting their value to + ``"null"`` + + Args: + scope_session: Current session to store + message: Outgoing send-message + connection: Originating ASGIConnection containing the scope + + Returns: + None + """ + + scope = connection.scope + headers = MutableScopeHeaders.from_message(message) + cookie_keys = self.get_cookie_keys(connection) + + if scope_session and scope_session is not Empty: + data = self.dump_data(scope_session, scope=scope) + cookie_params = dict( + extract_dataclass_items( + self.config, + exclude_none=True, + include={f for f in Cookie.__dict__ if f not in ("key", "secret")}, + ) + ) + for cookie in self._create_session_cookies(data, cookie_params): + headers.add("Set-Cookie", cookie.to_header(header="")) + # Cookies with the same key overwrite the earlier cookie with that key. To expire earlier session + # cookies, first check how many session cookies will not be overwritten in this upcoming response. + # If leftover cookies are greater than or equal to 1, that means older session cookies have to be + # expired and their names are in cookie_keys. + cookies_to_clear = cookie_keys[len(data) :] if len(cookie_keys) - len(data) > 0 else [] + else: + cookies_to_clear = cookie_keys + + for cookie_key in cookies_to_clear: + cookie_params = dict( + extract_dataclass_items( + self.config, + exclude_none=True, + include={f for f in Cookie.__dict__ if f not in ("key", "secret", "max_age")}, + ) + ) + headers.add( + "Set-Cookie", + Cookie(value="null", key=cookie_key, expires=0, **cookie_params).to_header(header=""), + ) + + async def load_from_connection(self, connection: ASGIConnection) -> dict[str, Any]: + """Load session data from a connection's session-cookies and return it as a dictionary. + + Args: + connection: Originating ASGIConnection + + Returns: + The session data + """ + if cookie_keys := self.get_cookie_keys(connection): + data = [connection.cookies[key].encode("utf-8") for key in cookie_keys] + # If these exceptions occur, the session must remain empty so do nothing. + with contextlib.suppress(InvalidTag, binascii.Error): + return self.load_data(data) + return {} + + def get_session_id(self, connection: ASGIConnection) -> str | None: + return None + + +@dataclass +class CookieBackendConfig(BaseBackendConfig[ClientSideSessionBackend]): # pyright: ignore + """Configuration for [SessionMiddleware] middleware.""" + + _backend_class = ClientSideSessionBackend + + secret: bytes + """A secret key to use for generating an encryption key. + + Must have a length of 16 (128 bits), 24 (192 bits) or 32 (256 bits) characters. + """ + key: str = field(default="session") + """Key to use for the cookie inside the header, e.g. ``session=<data>`` where ``session`` is the cookie key and + ``<data>`` is the session data. + + Notes: + - If a session cookie exceeds 4KB in size it is split. In this case the key will be of the format + ``session-{segment number}``. + + """ + max_age: int = field(default=ONE_DAY_IN_SECONDS * 14) + """Maximal age of the cookie before its invalidated.""" + scopes: Scopes = field(default_factory=lambda: {ScopeType.HTTP, ScopeType.WEBSOCKET}) + """Scopes for the middleware - options are ``http`` and ``websocket`` with the default being both""" + path: str = field(default="/") + """Path fragment that must exist in the request url for the cookie to be valid. + + Defaults to ``'/'``. + """ + domain: str | None = field(default=None) + """Domain for which the cookie is valid.""" + secure: bool = field(default=False) + """Https is required for the cookie.""" + httponly: bool = field(default=True) + """Forbids javascript to access the cookie via 'Document.cookie'.""" + samesite: Literal["lax", "strict", "none"] = field(default="lax") + """Controls whether or not a cookie is sent with cross-site requests. + + Defaults to ``lax``. + """ + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns to skip in the session middleware.""" + exclude_opt_key: str = field(default="skip_session") + """An identifier to use on routes to disable the session middleware for a particular route.""" + + def __post_init__(self) -> None: + if len(self.key) < 1 or len(self.key) > 256: + raise ImproperlyConfiguredException("key must be a string with a length between 1-256") + if self.max_age < 1: + raise ImproperlyConfiguredException("max_age must be greater than 0") + if len(self.secret) not in {16, 24, 32}: + raise ImproperlyConfiguredException("secret length must be 16 (128 bit), 24 (192 bit) or 32 (256 bit)") diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/server_side.py b/venv/lib/python3.11/site-packages/litestar/middleware/session/server_side.py new file mode 100644 index 0000000..91708ac --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/server_side.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import secrets +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal + +from litestar.datastructures import Cookie, MutableScopeHeaders +from litestar.enums import ScopeType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.middleware.session.base import ONE_DAY_IN_SECONDS, BaseBackendConfig, BaseSessionBackend +from litestar.types import Empty, Message, Scopes, ScopeSession +from litestar.utils.dataclass import extract_dataclass_items + +__all__ = ("ServerSideSessionBackend", "ServerSideSessionConfig") + + +if TYPE_CHECKING: + from litestar import Litestar + from litestar.connection import ASGIConnection + from litestar.stores.base import Store + + +class ServerSideSessionBackend(BaseSessionBackend["ServerSideSessionConfig"]): + """Base class for server-side backends. + + Implements :class:`BaseSessionBackend` and defines and interface which subclasses can + implement to facilitate the storage of session data. + """ + + def __init__(self, config: ServerSideSessionConfig) -> None: + """Initialize ``ServerSideSessionBackend`` + + Args: + config: A subclass of ``ServerSideSessionConfig`` + """ + super().__init__(config=config) + + async def get(self, session_id: str, store: Store) -> bytes | None: + """Retrieve data associated with ``session_id``. + + Args: + session_id: The session-ID + store: Store to retrieve the session data from + + Returns: + The session data, if existing, otherwise ``None``. + """ + max_age = int(self.config.max_age) if self.config.max_age is not None else None + return await store.get(session_id, renew_for=max_age if self.config.renew_on_access else None) + + async def set(self, session_id: str, data: bytes, store: Store) -> None: + """Store ``data`` under the ``session_id`` for later retrieval. + + If there is already data associated with ``session_id``, replace + it with ``data`` and reset its expiry time + + Args: + session_id: The session-ID + data: Serialized session data + store: Store to save the session data in + + Returns: + None + """ + expires_in = int(self.config.max_age) if self.config.max_age is not None else None + await store.set(session_id, data, expires_in=expires_in) + + async def delete(self, session_id: str, store: Store) -> None: + """Delete the data associated with ``session_id``. Fails silently if no such session-ID exists. + + Args: + session_id: The session-ID + store: Store to delete the session data from + + Returns: + None + """ + await store.delete(session_id) + + def get_session_id(self, connection: ASGIConnection) -> str: + """Try to fetch session id from the connection. If one does not exist, generate one. + + If a session ID already exists in the cookies, it is returned. + If there is no ID in the cookies but one in the connection state, then the session exists but has not yet + been returned to the user. + Otherwise, a new session must be created. + + Args: + connection: Originating ASGIConnection containing the scope + Returns: + Session id str or None if the concept of a session id does not apply. + """ + session_id = connection.cookies.get(self.config.key) + if not session_id or session_id == "null": + session_id = connection.get_session_id() + if not session_id: + session_id = self.generate_session_id() + return session_id + + def generate_session_id(self) -> str: + """Generate a new session-ID, with + n=:attr:`session_id_bytes <ServerSideSessionConfig.session_id_bytes>` random bytes. + + Returns: + A session-ID + """ + return secrets.token_hex(self.config.session_id_bytes) + + async def store_in_message(self, scope_session: ScopeSession, message: Message, connection: ASGIConnection) -> None: + """Store the necessary information in the outgoing ``Message`` by setting a cookie containing the session-ID. + + If the session is empty, a null-cookie will be set. Otherwise, the serialised + data will be stored using :meth:`set <ServerSideSessionBackend.set>`, under the current session-id. If no session-ID + exists, a new ID will be generated using :meth:`generate_session_id <ServerSideSessionBackend.generate_session_id>`. + + Args: + scope_session: Current session to store + message: Outgoing send-message + connection: Originating ASGIConnection containing the scope + + Returns: + None + """ + scope = connection.scope + store = self.config.get_store_from_app(scope["app"]) + headers = MutableScopeHeaders.from_message(message) + session_id = self.get_session_id(connection) + + cookie_params = dict(extract_dataclass_items(self.config, exclude_none=True, include=Cookie.__dict__.keys())) + + if scope_session is Empty: + await self.delete(session_id, store=store) + headers.add( + "Set-Cookie", + Cookie(value="null", key=self.config.key, expires=0, **cookie_params).to_header(header=""), + ) + else: + serialised_data = self.serialize_data(scope_session, scope) + await self.set(session_id=session_id, data=serialised_data, store=store) + headers.add( + "Set-Cookie", Cookie(value=session_id, key=self.config.key, **cookie_params).to_header(header="") + ) + + async def load_from_connection(self, connection: ASGIConnection) -> dict[str, Any]: + """Load session data from a connection and return it as a dictionary to be used in the current application + scope. + + The session-ID will be gathered from a cookie with the key set in + :attr:`BaseBackendConfig.key`. If a cookie is found, its value will be used as the session-ID and data associated + with this ID will be loaded using :meth:`get <ServerSideSessionBackend.get>`. + If no cookie was found or no data was loaded from the store, this will return an + empty dictionary. + + Args: + connection: An ASGIConnection instance + + Returns: + The current session data + """ + if session_id := connection.cookies.get(self.config.key): + store = self.config.get_store_from_app(connection.scope["app"]) + data = await self.get(session_id, store=store) + if data is not None: + return self.deserialize_data(data) + return {} + + +@dataclass +class ServerSideSessionConfig(BaseBackendConfig[ServerSideSessionBackend]): # pyright: ignore + """Base configuration for server side backends.""" + + _backend_class = ServerSideSessionBackend + + session_id_bytes: int = field(default=32) + """Number of bytes used to generate a random session-ID.""" + renew_on_access: bool = field(default=False) + """Renew expiry times of sessions when they're being accessed""" + key: str = field(default="session") + """Key to use for the cookie inside the header, e.g. ``session=<data>`` where ``session`` is the cookie key and + ``<data>`` is the session data. + + Notes: + - If a session cookie exceeds 4KB in size it is split. In this case the key will be of the format + ``session-{segment number}``. + + """ + max_age: int = field(default=ONE_DAY_IN_SECONDS * 14) + """Maximal age of the cookie before its invalidated.""" + scopes: Scopes = field(default_factory=lambda: {ScopeType.HTTP, ScopeType.WEBSOCKET}) + """Scopes for the middleware - options are ``http`` and ``websocket`` with the default being both""" + path: str = field(default="/") + """Path fragment that must exist in the request url for the cookie to be valid. + + Defaults to ``'/'``. + """ + domain: str | None = field(default=None) + """Domain for which the cookie is valid.""" + secure: bool = field(default=False) + """Https is required for the cookie.""" + httponly: bool = field(default=True) + """Forbids javascript to access the cookie via 'Document.cookie'.""" + samesite: Literal["lax", "strict", "none"] = field(default="lax") + """Controls whether or not a cookie is sent with cross-site requests. Defaults to ``lax``.""" + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns to skip in the session middleware.""" + exclude_opt_key: str = field(default="skip_session") + """An identifier to use on routes to disable the session middleware for a particular route.""" + store: str = "sessions" + """Name of the :class:`Store <.stores.base.Store>` to use""" + + def __post_init__(self) -> None: + if len(self.key) < 1 or len(self.key) > 256: + raise ImproperlyConfiguredException("key must be a string with a length between 1-256") + if self.max_age < 1: + raise ImproperlyConfiguredException("max_age must be greater than 0") + + def get_store_from_app(self, app: Litestar) -> Store: + """Get the store defined in :attr:`store` from an :class:`Litestar <.app.Litestar>` instance""" + return app.stores.get(self.store) |