summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py275
1 files changed, 275 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py b/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py
new file mode 100644
index 0000000..cd767ba
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py
@@ -0,0 +1,275 @@
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from time import time
+from typing import TYPE_CHECKING, Any, Callable, Literal, cast
+
+from litestar.datastructures import MutableScopeHeaders
+from litestar.enums import ScopeType
+from litestar.exceptions import TooManyRequestsException
+from litestar.middleware.base import AbstractMiddleware, DefineMiddleware
+from litestar.serialization import decode_json, encode_json
+from litestar.utils import ensure_async_callable
+
+__all__ = ("CacheObject", "RateLimitConfig", "RateLimitMiddleware")
+
+
+if TYPE_CHECKING:
+ from typing import Awaitable
+
+ from litestar import Litestar
+ from litestar.connection import Request
+ from litestar.stores.base import Store
+ from litestar.types import ASGIApp, Message, Receive, Scope, Send, SyncOrAsyncUnion
+
+
+DurationUnit = Literal["second", "minute", "hour", "day"]
+
+DURATION_VALUES: dict[DurationUnit, int] = {"second": 1, "minute": 60, "hour": 3600, "day": 86400}
+
+
+@dataclass
+class CacheObject:
+ """Representation of a cached object's metadata."""
+
+ __slots__ = ("history", "reset")
+
+ history: list[int]
+ reset: int
+
+
+class RateLimitMiddleware(AbstractMiddleware):
+ """Rate-limiting middleware."""
+
+ __slots__ = ("app", "check_throttle_handler", "max_requests", "unit", "request_quota", "config")
+
+ def __init__(self, app: ASGIApp, config: RateLimitConfig) -> None:
+ """Initialize ``RateLimitMiddleware``.
+
+ Args:
+ app: The ``next`` ASGI app to call.
+ config: An instance of RateLimitConfig.
+ """
+ super().__init__(
+ app=app, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key, scopes={ScopeType.HTTP}
+ )
+ self.check_throttle_handler = cast("Callable[[Request], Awaitable[bool]] | None", config.check_throttle_handler)
+ self.config = config
+ self.max_requests: int = config.rate_limit[1]
+ self.unit: DurationUnit = config.rate_limit[0]
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ """ASGI callable.
+
+ Args:
+ scope: The ASGI connection scope.
+ receive: The ASGI receive function.
+ send: The ASGI send function.
+
+ Returns:
+ None
+ """
+ app = scope["app"]
+ request: Request[Any, Any, Any] = app.request_class(scope)
+ store = self.config.get_store_from_app(app)
+ if await self.should_check_request(request=request):
+ key = self.cache_key_from_request(request=request)
+ cache_object = await self.retrieve_cached_history(key, store)
+ if len(cache_object.history) >= self.max_requests:
+ raise TooManyRequestsException(
+ headers=self.create_response_headers(cache_object=cache_object)
+ if self.config.set_rate_limit_headers
+ else None
+ )
+ await self.set_cached_history(key=key, cache_object=cache_object, store=store)
+ if self.config.set_rate_limit_headers:
+ send = self.create_send_wrapper(send=send, cache_object=cache_object)
+
+ await self.app(scope, receive, send) # pyright: ignore
+
+ def create_send_wrapper(self, send: Send, cache_object: CacheObject) -> Send:
+ """Create a ``send`` function that wraps the original send to inject response headers.
+
+ Args:
+ send: The ASGI send function.
+ cache_object: A StorageObject instance.
+
+ Returns:
+ Send wrapper callable.
+ """
+
+ async def send_wrapper(message: Message) -> None:
+ """Wrap the ASGI ``Send`` callable.
+
+ Args:
+ message: An ASGI ``Message``
+
+ Returns:
+ None
+ """
+ if message["type"] == "http.response.start":
+ message.setdefault("headers", [])
+ headers = MutableScopeHeaders(message)
+ for key, value in self.create_response_headers(cache_object=cache_object).items():
+ headers.add(key, value)
+ await send(message)
+
+ return send_wrapper
+
+ def cache_key_from_request(self, request: Request[Any, Any, Any]) -> str:
+ """Get a cache-key from a ``Request``
+
+ Args:
+ request: A :class:`Request <.connection.Request>` instance.
+
+ Returns:
+ A cache key.
+ """
+ host = request.client.host if request.client else "anonymous"
+ identifier = request.headers.get("X-Forwarded-For") or request.headers.get("X-Real-IP") or host
+ route_handler = request.scope["route_handler"]
+ if getattr(route_handler, "is_mount", False):
+ identifier += "::mount"
+
+ if getattr(route_handler, "is_static", False):
+ identifier += "::static"
+
+ return f"{type(self).__name__}::{identifier}"
+
+ async def retrieve_cached_history(self, key: str, store: Store) -> CacheObject:
+ """Retrieve a list of time stamps for the given duration unit.
+
+ Args:
+ key: Cache key.
+ store: A :class:`Store <.stores.base.Store>`
+
+ Returns:
+ An :class:`CacheObject`.
+ """
+ duration = DURATION_VALUES[self.unit]
+ now = int(time())
+ cached_string = await store.get(key)
+ if cached_string:
+ cache_object = CacheObject(**decode_json(value=cached_string))
+ if cache_object.reset <= now:
+ return CacheObject(history=[], reset=now + duration)
+
+ while cache_object.history and cache_object.history[-1] <= now - duration:
+ cache_object.history.pop()
+ return cache_object
+
+ return CacheObject(history=[], reset=now + duration)
+
+ async def set_cached_history(self, key: str, cache_object: CacheObject, store: Store) -> None:
+ """Store history extended with the current timestamp in cache.
+
+ Args:
+ key: Cache key.
+ cache_object: A :class:`CacheObject`.
+ store: A :class:`Store <.stores.base.Store>`
+
+ Returns:
+ None
+ """
+ cache_object.history = [int(time()), *cache_object.history]
+ await store.set(key, encode_json(cache_object), expires_in=DURATION_VALUES[self.unit])
+
+ async def should_check_request(self, request: Request[Any, Any, Any]) -> bool:
+ """Return a boolean indicating if a request should be checked for rate limiting.
+
+ Args:
+ request: A :class:`Request <.connection.Request>` instance.
+
+ Returns:
+ Boolean dictating whether the request should be checked for rate-limiting.
+ """
+ if self.check_throttle_handler:
+ return await self.check_throttle_handler(request)
+ return True
+
+ def create_response_headers(self, cache_object: CacheObject) -> dict[str, str]:
+ """Create ratelimit response headers.
+
+ Notes:
+ * see the `IETF RateLimit draft <https://datatracker.ietf.org/doc/draft-ietf-httpapi-ratelimit-headers/>_`
+
+ Args:
+ cache_object:A :class:`CacheObject`.
+
+ Returns:
+ A dict of http headers.
+ """
+ remaining_requests = str(
+ len(cache_object.history) - self.max_requests if len(cache_object.history) <= self.max_requests else 0
+ )
+
+ return {
+ self.config.rate_limit_policy_header_key: f"{self.max_requests}; w={DURATION_VALUES[self.unit]}",
+ self.config.rate_limit_limit_header_key: str(self.max_requests),
+ self.config.rate_limit_remaining_header_key: remaining_requests,
+ self.config.rate_limit_reset_header_key: str(int(time()) - cache_object.reset),
+ }
+
+
+@dataclass
+class RateLimitConfig:
+ """Configuration for ``RateLimitMiddleware``"""
+
+ rate_limit: tuple[DurationUnit, int]
+ """A tuple containing a time unit (second, minute, hour, day) and quantity, e.g. ("day", 1) or ("minute", 5)."""
+ exclude: str | list[str] | None = field(default=None)
+ """A pattern or list of patterns to skip in the rate limiting middleware."""
+ exclude_opt_key: str | None = field(default=None)
+ """An identifier to use on routes to disable rate limiting for a particular route."""
+ check_throttle_handler: Callable[[Request[Any, Any, Any]], SyncOrAsyncUnion[bool]] | None = field(default=None)
+ """Handler callable that receives the request instance, returning a boolean dictating whether or not the request
+ should be checked for rate limiting.
+ """
+ middleware_class: type[RateLimitMiddleware] = field(default=RateLimitMiddleware)
+ """The middleware class to use."""
+ set_rate_limit_headers: bool = field(default=True)
+ """Boolean dictating whether to set the rate limit headers on the response."""
+ rate_limit_policy_header_key: str = field(default="RateLimit-Policy")
+ """Key to use for the rate limit policy header."""
+ rate_limit_remaining_header_key: str = field(default="RateLimit-Remaining")
+ """Key to use for the rate limit remaining header."""
+ rate_limit_reset_header_key: str = field(default="RateLimit-Reset")
+ """Key to use for the rate limit reset header."""
+ rate_limit_limit_header_key: str = field(default="RateLimit-Limit")
+ """Key to use for the rate limit limit header."""
+ store: str = "rate_limit"
+ """Name of the :class:`Store <.stores.base.Store>` to use"""
+
+ def __post_init__(self) -> None:
+ if self.check_throttle_handler:
+ self.check_throttle_handler = ensure_async_callable(self.check_throttle_handler) # type: ignore[arg-type]
+
+ @property
+ def middleware(self) -> DefineMiddleware:
+ """Use this property to insert the config into a middleware list on one of the application layers.
+
+ Examples:
+ .. code-block:: python
+
+ from litestar import Litestar, Request, get
+ from litestar.middleware.rate_limit import RateLimitConfig
+
+ # limit to 10 requests per minute, excluding the schema path
+ throttle_config = RateLimitConfig(rate_limit=("minute", 10), exclude=["/schema"])
+
+
+ @get("/")
+ def my_handler(request: Request) -> None: ...
+
+
+ app = Litestar(route_handlers=[my_handler], middleware=[throttle_config.middleware])
+
+ Returns:
+ An instance of :class:`DefineMiddleware <.middleware.base.DefineMiddleware>` including ``self`` as the
+ config kwarg value.
+ """
+ return DefineMiddleware(self.middleware_class, config=self)
+
+ def get_store_from_app(self, app: Litestar) -> Store:
+ """Get the store defined in :attr:`store` from an :class:`Litestar <.app.Litestar>` instance."""
+ return app.stores.get(self.store)