summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py275
1 files changed, 0 insertions, 275 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py b/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py
deleted file mode 100644
index cd767ba..0000000
--- a/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py
+++ /dev/null
@@ -1,275 +0,0 @@
-from __future__ import annotations
-
-from dataclasses import dataclass, field
-from time import time
-from typing import TYPE_CHECKING, Any, Callable, Literal, cast
-
-from litestar.datastructures import MutableScopeHeaders
-from litestar.enums import ScopeType
-from litestar.exceptions import TooManyRequestsException
-from litestar.middleware.base import AbstractMiddleware, DefineMiddleware
-from litestar.serialization import decode_json, encode_json
-from litestar.utils import ensure_async_callable
-
-__all__ = ("CacheObject", "RateLimitConfig", "RateLimitMiddleware")
-
-
-if TYPE_CHECKING:
- from typing import Awaitable
-
- from litestar import Litestar
- from litestar.connection import Request
- from litestar.stores.base import Store
- from litestar.types import ASGIApp, Message, Receive, Scope, Send, SyncOrAsyncUnion
-
-
-DurationUnit = Literal["second", "minute", "hour", "day"]
-
-DURATION_VALUES: dict[DurationUnit, int] = {"second": 1, "minute": 60, "hour": 3600, "day": 86400}
-
-
-@dataclass
-class CacheObject:
- """Representation of a cached object's metadata."""
-
- __slots__ = ("history", "reset")
-
- history: list[int]
- reset: int
-
-
-class RateLimitMiddleware(AbstractMiddleware):
- """Rate-limiting middleware."""
-
- __slots__ = ("app", "check_throttle_handler", "max_requests", "unit", "request_quota", "config")
-
- def __init__(self, app: ASGIApp, config: RateLimitConfig) -> None:
- """Initialize ``RateLimitMiddleware``.
-
- Args:
- app: The ``next`` ASGI app to call.
- config: An instance of RateLimitConfig.
- """
- super().__init__(
- app=app, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key, scopes={ScopeType.HTTP}
- )
- self.check_throttle_handler = cast("Callable[[Request], Awaitable[bool]] | None", config.check_throttle_handler)
- self.config = config
- self.max_requests: int = config.rate_limit[1]
- self.unit: DurationUnit = config.rate_limit[0]
-
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- """ASGI callable.
-
- Args:
- scope: The ASGI connection scope.
- receive: The ASGI receive function.
- send: The ASGI send function.
-
- Returns:
- None
- """
- app = scope["app"]
- request: Request[Any, Any, Any] = app.request_class(scope)
- store = self.config.get_store_from_app(app)
- if await self.should_check_request(request=request):
- key = self.cache_key_from_request(request=request)
- cache_object = await self.retrieve_cached_history(key, store)
- if len(cache_object.history) >= self.max_requests:
- raise TooManyRequestsException(
- headers=self.create_response_headers(cache_object=cache_object)
- if self.config.set_rate_limit_headers
- else None
- )
- await self.set_cached_history(key=key, cache_object=cache_object, store=store)
- if self.config.set_rate_limit_headers:
- send = self.create_send_wrapper(send=send, cache_object=cache_object)
-
- await self.app(scope, receive, send) # pyright: ignore
-
- def create_send_wrapper(self, send: Send, cache_object: CacheObject) -> Send:
- """Create a ``send`` function that wraps the original send to inject response headers.
-
- Args:
- send: The ASGI send function.
- cache_object: A StorageObject instance.
-
- Returns:
- Send wrapper callable.
- """
-
- async def send_wrapper(message: Message) -> None:
- """Wrap the ASGI ``Send`` callable.
-
- Args:
- message: An ASGI ``Message``
-
- Returns:
- None
- """
- if message["type"] == "http.response.start":
- message.setdefault("headers", [])
- headers = MutableScopeHeaders(message)
- for key, value in self.create_response_headers(cache_object=cache_object).items():
- headers.add(key, value)
- await send(message)
-
- return send_wrapper
-
- def cache_key_from_request(self, request: Request[Any, Any, Any]) -> str:
- """Get a cache-key from a ``Request``
-
- Args:
- request: A :class:`Request <.connection.Request>` instance.
-
- Returns:
- A cache key.
- """
- host = request.client.host if request.client else "anonymous"
- identifier = request.headers.get("X-Forwarded-For") or request.headers.get("X-Real-IP") or host
- route_handler = request.scope["route_handler"]
- if getattr(route_handler, "is_mount", False):
- identifier += "::mount"
-
- if getattr(route_handler, "is_static", False):
- identifier += "::static"
-
- return f"{type(self).__name__}::{identifier}"
-
- async def retrieve_cached_history(self, key: str, store: Store) -> CacheObject:
- """Retrieve a list of time stamps for the given duration unit.
-
- Args:
- key: Cache key.
- store: A :class:`Store <.stores.base.Store>`
-
- Returns:
- An :class:`CacheObject`.
- """
- duration = DURATION_VALUES[self.unit]
- now = int(time())
- cached_string = await store.get(key)
- if cached_string:
- cache_object = CacheObject(**decode_json(value=cached_string))
- if cache_object.reset <= now:
- return CacheObject(history=[], reset=now + duration)
-
- while cache_object.history and cache_object.history[-1] <= now - duration:
- cache_object.history.pop()
- return cache_object
-
- return CacheObject(history=[], reset=now + duration)
-
- async def set_cached_history(self, key: str, cache_object: CacheObject, store: Store) -> None:
- """Store history extended with the current timestamp in cache.
-
- Args:
- key: Cache key.
- cache_object: A :class:`CacheObject`.
- store: A :class:`Store <.stores.base.Store>`
-
- Returns:
- None
- """
- cache_object.history = [int(time()), *cache_object.history]
- await store.set(key, encode_json(cache_object), expires_in=DURATION_VALUES[self.unit])
-
- async def should_check_request(self, request: Request[Any, Any, Any]) -> bool:
- """Return a boolean indicating if a request should be checked for rate limiting.
-
- Args:
- request: A :class:`Request <.connection.Request>` instance.
-
- Returns:
- Boolean dictating whether the request should be checked for rate-limiting.
- """
- if self.check_throttle_handler:
- return await self.check_throttle_handler(request)
- return True
-
- def create_response_headers(self, cache_object: CacheObject) -> dict[str, str]:
- """Create ratelimit response headers.
-
- Notes:
- * see the `IETF RateLimit draft <https://datatracker.ietf.org/doc/draft-ietf-httpapi-ratelimit-headers/>_`
-
- Args:
- cache_object:A :class:`CacheObject`.
-
- Returns:
- A dict of http headers.
- """
- remaining_requests = str(
- len(cache_object.history) - self.max_requests if len(cache_object.history) <= self.max_requests else 0
- )
-
- return {
- self.config.rate_limit_policy_header_key: f"{self.max_requests}; w={DURATION_VALUES[self.unit]}",
- self.config.rate_limit_limit_header_key: str(self.max_requests),
- self.config.rate_limit_remaining_header_key: remaining_requests,
- self.config.rate_limit_reset_header_key: str(int(time()) - cache_object.reset),
- }
-
-
-@dataclass
-class RateLimitConfig:
- """Configuration for ``RateLimitMiddleware``"""
-
- rate_limit: tuple[DurationUnit, int]
- """A tuple containing a time unit (second, minute, hour, day) and quantity, e.g. ("day", 1) or ("minute", 5)."""
- exclude: str | list[str] | None = field(default=None)
- """A pattern or list of patterns to skip in the rate limiting middleware."""
- exclude_opt_key: str | None = field(default=None)
- """An identifier to use on routes to disable rate limiting for a particular route."""
- check_throttle_handler: Callable[[Request[Any, Any, Any]], SyncOrAsyncUnion[bool]] | None = field(default=None)
- """Handler callable that receives the request instance, returning a boolean dictating whether or not the request
- should be checked for rate limiting.
- """
- middleware_class: type[RateLimitMiddleware] = field(default=RateLimitMiddleware)
- """The middleware class to use."""
- set_rate_limit_headers: bool = field(default=True)
- """Boolean dictating whether to set the rate limit headers on the response."""
- rate_limit_policy_header_key: str = field(default="RateLimit-Policy")
- """Key to use for the rate limit policy header."""
- rate_limit_remaining_header_key: str = field(default="RateLimit-Remaining")
- """Key to use for the rate limit remaining header."""
- rate_limit_reset_header_key: str = field(default="RateLimit-Reset")
- """Key to use for the rate limit reset header."""
- rate_limit_limit_header_key: str = field(default="RateLimit-Limit")
- """Key to use for the rate limit limit header."""
- store: str = "rate_limit"
- """Name of the :class:`Store <.stores.base.Store>` to use"""
-
- def __post_init__(self) -> None:
- if self.check_throttle_handler:
- self.check_throttle_handler = ensure_async_callable(self.check_throttle_handler) # type: ignore[arg-type]
-
- @property
- def middleware(self) -> DefineMiddleware:
- """Use this property to insert the config into a middleware list on one of the application layers.
-
- Examples:
- .. code-block:: python
-
- from litestar import Litestar, Request, get
- from litestar.middleware.rate_limit import RateLimitConfig
-
- # limit to 10 requests per minute, excluding the schema path
- throttle_config = RateLimitConfig(rate_limit=("minute", 10), exclude=["/schema"])
-
-
- @get("/")
- def my_handler(request: Request) -> None: ...
-
-
- app = Litestar(route_handlers=[my_handler], middleware=[throttle_config.middleware])
-
- Returns:
- An instance of :class:`DefineMiddleware <.middleware.base.DefineMiddleware>` including ``self`` as the
- config kwarg value.
- """
- return DefineMiddleware(self.middleware_class, config=self)
-
- def get_store_from_app(self, app: Litestar) -> Store:
- """Get the store defined in :attr:`store` from an :class:`Litestar <.app.Litestar>` instance."""
- return app.stores.get(self.store)