From 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 Mon Sep 17 00:00:00 2001 From: cyfraeviolae Date: Wed, 3 Apr 2024 03:10:44 -0400 Subject: venv --- .../litestar/middleware/rate_limit.py | 275 +++++++++++++++++++++ 1 file changed, 275 insertions(+) create mode 100644 venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py (limited to 'venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py') diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py b/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py new file mode 100644 index 0000000..cd767ba --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from time import time +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +from litestar.datastructures import MutableScopeHeaders +from litestar.enums import ScopeType +from litestar.exceptions import TooManyRequestsException +from litestar.middleware.base import AbstractMiddleware, DefineMiddleware +from litestar.serialization import decode_json, encode_json +from litestar.utils import ensure_async_callable + +__all__ = ("CacheObject", "RateLimitConfig", "RateLimitMiddleware") + + +if TYPE_CHECKING: + from typing import Awaitable + + from litestar import Litestar + from litestar.connection import Request + from litestar.stores.base import Store + from litestar.types import ASGIApp, Message, Receive, Scope, Send, SyncOrAsyncUnion + + +DurationUnit = Literal["second", "minute", "hour", "day"] + +DURATION_VALUES: dict[DurationUnit, int] = {"second": 1, "minute": 60, "hour": 3600, "day": 86400} + + +@dataclass +class CacheObject: + """Representation of a cached object's metadata.""" + + __slots__ = ("history", "reset") + + history: list[int] + reset: int + + +class RateLimitMiddleware(AbstractMiddleware): + """Rate-limiting middleware.""" + + __slots__ = ("app", "check_throttle_handler", "max_requests", "unit", "request_quota", "config") + + def __init__(self, app: ASGIApp, config: RateLimitConfig) -> None: + """Initialize ``RateLimitMiddleware``. + + Args: + app: The ``next`` ASGI app to call. + config: An instance of RateLimitConfig. + """ + super().__init__( + app=app, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key, scopes={ScopeType.HTTP} + ) + self.check_throttle_handler = cast("Callable[[Request], Awaitable[bool]] | None", config.check_throttle_handler) + self.config = config + self.max_requests: int = config.rate_limit[1] + self.unit: DurationUnit = config.rate_limit[0] + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + app = scope["app"] + request: Request[Any, Any, Any] = app.request_class(scope) + store = self.config.get_store_from_app(app) + if await self.should_check_request(request=request): + key = self.cache_key_from_request(request=request) + cache_object = await self.retrieve_cached_history(key, store) + if len(cache_object.history) >= self.max_requests: + raise TooManyRequestsException( + headers=self.create_response_headers(cache_object=cache_object) + if self.config.set_rate_limit_headers + else None + ) + await self.set_cached_history(key=key, cache_object=cache_object, store=store) + if self.config.set_rate_limit_headers: + send = self.create_send_wrapper(send=send, cache_object=cache_object) + + await self.app(scope, receive, send) # pyright: ignore + + def create_send_wrapper(self, send: Send, cache_object: CacheObject) -> Send: + """Create a ``send`` function that wraps the original send to inject response headers. + + Args: + send: The ASGI send function. + cache_object: A StorageObject instance. + + Returns: + Send wrapper callable. + """ + + async def send_wrapper(message: Message) -> None: + """Wrap the ASGI ``Send`` callable. + + Args: + message: An ASGI ``Message`` + + Returns: + None + """ + if message["type"] == "http.response.start": + message.setdefault("headers", []) + headers = MutableScopeHeaders(message) + for key, value in self.create_response_headers(cache_object=cache_object).items(): + headers.add(key, value) + await send(message) + + return send_wrapper + + def cache_key_from_request(self, request: Request[Any, Any, Any]) -> str: + """Get a cache-key from a ``Request`` + + Args: + request: A :class:`Request <.connection.Request>` instance. + + Returns: + A cache key. + """ + host = request.client.host if request.client else "anonymous" + identifier = request.headers.get("X-Forwarded-For") or request.headers.get("X-Real-IP") or host + route_handler = request.scope["route_handler"] + if getattr(route_handler, "is_mount", False): + identifier += "::mount" + + if getattr(route_handler, "is_static", False): + identifier += "::static" + + return f"{type(self).__name__}::{identifier}" + + async def retrieve_cached_history(self, key: str, store: Store) -> CacheObject: + """Retrieve a list of time stamps for the given duration unit. + + Args: + key: Cache key. + store: A :class:`Store <.stores.base.Store>` + + Returns: + An :class:`CacheObject`. + """ + duration = DURATION_VALUES[self.unit] + now = int(time()) + cached_string = await store.get(key) + if cached_string: + cache_object = CacheObject(**decode_json(value=cached_string)) + if cache_object.reset <= now: + return CacheObject(history=[], reset=now + duration) + + while cache_object.history and cache_object.history[-1] <= now - duration: + cache_object.history.pop() + return cache_object + + return CacheObject(history=[], reset=now + duration) + + async def set_cached_history(self, key: str, cache_object: CacheObject, store: Store) -> None: + """Store history extended with the current timestamp in cache. + + Args: + key: Cache key. + cache_object: A :class:`CacheObject`. + store: A :class:`Store <.stores.base.Store>` + + Returns: + None + """ + cache_object.history = [int(time()), *cache_object.history] + await store.set(key, encode_json(cache_object), expires_in=DURATION_VALUES[self.unit]) + + async def should_check_request(self, request: Request[Any, Any, Any]) -> bool: + """Return a boolean indicating if a request should be checked for rate limiting. + + Args: + request: A :class:`Request <.connection.Request>` instance. + + Returns: + Boolean dictating whether the request should be checked for rate-limiting. + """ + if self.check_throttle_handler: + return await self.check_throttle_handler(request) + return True + + def create_response_headers(self, cache_object: CacheObject) -> dict[str, str]: + """Create ratelimit response headers. + + Notes: + * see the `IETF RateLimit draft _` + + Args: + cache_object:A :class:`CacheObject`. + + Returns: + A dict of http headers. + """ + remaining_requests = str( + len(cache_object.history) - self.max_requests if len(cache_object.history) <= self.max_requests else 0 + ) + + return { + self.config.rate_limit_policy_header_key: f"{self.max_requests}; w={DURATION_VALUES[self.unit]}", + self.config.rate_limit_limit_header_key: str(self.max_requests), + self.config.rate_limit_remaining_header_key: remaining_requests, + self.config.rate_limit_reset_header_key: str(int(time()) - cache_object.reset), + } + + +@dataclass +class RateLimitConfig: + """Configuration for ``RateLimitMiddleware``""" + + rate_limit: tuple[DurationUnit, int] + """A tuple containing a time unit (second, minute, hour, day) and quantity, e.g. ("day", 1) or ("minute", 5).""" + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns to skip in the rate limiting middleware.""" + exclude_opt_key: str | None = field(default=None) + """An identifier to use on routes to disable rate limiting for a particular route.""" + check_throttle_handler: Callable[[Request[Any, Any, Any]], SyncOrAsyncUnion[bool]] | None = field(default=None) + """Handler callable that receives the request instance, returning a boolean dictating whether or not the request + should be checked for rate limiting. + """ + middleware_class: type[RateLimitMiddleware] = field(default=RateLimitMiddleware) + """The middleware class to use.""" + set_rate_limit_headers: bool = field(default=True) + """Boolean dictating whether to set the rate limit headers on the response.""" + rate_limit_policy_header_key: str = field(default="RateLimit-Policy") + """Key to use for the rate limit policy header.""" + rate_limit_remaining_header_key: str = field(default="RateLimit-Remaining") + """Key to use for the rate limit remaining header.""" + rate_limit_reset_header_key: str = field(default="RateLimit-Reset") + """Key to use for the rate limit reset header.""" + rate_limit_limit_header_key: str = field(default="RateLimit-Limit") + """Key to use for the rate limit limit header.""" + store: str = "rate_limit" + """Name of the :class:`Store <.stores.base.Store>` to use""" + + def __post_init__(self) -> None: + if self.check_throttle_handler: + self.check_throttle_handler = ensure_async_callable(self.check_throttle_handler) # type: ignore[arg-type] + + @property + def middleware(self) -> DefineMiddleware: + """Use this property to insert the config into a middleware list on one of the application layers. + + Examples: + .. code-block:: python + + from litestar import Litestar, Request, get + from litestar.middleware.rate_limit import RateLimitConfig + + # limit to 10 requests per minute, excluding the schema path + throttle_config = RateLimitConfig(rate_limit=("minute", 10), exclude=["/schema"]) + + + @get("/") + def my_handler(request: Request) -> None: ... + + + app = Litestar(route_handlers=[my_handler], middleware=[throttle_config.middleware]) + + Returns: + An instance of :class:`DefineMiddleware <.middleware.base.DefineMiddleware>` including ``self`` as the + config kwarg value. + """ + return DefineMiddleware(self.middleware_class, config=self) + + def get_store_from_app(self, app: Litestar) -> Store: + """Get the store defined in :attr:`store` from an :class:`Litestar <.app.Litestar>` instance.""" + return app.stores.get(self.store) -- cgit v1.2.3