diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar')
678 files changed, 44681 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/__init__.py b/venv/lib/python3.11/site-packages/litestar/__init__.py new file mode 100644 index 0000000..3235113 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__init__.py @@ -0,0 +1,33 @@ +from litestar.app import Litestar +from litestar.connection import Request, WebSocket +from litestar.controller import Controller +from litestar.enums import HttpMethod, MediaType +from litestar.handlers import asgi, delete, get, head, patch, post, put, route, websocket, websocket_listener +from litestar.response import Response +from litestar.router import Router +from litestar.utils.version import get_version + +__version__ = get_version() + + +__all__ = ( + "Controller", + "HttpMethod", + "Litestar", + "MediaType", + "Request", + "Response", + "Router", + "WebSocket", + "__version__", + "asgi", + "delete", + "get", + "head", + "patch", + "post", + "put", + "route", + "websocket", + "websocket_listener", +) diff --git a/venv/lib/python3.11/site-packages/litestar/__main__.py b/venv/lib/python3.11/site-packages/litestar/__main__.py new file mode 100644 index 0000000..92f2d86 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__main__.py @@ -0,0 +1,10 @@ +from litestar.cli.main import litestar_group + + +def run_cli() -> None: + """Application Entrypoint.""" + litestar_group() + + +if __name__ == "__main__": + run_cli() diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5a8d532 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/__main__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/__main__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..874c61f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/__main__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/_multipart.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/_multipart.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..62b625c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/_multipart.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/_parsers.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/_parsers.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0ea2061 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/_parsers.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/app.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/app.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..833a1a8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/app.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/background_tasks.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/background_tasks.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..750190c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/background_tasks.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/concurrency.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/concurrency.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..e3f82cf --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/concurrency.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/constants.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/constants.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a95a8b4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/constants.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/controller.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/controller.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4c7556c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/controller.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/data_extractors.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/data_extractors.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..379a991 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/data_extractors.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/di.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/di.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..224cd6a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/di.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/enums.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/enums.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..de5d781 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/enums.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/file_system.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/file_system.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..fef4f7e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/file_system.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/pagination.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/pagination.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..99feb13 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/pagination.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/params.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/params.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6fb5e43 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/params.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/router.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/router.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..26d45d8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/router.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/status_codes.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/status_codes.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b18ebee --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/status_codes.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/__pycache__/typing.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/__pycache__/typing.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..98e4dd1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/__pycache__/typing.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/__init__.py b/venv/lib/python3.11/site-packages/litestar/_asgi/__init__.py new file mode 100644 index 0000000..4cb42ad --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/__init__.py @@ -0,0 +1,3 @@ +from litestar._asgi.asgi_router import ASGIRouter + +__all__ = ("ASGIRouter",) diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_asgi/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..9b12b52 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/__pycache__/asgi_router.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_asgi/__pycache__/asgi_router.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..75c2cdb --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/__pycache__/asgi_router.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/__pycache__/utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_asgi/__pycache__/utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2e9f267 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/__pycache__/utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/asgi_router.py b/venv/lib/python3.11/site-packages/litestar/_asgi/asgi_router.py new file mode 100644 index 0000000..ebafaf0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/asgi_router.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import re +from collections import defaultdict +from functools import lru_cache +from traceback import format_exc +from typing import TYPE_CHECKING, Any, Pattern + +from litestar._asgi.routing_trie import validate_node +from litestar._asgi.routing_trie.mapping import add_route_to_trie +from litestar._asgi.routing_trie.traversal import parse_path_to_route +from litestar._asgi.routing_trie.types import create_node +from litestar._asgi.utils import get_route_handlers +from litestar.exceptions import ImproperlyConfiguredException +from litestar.utils import normalize_path + +__all__ = ("ASGIRouter",) + + +if TYPE_CHECKING: + from litestar._asgi.routing_trie.types import RouteTrieNode + from litestar.app import Litestar + from litestar.routes import ASGIRoute, HTTPRoute, WebSocketRoute + from litestar.routes.base import BaseRoute + from litestar.types import ( + ASGIApp, + LifeSpanReceive, + LifeSpanSend, + LifeSpanShutdownCompleteEvent, + LifeSpanShutdownFailedEvent, + LifeSpanStartupCompleteEvent, + LifeSpanStartupFailedEvent, + Method, + Receive, + RouteHandlerType, + Scope, + Send, + ) + + +class ASGIRouter: + """Litestar ASGI router. + + Handling both the ASGI lifespan events and routing of connection requests. + """ + + __slots__ = ( + "_mount_paths_regex", + "_mount_routes", + "_plain_routes", + "_registered_routes", + "_static_routes", + "app", + "root_route_map_node", + "route_handler_index", + "route_mapping", + ) + + def __init__(self, app: Litestar) -> None: + """Initialize ``ASGIRouter``. + + Args: + app: The Litestar app instance + """ + self._mount_paths_regex: Pattern | None = None + self._mount_routes: dict[str, RouteTrieNode] = {} + self._plain_routes: set[str] = set() + self._registered_routes: set[HTTPRoute | WebSocketRoute | ASGIRoute] = set() + self.app = app + self.root_route_map_node: RouteTrieNode = create_node() + self.route_handler_index: dict[str, RouteHandlerType] = {} + self.route_mapping: dict[str, list[BaseRoute]] = defaultdict(list) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + The main entry point to the Router class. + """ + scope.setdefault("path_params", {}) + + path = scope["path"] + if root_path := scope.get("root_path", ""): + path = path.split(root_path, maxsplit=1)[-1] + normalized_path = normalize_path(path) + + asgi_app, scope["route_handler"], scope["path"], scope["path_params"] = self.handle_routing( + path=normalized_path, method=scope.get("method") + ) + await asgi_app(scope, receive, send) + + @lru_cache(1024) # noqa: B019 + def handle_routing(self, path: str, method: Method | None) -> tuple[ASGIApp, RouteHandlerType, str, dict[str, Any]]: + """Handle routing for a given path / method combo. This method is meant to allow easy caching. + + Args: + path: The path of the request. + method: The scope's method, if any. + + Returns: + A tuple composed of the ASGIApp of the route, the route handler instance, the resolved and normalized path and any parsed path params. + """ + return parse_path_to_route( + mount_paths_regex=self._mount_paths_regex, + mount_routes=self._mount_routes, + path=path, + plain_routes=self._plain_routes, + root_node=self.root_route_map_node, + method=method, + ) + + def _store_handler_to_route_mapping(self, route: BaseRoute) -> None: + """Store the mapping of route handlers to routes and to route handler names. + + Args: + route: A Route instance. + + Returns: + None + """ + + for handler in get_route_handlers(route): + if handler.name in self.route_handler_index and str(self.route_handler_index[handler.name]) != str(handler): + raise ImproperlyConfiguredException( + f"route handler names must be unique - {handler.name} is not unique." + ) + identifier = handler.name or str(handler) + self.route_mapping[identifier].append(route) + self.route_handler_index[identifier] = handler + + def construct_routing_trie(self) -> None: + """Create a map of the app's routes. + + This map is used in the asgi router to route requests. + """ + new_routes = [route for route in self.app.routes if route not in self._registered_routes] + for route in new_routes: + add_route_to_trie( + app=self.app, + mount_routes=self._mount_routes, + plain_routes=self._plain_routes, + root_node=self.root_route_map_node, + route=route, + ) + self._store_handler_to_route_mapping(route) + self._registered_routes.add(route) + + validate_node(node=self.root_route_map_node) + if self._mount_routes: + self._mount_paths_regex = re.compile("|".join(sorted(set(self._mount_routes)))) # pyright: ignore + + async def lifespan(self, receive: LifeSpanReceive, send: LifeSpanSend) -> None: + """Handle the ASGI "lifespan" event on application startup and shutdown. + + Args: + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None. + """ + + message = await receive() + shutdown_event: LifeSpanShutdownCompleteEvent = {"type": "lifespan.shutdown.complete"} + startup_event: LifeSpanStartupCompleteEvent = {"type": "lifespan.startup.complete"} + + try: + async with self.app.lifespan(): + await send(startup_event) + message = await receive() + + except BaseException as e: + formatted_exception = format_exc() + failure_message: LifeSpanStartupFailedEvent | LifeSpanShutdownFailedEvent + + if message["type"] == "lifespan.startup": + failure_message = {"type": "lifespan.startup.failed", "message": formatted_exception} + else: + failure_message = {"type": "lifespan.shutdown.failed", "message": formatted_exception} + + await send(failure_message) + + raise e + + await send(shutdown_event) diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__init__.py b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__init__.py new file mode 100644 index 0000000..948e394 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__init__.py @@ -0,0 +1,6 @@ +from litestar._asgi.routing_trie.mapping import add_route_to_trie +from litestar._asgi.routing_trie.traversal import parse_path_to_route +from litestar._asgi.routing_trie.types import RouteTrieNode +from litestar._asgi.routing_trie.validate import validate_node + +__all__ = ("RouteTrieNode", "add_route_to_trie", "parse_path_to_route", "validate_node") diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f8658f9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/mapping.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/mapping.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..73323bf --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/mapping.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/traversal.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/traversal.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..9ed2983 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/traversal.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5247c5b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/validate.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/validate.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..57fdcdd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/__pycache__/validate.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/mapping.py b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/mapping.py new file mode 100644 index 0000000..7a56b97 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/mapping.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any, cast + +from litestar._asgi.routing_trie.types import ( + ASGIHandlerTuple, + PathParameterSentinel, + create_node, +) +from litestar._asgi.utils import wrap_in_exception_handler +from litestar.types.internal_types import PathParameterDefinition + +__all__ = ("add_mount_route", "add_route_to_trie", "build_route_middleware_stack", "configure_node") + + +if TYPE_CHECKING: + from litestar._asgi.routing_trie.types import RouteTrieNode + from litestar.app import Litestar + from litestar.routes import ASGIRoute, HTTPRoute, WebSocketRoute + from litestar.types import ASGIApp, RouteHandlerType + + +def add_mount_route( + current_node: RouteTrieNode, + mount_routes: dict[str, RouteTrieNode], + root_node: RouteTrieNode, + route: ASGIRoute, +) -> RouteTrieNode: + """Add a node for a mount route. + + Args: + current_node: The current trie node that is being mapped. + mount_routes: A dictionary mapping static routes to trie nodes. + root_node: The root trie node. + route: The route that is being added. + + Returns: + A trie node. + """ + + # we need to ensure that we can traverse the map both through the full path key, e.g. "/my-route/sub-path" and + # via the components keys ["my-route, "sub-path"] + if route.path not in current_node.children: + root_node = current_node + for component in route.path_components: + if component not in current_node.children: + current_node.children[component] = create_node() # type: ignore[index] + current_node = current_node.children[component] # type: ignore[index] + + current_node.is_mount = True + current_node.is_static = route.route_handler.is_static + + if route.path != "/": + mount_routes[route.path] = root_node.children[route.path] = current_node + else: + mount_routes[route.path] = current_node + + return current_node + + +def add_route_to_trie( + app: Litestar, + mount_routes: dict[str, RouteTrieNode], + plain_routes: set[str], + root_node: RouteTrieNode, + route: HTTPRoute | WebSocketRoute | ASGIRoute, +) -> RouteTrieNode: + """Add a new route path (e.g. '/foo/bar/{param:int}') into the route_map tree. + + Inserts non-parameter paths ('plain routes') off the tree's root + node. For paths containing parameters, splits the path on '/' and + nests each path segment under the previous segment's node (see + prefix tree / trie). + + Args: + app: The Litestar app instance. + mount_routes: A dictionary mapping static routes to trie nodes. + plain_routes: A set of routes that do not have path parameters. + root_node: The root trie node. + route: The route that is being added. + + Returns: + A RouteTrieNode instance. + """ + current_node = root_node + + has_path_parameters = bool(route.path_parameters) + + if (route_handler := getattr(route, "route_handler", None)) and getattr(route_handler, "is_mount", False): + current_node = add_mount_route( + current_node=current_node, + mount_routes=mount_routes, + root_node=root_node, + route=cast("ASGIRoute", route), + ) + + elif not has_path_parameters: + plain_routes.add(route.path) + if route.path not in root_node.children: + current_node.children[route.path] = create_node() + current_node = root_node.children[route.path] + + else: + for component in route.path_components: + if isinstance(component, PathParameterDefinition): + current_node.is_path_param_node = True + next_node_key: type[PathParameterSentinel] | str = PathParameterSentinel + + else: + next_node_key = component + + if next_node_key not in current_node.children: + current_node.children[next_node_key] = create_node() + + current_node.child_keys = set(current_node.children.keys()) + current_node = current_node.children[next_node_key] + + if isinstance(component, PathParameterDefinition) and component.type is Path: + current_node.is_path_type = True + + configure_node(route=route, app=app, node=current_node) + return current_node + + +def configure_node( + app: Litestar, + route: HTTPRoute | WebSocketRoute | ASGIRoute, + node: RouteTrieNode, +) -> None: + """Set required attributes and route handlers on route_map tree node. + + Args: + app: The Litestar app instance. + route: The route that is being added. + node: The trie node being configured. + + Returns: + None + """ + from litestar.routes import HTTPRoute, WebSocketRoute + + if not node.path_parameters: + node.path_parameters = {} + + if isinstance(route, HTTPRoute): + for method, handler_mapping in route.route_handler_map.items(): + handler, _ = handler_mapping + node.asgi_handlers[method] = ASGIHandlerTuple( + asgi_app=build_route_middleware_stack(app=app, route=route, route_handler=handler), + handler=handler, + ) + node.path_parameters[method] = route.path_parameters + + elif isinstance(route, WebSocketRoute): + node.asgi_handlers["websocket"] = ASGIHandlerTuple( + asgi_app=build_route_middleware_stack(app=app, route=route, route_handler=route.route_handler), + handler=route.route_handler, + ) + node.path_parameters["websocket"] = route.path_parameters + + else: + node.asgi_handlers["asgi"] = ASGIHandlerTuple( + asgi_app=build_route_middleware_stack(app=app, route=route, route_handler=route.route_handler), + handler=route.route_handler, + ) + node.path_parameters["asgi"] = route.path_parameters + node.is_asgi = True + + +def build_route_middleware_stack( + app: Litestar, + route: HTTPRoute | WebSocketRoute | ASGIRoute, + route_handler: RouteHandlerType, +) -> ASGIApp: + """Construct a middleware stack that serves as the point of entry for each route. + + Args: + app: The Litestar app instance. + route: The route that is being added. + route_handler: The route handler that is being wrapped. + + Returns: + An ASGIApp that is composed of a "stack" of middlewares. + """ + from litestar.middleware.allowed_hosts import AllowedHostsMiddleware + from litestar.middleware.compression import CompressionMiddleware + from litestar.middleware.csrf import CSRFMiddleware + from litestar.middleware.response_cache import ResponseCacheMiddleware + from litestar.routes import HTTPRoute + + # we wrap the route.handle method in the ExceptionHandlerMiddleware + asgi_handler = wrap_in_exception_handler( + app=route.handle, # type: ignore[arg-type] + exception_handlers=route_handler.resolve_exception_handlers(), + ) + + if app.csrf_config: + asgi_handler = CSRFMiddleware(app=asgi_handler, config=app.csrf_config) + + if app.compression_config: + asgi_handler = CompressionMiddleware(app=asgi_handler, config=app.compression_config) + + if isinstance(route, HTTPRoute) and any(r.cache for r in route.route_handlers): + asgi_handler = ResponseCacheMiddleware(app=asgi_handler, config=app.response_cache_config) + + if app.allowed_hosts: + asgi_handler = AllowedHostsMiddleware(app=asgi_handler, config=app.allowed_hosts) + + for middleware in route_handler.resolve_middleware(): + if hasattr(middleware, "__iter__"): + handler, kwargs = cast("tuple[Any, dict[str, Any]]", middleware) + asgi_handler = handler(app=asgi_handler, **kwargs) + else: + asgi_handler = middleware(app=asgi_handler) # type: ignore[call-arg] + + # we wrap the entire stack again in ExceptionHandlerMiddleware + return wrap_in_exception_handler( + app=cast("ASGIApp", asgi_handler), + exception_handlers=route_handler.resolve_exception_handlers(), + ) # pyright: ignore diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/traversal.py b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/traversal.py new file mode 100644 index 0000000..b7788bd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/traversal.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import TYPE_CHECKING, Any, Pattern + +from litestar._asgi.routing_trie.types import PathParameterSentinel +from litestar.exceptions import MethodNotAllowedException, NotFoundException +from litestar.utils import normalize_path + +__all__ = ("parse_node_handlers", "parse_path_params", "parse_path_to_route", "traverse_route_map") + + +if TYPE_CHECKING: + from litestar._asgi.routing_trie.types import ASGIHandlerTuple, RouteTrieNode + from litestar.types import ASGIApp, Method, RouteHandlerType + from litestar.types.internal_types import PathParameterDefinition + + +def traverse_route_map( + root_node: RouteTrieNode, + path: str, +) -> tuple[RouteTrieNode, list[str], str]: + """Traverses the application route mapping and retrieves the correct node for the request url. + + Args: + root_node: The root trie node. + path: The request's path. + + Raises: + NotFoundException: If no correlating node is found. + + Returns: + A tuple containing the target RouteMapNode and a list containing all path parameter values. + """ + current_node = root_node + path_params: list[str] = [] + path_components = [p for p in path.split("/") if p] + + for i, component in enumerate(path_components): + if component in current_node.child_keys: + current_node = current_node.children[component] + continue + + if current_node.is_path_param_node: + current_node = current_node.children[PathParameterSentinel] + + if current_node.is_path_type: + path_params.append(normalize_path("/".join(path_components[i:]))) + break + + path_params.append(component) + continue + + raise NotFoundException() + + if not current_node.asgi_handlers: + raise NotFoundException() + + return current_node, path_params, path + + +def parse_node_handlers( + node: RouteTrieNode, + method: Method | None, +) -> ASGIHandlerTuple: + """Retrieve the handler tuple from the node. + + Args: + node: The trie node to parse. + method: The scope's method. + + Raises: + KeyError: If no matching method is found. + + Returns: + An ASGI Handler tuple. + """ + + if node.is_asgi: + return node.asgi_handlers["asgi"] + if method: + return node.asgi_handlers[method] + return node.asgi_handlers["websocket"] + + +@lru_cache(1024) +def parse_path_params( + parameter_definitions: tuple[PathParameterDefinition, ...], path_param_values: tuple[str, ...] +) -> dict[str, Any]: + """Parse path parameters into a dictionary of values. + + Args: + parameter_definitions: The parameter definitions tuple from the route. + path_param_values: The string values extracted from the url + + Raises: + ValueError: If any of path parameters can not be parsed into a value. + + Returns: + A dictionary of parsed path parameters. + """ + return { + param_definition.name: param_definition.parser(value) if param_definition.parser else value + for param_definition, value in zip(parameter_definitions, path_param_values) + } + + +def parse_path_to_route( + method: Method | None, + mount_paths_regex: Pattern | None, + mount_routes: dict[str, RouteTrieNode], + path: str, + plain_routes: set[str], + root_node: RouteTrieNode, +) -> tuple[ASGIApp, RouteHandlerType, str, dict[str, Any]]: + """Given a scope object, retrieve the asgi_handlers and is_mount boolean values from correct trie node. + + Args: + method: The scope's method, if any. + root_node: The root trie node. + path: The path to resolve scope instance. + plain_routes: The set of plain routes. + mount_routes: Mapping of mount routes to trie nodes. + mount_paths_regex: A compiled regex to match the mount routes. + + Raises: + MethodNotAllowedException: if no matching method is found. + NotFoundException: If no correlating node is found or if path params can not be parsed into values according to the node definition. + + Returns: + A tuple containing the stack of middlewares and the route handler that is wrapped by it. + """ + + try: + if path in plain_routes: + asgi_app, handler = parse_node_handlers(node=root_node.children[path], method=method) + return asgi_app, handler, path, {} + + if mount_paths_regex and (match := mount_paths_regex.search(path)): + mount_path = path[match.start() : match.end()] + mount_node = mount_routes[mount_path] + remaining_path = path[match.end() :] + # since we allow regular handlers under static paths, we must validate that the request does not match + # any such handler. + children = [sub_route for sub_route in mount_node.children or [] if sub_route != mount_path] + if not children or all(sub_route not in path for sub_route in children): # type: ignore[operator] + asgi_app, handler = parse_node_handlers(node=mount_node, method=method) + remaining_path = remaining_path or "/" + if not mount_node.is_static: + remaining_path = remaining_path if remaining_path.endswith("/") else f"{remaining_path}/" + return asgi_app, handler, remaining_path, {} + + node, path_parameters, path = traverse_route_map( + root_node=root_node, + path=path, + ) + asgi_app, handler = parse_node_handlers(node=node, method=method) + key = method or ("asgi" if node.is_asgi else "websocket") + parsed_path_parameters = parse_path_params(node.path_parameters[key], tuple(path_parameters)) + + return ( + asgi_app, + handler, + path, + parsed_path_parameters, + ) + except KeyError as e: + raise MethodNotAllowedException() from e + except ValueError as e: + raise NotFoundException() from e diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/types.py b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/types.py new file mode 100644 index 0000000..d1fc368 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/types.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Literal, NamedTuple + +__all__ = ("ASGIHandlerTuple", "PathParameterSentinel", "RouteTrieNode", "create_node") + + +if TYPE_CHECKING: + from litestar.types import ASGIApp, Method, RouteHandlerType + from litestar.types.internal_types import PathParameterDefinition + + +class PathParameterSentinel: + """Sentinel class designating a path parameter.""" + + +class ASGIHandlerTuple(NamedTuple): + """Encapsulation of a route handler node.""" + + asgi_app: ASGIApp + """An ASGI stack, composed of a handler function and layers of middleware that wrap it.""" + handler: RouteHandlerType + """The route handler instance.""" + + +@dataclass(unsafe_hash=True) +class RouteTrieNode: + """A radix trie node.""" + + __slots__ = ( + "asgi_handlers", + "child_keys", + "children", + "is_asgi", + "is_mount", + "is_static", + "is_path_param_node", + "is_path_type", + "path_parameters", + ) + + asgi_handlers: dict[Method | Literal["websocket", "asgi"], ASGIHandlerTuple] + """A mapping of ASGI handlers stored on the node.""" + child_keys: set[str | type[PathParameterSentinel]] + """ + A set containing the child keys, same as the children dictionary - but as a set, which offers faster lookup. + """ + children: dict[str | type[PathParameterSentinel], RouteTrieNode] + """A dictionary mapping path components or using the PathParameterSentinel class to child nodes.""" + is_path_param_node: bool + """Designates the node as having a path parameter.""" + is_path_type: bool + """Designates the node as having a 'path' type path parameter.""" + is_asgi: bool + """Designate the node as having an `asgi` type handler.""" + is_mount: bool + """Designate the node as being a mount route.""" + is_static: bool + """Designate the node as being a static mount route.""" + path_parameters: dict[Method | Literal["websocket"] | Literal["asgi"], tuple[PathParameterDefinition, ...]] + """A list of tuples containing path parameter definitions. + + This is used for parsing extracted path parameter values. + """ + + +def create_node() -> RouteTrieNode: + """Create a RouteMapNode instance. + + Returns: + A route map node instance. + """ + + return RouteTrieNode( + asgi_handlers={}, + child_keys=set(), + children={}, + is_path_param_node=False, + is_asgi=False, + is_mount=False, + is_static=False, + is_path_type=False, + path_parameters={}, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/validate.py b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/validate.py new file mode 100644 index 0000000..5c29fac --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/routing_trie/validate.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from itertools import chain +from typing import TYPE_CHECKING + +from litestar.exceptions import ImproperlyConfiguredException + +__all__ = ("validate_node",) + + +if TYPE_CHECKING: + from litestar._asgi.routing_trie.types import RouteTrieNode + + +def validate_node(node: RouteTrieNode) -> None: + """Recursively traverses the trie from the given node upwards. + + Args: + node: A trie node. + + Raises: + ImproperlyConfiguredException + + Returns: + None + """ + if node.is_asgi and bool(set(node.asgi_handlers).difference({"asgi"})): + raise ImproperlyConfiguredException("ASGI handlers must have a unique path not shared by other route handlers.") + + if ( + node.is_mount + and node.children + and any( + chain.from_iterable( + list(child.path_parameters.values()) + if isinstance(child.path_parameters, dict) + else child.path_parameters + for child in node.children.values() + ) + ) + ): + raise ImproperlyConfiguredException("Path parameters are not allowed under a static or mount route.") + + for child in node.children.values(): + if child is node: + continue + validate_node(node=child) diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/utils.py b/venv/lib/python3.11/site-packages/litestar/_asgi/utils.py new file mode 100644 index 0000000..c4111e0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_asgi/utils.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +__all__ = ("get_route_handlers", "wrap_in_exception_handler") + + +if TYPE_CHECKING: + from litestar.routes import ASGIRoute, HTTPRoute, WebSocketRoute + from litestar.routes.base import BaseRoute + from litestar.types import ASGIApp, ExceptionHandlersMap, RouteHandlerType + + +def wrap_in_exception_handler(app: ASGIApp, exception_handlers: ExceptionHandlersMap) -> ASGIApp: + """Wrap the given ASGIApp in an instance of ExceptionHandlerMiddleware. + + Args: + app: The ASGI app that is being wrapped. + exception_handlers: A mapping of exceptions to handler functions. + + Returns: + A wrapped ASGIApp. + """ + from litestar.middleware.exceptions import ExceptionHandlerMiddleware + + return ExceptionHandlerMiddleware(app=app, exception_handlers=exception_handlers, debug=None) + + +def get_route_handlers(route: BaseRoute) -> list[RouteHandlerType]: + """Retrieve handler(s) as a list for given route. + + Args: + route: The route from which the route handlers are extracted. + + Returns: + The route handlers defined on the route. + """ + route_handlers: list[RouteHandlerType] = [] + if hasattr(route, "route_handlers"): + route_handlers.extend(cast("HTTPRoute", route).route_handlers) + else: + route_handlers.append(cast("WebSocketRoute | ASGIRoute", route).route_handler) + + return route_handlers diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__init__.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/__init__.py new file mode 100644 index 0000000..af8ad36 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__init__.py @@ -0,0 +1,3 @@ +from .kwargs_model import KwargsModel + +__all__ = ("KwargsModel",) diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b61cc1b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/cleanup.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/cleanup.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a084eb5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/cleanup.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/dependencies.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/dependencies.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b0f49a4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/dependencies.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/extractors.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/extractors.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..35a6e40 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/extractors.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/kwargs_model.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/kwargs_model.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..17029de --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/kwargs_model.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/parameter_definition.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/parameter_definition.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..10753dc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/parameter_definition.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/cleanup.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/cleanup.py new file mode 100644 index 0000000..8839d36 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/cleanup.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from inspect import Traceback, isasyncgen +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator + +from anyio import create_task_group + +from litestar.utils import ensure_async_callable +from litestar.utils.compat import async_next + +__all__ = ("DependencyCleanupGroup",) + + +if TYPE_CHECKING: + from litestar.types import AnyGenerator + + +class DependencyCleanupGroup: + """Wrapper for generator based dependencies. + + Simplify cleanup by wrapping :func:`next` / :func:`anext` calls and providing facilities to + :meth:`throw <generator.throw>` / :meth:`athrow <agen.athrow>` into all generators consecutively. An instance of + this class can be used as a contextmanager, which will automatically throw any exceptions into its generators. All + exceptions caught in this manner will be re-raised after they have been thrown in the generators. + """ + + __slots__ = ("_generators", "_closed") + + def __init__(self, generators: list[AnyGenerator] | None = None) -> None: + """Initialize ``DependencyCleanupGroup``. + + Args: + generators: An optional list of generators to be called at cleanup + """ + self._generators = generators or [] + self._closed = False + + def add(self, generator: Generator[Any, None, None] | AsyncGenerator[Any, None]) -> None: + """Add a new generator to the group. + + Args: + generator: The generator to add + + Returns: + None + """ + if self._closed: + raise RuntimeError("Cannot call cleanup on a closed DependencyCleanupGroup") + self._generators.append(generator) + + @staticmethod + def _wrap_next(generator: AnyGenerator) -> Callable[[], Awaitable[None]]: + if isasyncgen(generator): + + async def wrapped_async() -> None: + await async_next(generator, None) + + return wrapped_async + + def wrapped() -> None: + next(generator, None) # type: ignore[arg-type] + + return ensure_async_callable(wrapped) + + async def cleanup(self) -> None: + """Execute cleanup by calling :func:`next` / :func:`anext` on all generators. + + If there are multiple generators to be called, they will be executed in a :class:`anyio.TaskGroup`. + + Returns: + None + """ + if self._closed: + raise RuntimeError("Cannot call cleanup on a closed DependencyCleanupGroup") + + self._closed = True + + if not self._generators: + return + + if len(self._generators) == 1: + await self._wrap_next(self._generators[0])() + return + + async with create_task_group() as task_group: + for generator in self._generators: + task_group.start_soon(self._wrap_next(generator)) + + async def __aenter__(self) -> None: + """Support the async contextmanager protocol to allow for easier catching and throwing of exceptions into the + generators. + """ + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Traceback | None, + ) -> None: + """If an exception was raised within the contextmanager block, throw it into all generators.""" + if exc_val: + await self.throw(exc_val) + + async def throw(self, exc: BaseException) -> None: + """Throw an exception in all generators sequentially. + + Args: + exc: Exception to throw + """ + for gen in self._generators: + try: + if isasyncgen(gen): + await gen.athrow(exc) + else: + gen.throw(exc) # type: ignore[union-attr] + except (StopIteration, StopAsyncIteration): + continue diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/dependencies.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/dependencies.py new file mode 100644 index 0000000..88ffb07 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/dependencies.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from litestar.utils.compat import async_next + +__all__ = ("Dependency", "create_dependency_batches", "map_dependencies_recursively", "resolve_dependency") + + +if TYPE_CHECKING: + from litestar._kwargs.cleanup import DependencyCleanupGroup + from litestar.connection import ASGIConnection + from litestar.di import Provide + + +class Dependency: + """Dependency graph of a given combination of ``Route`` + ``RouteHandler``""" + + __slots__ = ("key", "provide", "dependencies") + + def __init__(self, key: str, provide: Provide, dependencies: list[Dependency]) -> None: + """Initialize a dependency. + + Args: + key: The dependency key + provide: Provider + dependencies: List of child nodes + """ + self.key = key + self.provide = provide + self.dependencies = dependencies + + def __eq__(self, other: Any) -> bool: + # check if memory address is identical, otherwise compare attributes + return other is self or (isinstance(other, self.__class__) and other.key == self.key) + + def __hash__(self) -> int: + return hash(self.key) + + +async def resolve_dependency( + dependency: Dependency, + connection: ASGIConnection, + kwargs: dict[str, Any], + cleanup_group: DependencyCleanupGroup, +) -> None: + """Resolve a given instance of :class:`Dependency <litestar._kwargs.Dependency>`. + + All required sub dependencies must already + be resolved into the kwargs. The result of the dependency will be stored in the kwargs. + + Args: + dependency: An instance of :class:`Dependency <litestar._kwargs.Dependency>` + connection: An instance of :class:`Request <litestar.connection.Request>` or + :class:`WebSocket <litestar.connection.WebSocket>`. + kwargs: Any kwargs to pass to the dependency, the result will be stored here as well. + cleanup_group: DependencyCleanupGroup to which generators returned by ``dependency`` will be added + """ + signature_model = dependency.provide.signature_model + dependency_kwargs = ( + signature_model.parse_values_from_connection_kwargs(connection=connection, **kwargs) + if signature_model._fields + else {} + ) + value = await dependency.provide(**dependency_kwargs) + + if dependency.provide.has_sync_generator_dependency: + cleanup_group.add(value) + value = next(value) + elif dependency.provide.has_async_generator_dependency: + cleanup_group.add(value) + value = await async_next(value) + + kwargs[dependency.key] = value + + +def create_dependency_batches(expected_dependencies: set[Dependency]) -> list[set[Dependency]]: + """Calculate batches for all dependencies, recursively. + + Args: + expected_dependencies: A set of all direct :class:`Dependencies <litestar._kwargs.Dependency>`. + + Returns: + A list of batches. + """ + dependencies_to: dict[Dependency, set[Dependency]] = {} + for dependency in expected_dependencies: + if dependency not in dependencies_to: + map_dependencies_recursively(dependency, dependencies_to) + + batches = [] + while dependencies_to: + current_batch = { + dependency + for dependency, remaining_sub_dependencies in dependencies_to.items() + if not remaining_sub_dependencies + } + + for dependency in current_batch: + del dependencies_to[dependency] + for others_dependencies in dependencies_to.values(): + others_dependencies.discard(dependency) + + batches.append(current_batch) + + return batches + + +def map_dependencies_recursively(dependency: Dependency, dependencies_to: dict[Dependency, set[Dependency]]) -> None: + """Recursively map dependencies to their sub dependencies. + + Args: + dependency: The current dependency to map. + dependencies_to: A map of dependency to its sub dependencies. + """ + dependencies_to[dependency] = set(dependency.dependencies) + for sub in dependency.dependencies: + if sub not in dependencies_to: + map_dependencies_recursively(sub, dependencies_to) diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py new file mode 100644 index 0000000..e3b347e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py @@ -0,0 +1,492 @@ +from __future__ import annotations + +from collections import defaultdict +from functools import lru_cache, partial +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Mapping, NamedTuple, cast + +from litestar._multipart import parse_multipart_form +from litestar._parsers import ( + parse_query_string, + parse_url_encoded_form_data, +) +from litestar.datastructures import Headers +from litestar.datastructures.upload_file import UploadFile +from litestar.datastructures.url import URL +from litestar.enums import ParamType, RequestEncodingType +from litestar.exceptions import ValidationException +from litestar.params import BodyKwarg +from litestar.types import Empty +from litestar.utils.predicates import is_non_string_sequence +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from litestar._kwargs import KwargsModel + from litestar._kwargs.parameter_definition import ParameterDefinition + from litestar.connection import ASGIConnection, Request + from litestar.dto import AbstractDTO + from litestar.typing import FieldDefinition + + +__all__ = ( + "body_extractor", + "cookies_extractor", + "create_connection_value_extractor", + "create_data_extractor", + "create_multipart_extractor", + "create_query_default_dict", + "create_url_encoded_data_extractor", + "headers_extractor", + "json_extractor", + "msgpack_extractor", + "parse_connection_headers", + "parse_connection_query_params", + "query_extractor", + "request_extractor", + "scope_extractor", + "socket_extractor", + "state_extractor", +) + + +class ParamMappings(NamedTuple): + alias_and_key_tuples: list[tuple[str, str]] + alias_defaults: dict[str, Any] + alias_to_param: dict[str, ParameterDefinition] + + +def _create_param_mappings(expected_params: set[ParameterDefinition]) -> ParamMappings: + alias_and_key_tuples = [] + alias_defaults = {} + alias_to_params: dict[str, ParameterDefinition] = {} + for param in expected_params: + alias = param.field_alias + if param.param_type == ParamType.HEADER: + alias = alias.lower() + + alias_and_key_tuples.append((alias, param.field_name)) + + if not (param.is_required or param.default is Ellipsis): + alias_defaults[alias] = param.default + + alias_to_params[alias] = param + + return ParamMappings( + alias_and_key_tuples=alias_and_key_tuples, + alias_defaults=alias_defaults, + alias_to_param=alias_to_params, + ) + + +def create_connection_value_extractor( + kwargs_model: KwargsModel, + connection_key: str, + expected_params: set[ParameterDefinition], + parser: Callable[[ASGIConnection, KwargsModel], Mapping[str, Any]] | None = None, +) -> Callable[[dict[str, Any], ASGIConnection], None]: + """Create a kwargs extractor function. + + Args: + kwargs_model: The KwargsModel instance. + connection_key: The attribute key to use. + expected_params: The set of expected params. + parser: An optional parser function. + + Returns: + An extractor function. + """ + + alias_and_key_tuples, alias_defaults, alias_to_params = _create_param_mappings(expected_params) + + def extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + data = parser(connection, kwargs_model) if parser else getattr(connection, connection_key, {}) + + try: + connection_mapping: dict[str, Any] = { + key: data[alias] if alias in data else alias_defaults[alias] for alias, key in alias_and_key_tuples + } + values.update(connection_mapping) + except KeyError as e: + param = alias_to_params[e.args[0]] + path = URL.from_components( + path=connection.url.path, + query=connection.url.query, + ) + raise ValidationException( + f"Missing required {param.param_type.value} parameter {param.field_alias!r} for path {path}" + ) from e + + return extractor + + +@lru_cache(1024) +def create_query_default_dict( + parsed_query: tuple[tuple[str, str], ...], sequence_query_parameter_names: tuple[str, ...] +) -> defaultdict[str, list[str] | str]: + """Transform a list of tuples into a default dict. Ensures non-list values are not wrapped in a list. + + Args: + parsed_query: The parsed query list of tuples. + sequence_query_parameter_names: A set of query parameters that should be wrapped in list. + + Returns: + A default dict + """ + output: defaultdict[str, list[str] | str] = defaultdict(list) + + for k, v in parsed_query: + if k in sequence_query_parameter_names: + output[k].append(v) # type: ignore[union-attr] + else: + output[k] = v + + return output + + +def parse_connection_query_params(connection: ASGIConnection, kwargs_model: KwargsModel) -> dict[str, Any]: + """Parse query params and cache the result in scope. + + Args: + connection: The ASGI connection instance. + kwargs_model: The KwargsModel instance. + + Returns: + A dictionary of parsed values. + """ + parsed_query = ( + connection._parsed_query + if connection._parsed_query is not Empty + else parse_query_string(connection.scope.get("query_string", b"")) + ) + ScopeState.from_scope(connection.scope).parsed_query = parsed_query + return create_query_default_dict( + parsed_query=parsed_query, + sequence_query_parameter_names=kwargs_model.sequence_query_parameter_names, + ) + + +def parse_connection_headers(connection: ASGIConnection, _: KwargsModel) -> Headers: + """Parse header parameters and cache the result in scope. + + Args: + connection: The ASGI connection instance. + _: The KwargsModel instance. + + Returns: + A Headers instance + """ + return Headers.from_scope(connection.scope) + + +def state_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the app state from the connection and insert it to the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["state"] = connection.app.state._state + + +def headers_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the headers from the connection and insert them to the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + # TODO: This should be removed in 3.0 and instead Headers should be injected + # directly. We are only keeping this one around to not break things + values["headers"] = dict(connection.headers.items()) + + +def cookies_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the cookies from the connection and insert them to the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["cookies"] = connection.cookies + + +def query_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the query params from the connection and insert them to the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["query"] = connection.query_params + + +def scope_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the scope from the connection and insert it into the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["scope"] = connection.scope + + +def request_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Set the connection instance as the 'request' value in the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["request"] = connection + + +def socket_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Set the connection instance as the 'socket' value in the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["socket"] = connection + + +def body_extractor( + values: dict[str, Any], + connection: Request[Any, Any, Any], +) -> None: + """Extract the body from the request instance. + + Notes: + - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + The Body value. + """ + values["body"] = connection.body() + + +async def json_extractor(connection: Request[Any, Any, Any]) -> Any: + """Extract the data from request and insert it into the kwargs injected to the handler. + + Notes: + - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage. + + Args: + connection: The ASGI connection instance. + + Returns: + The JSON value. + """ + if not await connection.body(): + return Empty + return await connection.json() + + +async def msgpack_extractor(connection: Request[Any, Any, Any]) -> Any: + """Extract the data from request and insert it into the kwargs injected to the handler. + + Notes: + - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage. + + Args: + connection: The ASGI connection instance. + + Returns: + The MessagePack value. + """ + if not await connection.body(): + return Empty + return await connection.msgpack() + + +async def _extract_multipart( + connection: Request[Any, Any, Any], + body_kwarg_multipart_form_part_limit: int | None, + field_definition: FieldDefinition, + is_data_optional: bool, + data_dto: type[AbstractDTO] | None, +) -> Any: + multipart_form_part_limit = ( + body_kwarg_multipart_form_part_limit + if body_kwarg_multipart_form_part_limit is not None + else connection.app.multipart_form_part_limit + ) + connection.scope["_form"] = form_values = ( # type: ignore[typeddict-unknown-key] + connection.scope["_form"] # type: ignore[typeddict-item] + if "_form" in connection.scope + else parse_multipart_form( + body=await connection.body(), + boundary=connection.content_type[-1].get("boundary", "").encode(), + multipart_form_part_limit=multipart_form_part_limit, + type_decoders=connection.route_handler.resolve_type_decoders(), + ) + ) + + if field_definition.is_non_string_sequence: + values = list(form_values.values()) + if field_definition.has_inner_subclass_of(UploadFile) and isinstance(values[0], list): + return values[0] + + return values + + if field_definition.is_simple_type and field_definition.annotation is UploadFile and form_values: + return next(v for v in form_values.values() if isinstance(v, UploadFile)) + + if not form_values and is_data_optional: + return None + + if data_dto: + return data_dto(connection).decode_builtins(form_values) + + for name, tp in field_definition.get_type_hints().items(): + value = form_values.get(name) + if value is not None and is_non_string_sequence(tp) and not isinstance(value, list): + form_values[name] = [value] + + return form_values + + +def create_multipart_extractor( + field_definition: FieldDefinition, is_data_optional: bool, data_dto: type[AbstractDTO] | None +) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]: + """Create a multipart form-data extractor. + + Args: + field_definition: A FieldDefinition instance. + is_data_optional: Boolean dictating whether the field is optional. + data_dto: A data DTO type, if configured for handler. + + Returns: + An extractor function. + """ + body_kwarg_multipart_form_part_limit: int | None = None + if field_definition.kwarg_definition and isinstance(field_definition.kwarg_definition, BodyKwarg): + body_kwarg_multipart_form_part_limit = field_definition.kwarg_definition.multipart_form_part_limit + + extract_multipart = partial( + _extract_multipart, + body_kwarg_multipart_form_part_limit=body_kwarg_multipart_form_part_limit, + is_data_optional=is_data_optional, + data_dto=data_dto, + field_definition=field_definition, + ) + + return cast("Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", extract_multipart) + + +def create_url_encoded_data_extractor( + is_data_optional: bool, data_dto: type[AbstractDTO] | None +) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]: + """Create extractor for url encoded form-data. + + Args: + is_data_optional: Boolean dictating whether the field is optional. + data_dto: A data DTO type, if configured for handler. + + Returns: + An extractor function. + """ + + async def extract_url_encoded_extractor( + connection: Request[Any, Any, Any], + ) -> Any: + connection.scope["_form"] = form_values = ( # type: ignore[typeddict-unknown-key] + connection.scope["_form"] # type: ignore[typeddict-item] + if "_form" in connection.scope + else parse_url_encoded_form_data(await connection.body()) + ) + + if not form_values and is_data_optional: + return None + + return data_dto(connection).decode_builtins(form_values) if data_dto else form_values + + return cast( + "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", extract_url_encoded_extractor + ) + + +def create_data_extractor(kwargs_model: KwargsModel) -> Callable[[dict[str, Any], ASGIConnection], None]: + """Create an extractor for a request's body. + + Args: + kwargs_model: The KwargsModel instance. + + Returns: + An extractor for the request's body. + """ + + if kwargs_model.expected_form_data: + media_type, field_definition = kwargs_model.expected_form_data + + if media_type == RequestEncodingType.MULTI_PART: + data_extractor = create_multipart_extractor( + field_definition=field_definition, + is_data_optional=kwargs_model.is_data_optional, + data_dto=kwargs_model.expected_data_dto, + ) + else: + data_extractor = create_url_encoded_data_extractor( + is_data_optional=kwargs_model.is_data_optional, + data_dto=kwargs_model.expected_data_dto, + ) + elif kwargs_model.expected_msgpack_data: + data_extractor = cast( + "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", msgpack_extractor + ) + elif kwargs_model.expected_data_dto: + data_extractor = create_dto_extractor(data_dto=kwargs_model.expected_data_dto) + else: + data_extractor = cast( + "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", json_extractor + ) + + def extractor( + values: dict[str, Any], + connection: ASGIConnection[Any, Any, Any, Any], + ) -> None: + values["data"] = data_extractor(connection) + + return extractor + + +def create_dto_extractor( + data_dto: type[AbstractDTO], +) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]: + """Create a DTO data extractor. + + + Returns: + An extractor function. + """ + + async def dto_extractor(connection: Request[Any, Any, Any]) -> Any: + if not (body := await connection.body()): + return Empty + return data_dto(connection).decode_bytes(body) + + return dto_extractor # type:ignore[return-value] diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py new file mode 100644 index 0000000..01ed2e5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py @@ -0,0 +1,479 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable + +from anyio import create_task_group + +from litestar._kwargs.cleanup import DependencyCleanupGroup +from litestar._kwargs.dependencies import ( + Dependency, + create_dependency_batches, + resolve_dependency, +) +from litestar._kwargs.extractors import ( + body_extractor, + cookies_extractor, + create_connection_value_extractor, + create_data_extractor, + headers_extractor, + parse_connection_headers, + parse_connection_query_params, + query_extractor, + request_extractor, + scope_extractor, + socket_extractor, + state_extractor, +) +from litestar._kwargs.parameter_definition import ( + ParameterDefinition, + create_parameter_definition, + merge_parameter_sets, +) +from litestar.constants import RESERVED_KWARGS +from litestar.enums import ParamType, RequestEncodingType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.params import BodyKwarg, ParameterKwarg +from litestar.typing import FieldDefinition +from litestar.utils.helpers import get_exception_group + +__all__ = ("KwargsModel",) + + +if TYPE_CHECKING: + from litestar._signature import SignatureModel + from litestar.connection import ASGIConnection + from litestar.di import Provide + from litestar.dto import AbstractDTO + from litestar.utils.signature import ParsedSignature + +_ExceptionGroup = get_exception_group() + + +class KwargsModel: + """Model required kwargs for a given RouteHandler and its dependencies. + + This is done once and is memoized during application bootstrap, ensuring minimal runtime overhead. + """ + + __slots__ = ( + "dependency_batches", + "expected_cookie_params", + "expected_data_dto", + "expected_form_data", + "expected_header_params", + "expected_msgpack_data", + "expected_path_params", + "expected_query_params", + "expected_reserved_kwargs", + "extractors", + "has_kwargs", + "is_data_optional", + "sequence_query_parameter_names", + ) + + def __init__( + self, + *, + expected_cookie_params: set[ParameterDefinition], + expected_data_dto: type[AbstractDTO] | None, + expected_dependencies: set[Dependency], + expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None, + expected_header_params: set[ParameterDefinition], + expected_msgpack_data: FieldDefinition | None, + expected_path_params: set[ParameterDefinition], + expected_query_params: set[ParameterDefinition], + expected_reserved_kwargs: set[str], + is_data_optional: bool, + sequence_query_parameter_names: set[str], + ) -> None: + """Initialize ``KwargsModel``. + + Args: + expected_cookie_params: Any expected cookie parameter kwargs + expected_dependencies: Any expected dependency kwargs + expected_form_data: Any expected form data kwargs + expected_header_params: Any expected header parameter kwargs + expected_msgpack_data: Any expected MessagePack data kwargs + expected_path_params: Any expected path parameter kwargs + expected_query_params: Any expected query parameter kwargs + expected_reserved_kwargs: Any expected reserved kwargs, e.g. 'state' + expected_data_dto: A data DTO, if defined + is_data_optional: Treat data as optional + sequence_query_parameter_names: Any query parameters that are sequences + """ + self.expected_cookie_params = expected_cookie_params + self.expected_form_data = expected_form_data + self.expected_header_params = expected_header_params + self.expected_msgpack_data = expected_msgpack_data + self.expected_path_params = expected_path_params + self.expected_query_params = expected_query_params + self.expected_reserved_kwargs = expected_reserved_kwargs + self.expected_data_dto = expected_data_dto + self.sequence_query_parameter_names = tuple(sequence_query_parameter_names) + + self.has_kwargs = ( + expected_cookie_params + or expected_dependencies + or expected_form_data + or expected_msgpack_data + or expected_header_params + or expected_path_params + or expected_query_params + or expected_reserved_kwargs + or expected_data_dto + ) + + self.is_data_optional = is_data_optional + self.extractors = self._create_extractors() + self.dependency_batches = create_dependency_batches(expected_dependencies) + + def _create_extractors(self) -> list[Callable[[dict[str, Any], ASGIConnection], None]]: + reserved_kwargs_extractors: dict[str, Callable[[dict[str, Any], ASGIConnection], None]] = { + "data": create_data_extractor(self), + "state": state_extractor, + "scope": scope_extractor, + "request": request_extractor, + "socket": socket_extractor, + "headers": headers_extractor, + "cookies": cookies_extractor, + "query": query_extractor, + "body": body_extractor, # type: ignore[dict-item] + } + + extractors: list[Callable[[dict[str, Any], ASGIConnection], None]] = [ + reserved_kwargs_extractors[reserved_kwarg] for reserved_kwarg in self.expected_reserved_kwargs + ] + + if self.expected_header_params: + extractors.append( + create_connection_value_extractor( + connection_key="headers", + expected_params=self.expected_header_params, + kwargs_model=self, + parser=parse_connection_headers, + ), + ) + + if self.expected_path_params: + extractors.append( + create_connection_value_extractor( + connection_key="path_params", + expected_params=self.expected_path_params, + kwargs_model=self, + ), + ) + + if self.expected_cookie_params: + extractors.append( + create_connection_value_extractor( + connection_key="cookies", + expected_params=self.expected_cookie_params, + kwargs_model=self, + ), + ) + + if self.expected_query_params: + extractors.append( + create_connection_value_extractor( + connection_key="query_params", + expected_params=self.expected_query_params, + kwargs_model=self, + parser=parse_connection_query_params, + ), + ) + return extractors + + @classmethod + def _get_param_definitions( + cls, + path_parameters: set[str], + layered_parameters: dict[str, FieldDefinition], + dependencies: dict[str, Provide], + field_definitions: dict[str, FieldDefinition], + ) -> tuple[set[ParameterDefinition], set[Dependency]]: + """Get parameter_definitions for the construction of KwargsModel instance. + + Args: + path_parameters: Any expected path parameters. + layered_parameters: A string keyed dictionary of layered parameters. + dependencies: A string keyed dictionary mapping dependency providers. + field_definitions: The SignatureModel fields. + + Returns: + A Tuple of sets + """ + expected_dependencies = { + cls._create_dependency_graph(key=key, dependencies=dependencies) + for key in dependencies + if key in field_definitions + } + ignored_keys = {*RESERVED_KWARGS, *(dependency.key for dependency in expected_dependencies)} + + param_definitions = { + *( + create_parameter_definition( + field_definition=field_definition, + field_name=field_name, + path_parameters=path_parameters, + ) + for field_name, field_definition in layered_parameters.items() + if field_name not in ignored_keys and field_name not in field_definitions + ), + *( + create_parameter_definition( + field_definition=field_definition, + field_name=field_name, + path_parameters=path_parameters, + ) + for field_name, field_definition in field_definitions.items() + if field_name not in ignored_keys and field_name not in layered_parameters + ), + } + + for field_name, field_definition in ( + (k, v) for k, v in field_definitions.items() if k not in ignored_keys and k in layered_parameters + ): + layered_parameter = layered_parameters[field_name] + field = field_definition if field_definition.is_parameter_field else layered_parameter + default = field_definition.default if field_definition.has_default else layered_parameter.default + + param_definitions.add( + create_parameter_definition( + field_definition=FieldDefinition.from_kwarg( + name=field.name, + default=default, + inner_types=field.inner_types, + annotation=field.annotation, + kwarg_definition=field.kwarg_definition, + extra=field.extra, + ), + field_name=field_name, + path_parameters=path_parameters, + ) + ) + + return param_definitions, expected_dependencies + + @classmethod + def create_for_signature_model( + cls, + signature_model: type[SignatureModel], + parsed_signature: ParsedSignature, + dependencies: dict[str, Provide], + path_parameters: set[str], + layered_parameters: dict[str, FieldDefinition], + ) -> KwargsModel: + """Pre-determine what parameters are required for a given combination of route + route handler. It is executed + during the application bootstrap process. + + Args: + signature_model: A :class:`SignatureModel <litestar._signature.SignatureModel>` subclass. + parsed_signature: A :class:`ParsedSignature <litestar._signature.ParsedSignature>` instance. + dependencies: A string keyed dictionary mapping dependency providers. + path_parameters: Any expected path parameters. + layered_parameters: A string keyed dictionary of layered parameters. + + Returns: + An instance of KwargsModel + """ + + field_definitions = signature_model._fields + + cls._validate_raw_kwargs( + path_parameters=path_parameters, + dependencies=dependencies, + field_definitions=field_definitions, + layered_parameters=layered_parameters, + ) + + param_definitions, expected_dependencies = cls._get_param_definitions( + path_parameters=path_parameters, + layered_parameters=layered_parameters, + dependencies=dependencies, + field_definitions=field_definitions, + ) + + expected_reserved_kwargs = {field_name for field_name in field_definitions if field_name in RESERVED_KWARGS} + expected_path_parameters = {p for p in param_definitions if p.param_type == ParamType.PATH} + expected_header_parameters = {p for p in param_definitions if p.param_type == ParamType.HEADER} + expected_cookie_parameters = {p for p in param_definitions if p.param_type == ParamType.COOKIE} + expected_query_parameters = {p for p in param_definitions if p.param_type == ParamType.QUERY} + sequence_query_parameter_names = {p.field_alias for p in expected_query_parameters if p.is_sequence} + + expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None = None + expected_msgpack_data: FieldDefinition | None = None + expected_data_dto: type[AbstractDTO] | None = None + data_field_definition = field_definitions.get("data") + + media_type: RequestEncodingType | str | None = None + if data_field_definition: + if isinstance(data_field_definition.kwarg_definition, BodyKwarg): + media_type = data_field_definition.kwarg_definition.media_type + + if media_type in (RequestEncodingType.MULTI_PART, RequestEncodingType.URL_ENCODED): + expected_form_data = (media_type, data_field_definition) + expected_data_dto = signature_model._data_dto + elif signature_model._data_dto: + expected_data_dto = signature_model._data_dto + elif media_type == RequestEncodingType.MESSAGEPACK: + expected_msgpack_data = data_field_definition + + for dependency in expected_dependencies: + dependency_kwargs_model = cls.create_for_signature_model( + signature_model=dependency.provide.signature_model, + parsed_signature=parsed_signature, + dependencies=dependencies, + path_parameters=path_parameters, + layered_parameters=layered_parameters, + ) + expected_path_parameters = merge_parameter_sets( + expected_path_parameters, dependency_kwargs_model.expected_path_params + ) + expected_query_parameters = merge_parameter_sets( + expected_query_parameters, dependency_kwargs_model.expected_query_params + ) + expected_cookie_parameters = merge_parameter_sets( + expected_cookie_parameters, dependency_kwargs_model.expected_cookie_params + ) + expected_header_parameters = merge_parameter_sets( + expected_header_parameters, dependency_kwargs_model.expected_header_params + ) + + if "data" in expected_reserved_kwargs and "data" in dependency_kwargs_model.expected_reserved_kwargs: + cls._validate_dependency_data( + expected_form_data=expected_form_data, + dependency_kwargs_model=dependency_kwargs_model, + ) + + expected_reserved_kwargs.update(dependency_kwargs_model.expected_reserved_kwargs) + sequence_query_parameter_names.update(dependency_kwargs_model.sequence_query_parameter_names) + + return KwargsModel( + expected_cookie_params=expected_cookie_parameters, + expected_dependencies=expected_dependencies, + expected_data_dto=expected_data_dto, + expected_form_data=expected_form_data, + expected_header_params=expected_header_parameters, + expected_msgpack_data=expected_msgpack_data, + expected_path_params=expected_path_parameters, + expected_query_params=expected_query_parameters, + expected_reserved_kwargs=expected_reserved_kwargs, + is_data_optional=field_definitions["data"].is_optional if "data" in expected_reserved_kwargs else False, + sequence_query_parameter_names=sequence_query_parameter_names, + ) + + def to_kwargs(self, connection: ASGIConnection) -> dict[str, Any]: + """Return a dictionary of kwargs. Async values, i.e. CoRoutines, are not resolved to ensure this function is + sync. + + Args: + connection: An instance of :class:`Request <litestar.connection.Request>` or + :class:`WebSocket <litestar.connection.WebSocket>`. + + Returns: + A string keyed dictionary of kwargs expected by the handler function and its dependencies. + """ + output: dict[str, Any] = {} + + for extractor in self.extractors: + extractor(output, connection) + + return output + + async def resolve_dependencies(self, connection: ASGIConnection, kwargs: dict[str, Any]) -> DependencyCleanupGroup: + """Resolve all dependencies into the kwargs, recursively. + + Args: + connection: An instance of :class:`Request <litestar.connection.Request>` or + :class:`WebSocket <litestar.connection.WebSocket>`. + kwargs: Kwargs to pass to dependencies. + """ + cleanup_group = DependencyCleanupGroup() + for batch in self.dependency_batches: + if len(batch) == 1: + await resolve_dependency(next(iter(batch)), connection, kwargs, cleanup_group) + else: + try: + async with create_task_group() as task_group: + for dependency in batch: + task_group.start_soon(resolve_dependency, dependency, connection, kwargs, cleanup_group) + except _ExceptionGroup as excgroup: + raise excgroup.exceptions[0] from excgroup # type: ignore[attr-defined] + + return cleanup_group + + @classmethod + def _create_dependency_graph(cls, key: str, dependencies: dict[str, Provide]) -> Dependency: + """Create a graph like structure of dependencies, with each dependency including its own dependencies as a + list. + """ + provide = dependencies[key] + sub_dependency_keys = [k for k in provide.signature_model._fields if k in dependencies] + return Dependency( + key=key, + provide=provide, + dependencies=[cls._create_dependency_graph(key=k, dependencies=dependencies) for k in sub_dependency_keys], + ) + + @classmethod + def _validate_dependency_data( + cls, + expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None, + dependency_kwargs_model: KwargsModel, + ) -> None: + """Validate that the 'data' kwarg is compatible across dependencies.""" + if bool(expected_form_data) != bool(dependency_kwargs_model.expected_form_data): + raise ImproperlyConfiguredException( + "Dependencies have incompatible 'data' kwarg types: one expects JSON and the other expects form-data" + ) + if expected_form_data and dependency_kwargs_model.expected_form_data: + local_media_type = expected_form_data[0] + dependency_media_type = dependency_kwargs_model.expected_form_data[0] + if local_media_type != dependency_media_type: + raise ImproperlyConfiguredException( + "Dependencies have incompatible form-data encoding: one expects url-encoded and the other expects multi-part" + ) + + @classmethod + def _validate_raw_kwargs( + cls, + path_parameters: set[str], + dependencies: dict[str, Provide], + field_definitions: dict[str, FieldDefinition], + layered_parameters: dict[str, FieldDefinition], + ) -> None: + """Validate that there are no ambiguous kwargs, that is, kwargs declared using the same key in different + places. + """ + dependency_keys = set(dependencies.keys()) + + parameter_names = { + *( + k + for k, f in field_definitions.items() + if isinstance(f.kwarg_definition, ParameterKwarg) + and (f.kwarg_definition.header or f.kwarg_definition.query or f.kwarg_definition.cookie) + ), + *list(layered_parameters.keys()), + } + + intersection = ( + path_parameters.intersection(dependency_keys) + or path_parameters.intersection(parameter_names) + or dependency_keys.intersection(parameter_names) + ) + if intersection: + raise ImproperlyConfiguredException( + f"Kwarg resolution ambiguity detected for the following keys: {', '.join(intersection)}. " + f"Make sure to use distinct keys for your dependencies, path parameters, and aliased parameters." + ) + + if used_reserved_kwargs := { + *parameter_names, + *path_parameters, + *dependency_keys, + }.intersection(RESERVED_KWARGS): + raise ImproperlyConfiguredException( + f"Reserved kwargs ({', '.join(RESERVED_KWARGS)}) cannot be used for dependencies and parameter arguments. " + f"The following kwargs have been used: {', '.join(used_reserved_kwargs)}" + ) diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/parameter_definition.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/parameter_definition.py new file mode 100644 index 0000000..02b09fc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/parameter_definition.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, NamedTuple + +from litestar.enums import ParamType +from litestar.params import ParameterKwarg + +if TYPE_CHECKING: + from litestar.typing import FieldDefinition + +__all__ = ("ParameterDefinition", "create_parameter_definition", "merge_parameter_sets") + + +class ParameterDefinition(NamedTuple): + """Tuple defining a kwarg representing a request parameter.""" + + default: Any + field_alias: str + field_name: str + is_required: bool + is_sequence: bool + param_type: ParamType + + +def create_parameter_definition( + field_definition: FieldDefinition, + field_name: str, + path_parameters: set[str], +) -> ParameterDefinition: + """Create a ParameterDefinition for the given FieldDefinition. + + Args: + field_definition: FieldDefinition instance. + field_name: The field's name. + path_parameters: A set of path parameter names. + + Returns: + A ParameterDefinition tuple. + """ + default = field_definition.default if field_definition.has_default else None + kwarg_definition = ( + field_definition.kwarg_definition if isinstance(field_definition.kwarg_definition, ParameterKwarg) else None + ) + + field_alias = kwarg_definition.query if kwarg_definition and kwarg_definition.query else field_name + param_type = ParamType.QUERY + + if field_name in path_parameters: + field_alias = field_name + param_type = ParamType.PATH + elif kwarg_definition and kwarg_definition.header: + field_alias = kwarg_definition.header + param_type = ParamType.HEADER + elif kwarg_definition and kwarg_definition.cookie: + field_alias = kwarg_definition.cookie + param_type = ParamType.COOKIE + + return ParameterDefinition( + param_type=param_type, + field_name=field_name, + field_alias=field_alias, + default=default, + is_required=field_definition.is_required + and default is None + and not field_definition.is_optional + and not field_definition.is_any, + is_sequence=field_definition.is_non_string_sequence, + ) + + +def merge_parameter_sets(first: set[ParameterDefinition], second: set[ParameterDefinition]) -> set[ParameterDefinition]: + """Given two sets of parameter definitions, coming from different dependencies for example, merge them into a single + set. + """ + result: set[ParameterDefinition] = first.intersection(second) + difference = first.symmetric_difference(second) + for param in difference: + # add the param if it's either required or no-other param in difference is the same but required + if param.is_required or not any(p.field_alias == param.field_alias and p.is_required for p in difference): + result.add(param) + return result diff --git a/venv/lib/python3.11/site-packages/litestar/_layers/__init__.py b/venv/lib/python3.11/site-packages/litestar/_layers/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_layers/__init__.py diff --git a/venv/lib/python3.11/site-packages/litestar/_layers/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_layers/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..7b7425e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_layers/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_layers/__pycache__/utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_layers/__pycache__/utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..d772e10 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_layers/__pycache__/utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_layers/utils.py b/venv/lib/python3.11/site-packages/litestar/_layers/utils.py new file mode 100644 index 0000000..61afd61 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_layers/utils.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping, Sequence + +from litestar.datastructures.cookie import Cookie +from litestar.datastructures.response_header import ResponseHeader + +__all__ = ("narrow_response_cookies", "narrow_response_headers") + + +if TYPE_CHECKING: + from litestar.types.composite_types import ResponseCookies, ResponseHeaders + + +def narrow_response_headers(headers: ResponseHeaders | None) -> Sequence[ResponseHeader] | None: + """Given :class:`.types.ResponseHeaders` as a :class:`typing.Mapping`, create a list of + :class:`.datastructures.response_header.ResponseHeader` from it, otherwise return ``headers`` unchanged + """ + return ( + tuple(ResponseHeader(name=name, value=value) for name, value in headers.items()) + if isinstance(headers, Mapping) + else headers + ) + + +def narrow_response_cookies(cookies: ResponseCookies | None) -> Sequence[Cookie] | None: + """Given :class:`.types.ResponseCookies` as a :class:`typing.Mapping`, create a list of + :class:`.datastructures.cookie.Cookie` from it, otherwise return ``cookies`` unchanged + """ + return ( + tuple(Cookie(key=key, value=value) for key, value in cookies.items()) + if isinstance(cookies, Mapping) + else cookies + ) diff --git a/venv/lib/python3.11/site-packages/litestar/_multipart.py b/venv/lib/python3.11/site-packages/litestar/_multipart.py new file mode 100644 index 0000000..55b3620 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_multipart.py @@ -0,0 +1,163 @@ +"""The contents of this file were adapted from sanic. + +MIT License + +Copyright (c) 2016-present Sanic Community + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from __future__ import annotations + +import re +from collections import defaultdict +from email.utils import decode_rfc2231 +from typing import TYPE_CHECKING, Any +from urllib.parse import unquote + +from litestar.datastructures.upload_file import UploadFile +from litestar.exceptions import ValidationException + +__all__ = ("parse_body", "parse_content_header", "parse_multipart_form") + + +if TYPE_CHECKING: + from litestar.types import TypeDecodersSequence + +_token = r"([\w!#$%&'*+\-.^_`|~]+)" # noqa: S105 +_quoted = r'"([^"]*)"' +_param = re.compile(rf";\s*{_token}=(?:{_token}|{_quoted})", re.ASCII) +_firefox_quote_escape = re.compile(r'\\"(?!; |\s*$)') + + +def parse_content_header(value: str) -> tuple[str, dict[str, str]]: + """Parse content-type and content-disposition header values. + + Args: + value: A header string value to parse. + + Returns: + A tuple containing the normalized header string and a dictionary of parameters. + """ + value = _firefox_quote_escape.sub("%22", value) + pos = value.find(";") + if pos == -1: + options: dict[str, str] = {} + else: + options = { + m.group(1).lower(): m.group(2) or m.group(3).replace("%22", '"') for m in _param.finditer(value[pos:]) + } + value = value[:pos] + return value.strip().lower(), options + + +def parse_body(body: bytes, boundary: bytes, multipart_form_part_limit: int) -> list[bytes]: + """Split the body using the boundary + and validate the number of form parts is within the allowed limit. + + Args: + body: The form body. + boundary: The boundary used to separate form components. + multipart_form_part_limit: The limit of allowed form components + + Returns: + A list of form components. + """ + if not (body and boundary): + return [] + + form_parts = body.split(boundary, multipart_form_part_limit + 3)[1:-1] + + if len(form_parts) > multipart_form_part_limit: + raise ValidationException( + f"number of multipart components exceeds the allowed limit of {multipart_form_part_limit}, " + f"this potentially indicates a DoS attack" + ) + + return form_parts + + +def parse_multipart_form( + body: bytes, + boundary: bytes, + multipart_form_part_limit: int = 1000, + type_decoders: TypeDecodersSequence | None = None, +) -> dict[str, Any]: + """Parse multipart form data. + + Args: + body: Body of the request. + boundary: Boundary of the multipart message. + multipart_form_part_limit: Limit of the number of parts allowed. + type_decoders: A sequence of type decoders to use. + + Returns: + A dictionary of parsed results. + """ + + fields: defaultdict[str, list[Any]] = defaultdict(list) + + for form_part in parse_body(body=body, boundary=boundary, multipart_form_part_limit=multipart_form_part_limit): + file_name = None + content_type = "text/plain" + content_charset = "utf-8" + field_name = None + line_index = 2 + line_end_index = 0 + headers: list[tuple[str, str]] = [] + + while line_end_index != -1: + line_end_index = form_part.find(b"\r\n", line_index) + form_line = form_part[line_index:line_end_index].decode("utf-8") + + if not form_line: + break + + line_index = line_end_index + 2 + colon_index = form_line.index(":") + current_idx = colon_index + 2 + form_header_field = form_line[:colon_index].lower() + form_header_value, form_parameters = parse_content_header(form_line[current_idx:]) + + if form_header_field == "content-disposition": + field_name = form_parameters.get("name") + file_name = form_parameters.get("filename") + + if file_name is None and (filename_with_asterisk := form_parameters.get("filename*")): + encoding, _, value = decode_rfc2231(filename_with_asterisk) + file_name = unquote(value, encoding=encoding or content_charset) + + elif form_header_field == "content-type": + content_type = form_header_value + content_charset = form_parameters.get("charset", "utf-8") + headers.append((form_header_field, form_header_value)) + + if field_name: + post_data = form_part[line_index:-4].lstrip(b"\r\n") + if file_name: + form_file = UploadFile( + content_type=content_type, filename=file_name, file_data=post_data, headers=dict(headers) + ) + fields[field_name].append(form_file) + elif post_data: + fields[field_name].append(post_data.decode(content_charset)) + else: + fields[field_name].append(None) + + return {k: v if len(v) > 1 else v[0] for k, v in fields.items()} diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/__init__.py b/venv/lib/python3.11/site-packages/litestar/_openapi/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/__init__.py diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..e090039 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/datastructures.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/datastructures.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..9dd5b10 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/datastructures.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/parameters.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/parameters.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1a17187 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/parameters.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/path_item.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/path_item.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a440e7d --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/path_item.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/plugin.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/plugin.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..13b3ab3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/plugin.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/request_body.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/request_body.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2f093bf --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/request_body.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/responses.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/responses.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..10d877b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/responses.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..97c7c13 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/__pycache__/utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/datastructures.py b/venv/lib/python3.11/site-packages/litestar/_openapi/datastructures.py new file mode 100644 index 0000000..d97c8db --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/datastructures.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import TYPE_CHECKING, Iterator, Sequence + +from litestar.exceptions import ImproperlyConfiguredException +from litestar.openapi.spec import Reference, Schema + +if TYPE_CHECKING: + from litestar.openapi import OpenAPIConfig + from litestar.plugins import OpenAPISchemaPluginProtocol + + +class RegisteredSchema: + """Object to store a schema and any references to it.""" + + def __init__(self, key: tuple[str, ...], schema: Schema, references: list[Reference]) -> None: + """Create a new RegisteredSchema object. + + Args: + key: The key used to register the schema. + schema: The schema object. + references: A list of references to the schema. + """ + self.key = key + self.schema = schema + self.references = references + + +class SchemaRegistry: + """A registry for object schemas. + + This class is used to store schemas that we reference from other parts of the spec. + + Its main purpose is to allow us to generate the components/schemas section of the spec once we have + collected all the schemas that should be included. + + This allows us to determine a path to the schema in the components/schemas section of the spec that + is unique and as short as possible. + """ + + def __init__(self) -> None: + self._schema_key_map: dict[tuple[str, ...], RegisteredSchema] = {} + self._schema_reference_map: dict[int, RegisteredSchema] = {} + self._model_name_groups: defaultdict[str, list[RegisteredSchema]] = defaultdict(list) + + def get_schema_for_key(self, key: tuple[str, ...]) -> Schema: + """Get a registered schema by its key. + + Args: + key: The key to the schema to get. + + Returns: + A RegisteredSchema object. + """ + if key not in self._schema_key_map: + self._schema_key_map[key] = registered_schema = RegisteredSchema(key, Schema(), []) + self._model_name_groups[key[-1]].append(registered_schema) + return self._schema_key_map[key].schema + + def get_reference_for_key(self, key: tuple[str, ...]) -> Reference | None: + """Get a reference to a registered schema by its key. + + Args: + key: The key to the schema to get. + + Returns: + A Reference object. + """ + if key not in self._schema_key_map: + return None + registered_schema = self._schema_key_map[key] + reference = Reference(f"#/components/schemas/{'_'.join(key)}") + registered_schema.references.append(reference) + self._schema_reference_map[id(reference)] = registered_schema + return reference + + def from_reference(self, reference: Reference) -> RegisteredSchema: + """Get a registered schema by its reference. + + Args: + reference: The reference to the schema to get. + + Returns: + A RegisteredSchema object. + """ + return self._schema_reference_map[id(reference)] + + def __iter__(self) -> Iterator[RegisteredSchema]: + """Iterate over the registered schemas.""" + return iter(self._schema_key_map.values()) + + @staticmethod + def set_reference_paths(name: str, registered_schema: RegisteredSchema) -> None: + """Set the reference paths for a registered schema.""" + for reference in registered_schema.references: + reference.ref = f"#/components/schemas/{name}" + + @staticmethod + def remove_common_prefix(tuples: list[tuple[str, ...]]) -> list[tuple[str, ...]]: + """Remove the common prefix from a list of tuples. + + Args: + tuples: A list of tuples to remove the common prefix from. + + Returns: + A list of tuples with the common prefix removed. + """ + + def longest_common_prefix(tuples_: list[tuple[str, ...]]) -> tuple[str, ...]: + """Find the longest common prefix of a list of tuples. + + Args: + tuples_: A list of tuples to find the longest common prefix of. + + Returns: + The longest common prefix of the tuples. + """ + prefix_ = tuples_[0] + for t in tuples_: + # Compare the current prefix with each tuple and shorten it + prefix_ = prefix_[: min(len(prefix_), len(t))] + for i in range(len(prefix_)): + if prefix_[i] != t[i]: + prefix_ = prefix_[:i] + break + return prefix_ + + prefix = longest_common_prefix(tuples) + prefix_length = len(prefix) + return [t[prefix_length:] for t in tuples] + + def generate_components_schemas(self) -> dict[str, Schema]: + """Generate the components/schemas section of the spec. + + Returns: + A dictionary of schemas. + """ + components_schemas: dict[str, Schema] = {} + + for name, name_group in self._model_name_groups.items(): + if len(name_group) == 1: + self.set_reference_paths(name, name_group[0]) + components_schemas[name] = name_group[0].schema + continue + + full_keys = [registered_schema.key for registered_schema in name_group] + names = ["_".join(k) for k in self.remove_common_prefix(full_keys)] + for name_, registered_schema in zip(names, name_group): + self.set_reference_paths(name_, registered_schema) + components_schemas[name_] = registered_schema.schema + + # Sort them by name to ensure they're always generated in the same order. + return {name: components_schemas[name] for name in sorted(components_schemas.keys())} + + +class OpenAPIContext: + def __init__( + self, + openapi_config: OpenAPIConfig, + plugins: Sequence[OpenAPISchemaPluginProtocol], + ) -> None: + self.openapi_config = openapi_config + self.plugins = plugins + self.operation_ids: set[str] = set() + self.schema_registry = SchemaRegistry() + + def add_operation_id(self, operation_id: str) -> None: + """Add an operation ID to the context. + + Args: + operation_id: Operation ID to add. + """ + if operation_id in self.operation_ids: + raise ImproperlyConfiguredException( + "operation_ids must be unique, " + f"please ensure the value of 'operation_id' is either not set or unique for {operation_id}" + ) + self.operation_ids.add(operation_id) diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/parameters.py b/venv/lib/python3.11/site-packages/litestar/_openapi/parameters.py new file mode 100644 index 0000000..c3da5c4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/parameters.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from litestar._openapi.schema_generation import SchemaCreator +from litestar._openapi.schema_generation.utils import get_formatted_examples +from litestar.constants import RESERVED_KWARGS +from litestar.enums import ParamType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.openapi.spec.parameter import Parameter +from litestar.openapi.spec.schema import Schema +from litestar.params import DependencyKwarg, ParameterKwarg +from litestar.types import Empty +from litestar.typing import FieldDefinition + +if TYPE_CHECKING: + from litestar._openapi.datastructures import OpenAPIContext + from litestar.handlers.base import BaseRouteHandler + from litestar.openapi.spec import Reference + from litestar.types.internal_types import PathParameterDefinition + +__all__ = ("create_parameters_for_handler",) + + +class ParameterCollection: + """Facilitates conditional deduplication of parameters. + + If multiple parameters with the same name are produced for a handler, the condition is ignored if the two + ``Parameter`` instances are the same (the first is retained and any duplicates are ignored). If the ``Parameter`` + instances are not the same, an exception is raised. + """ + + def __init__(self, route_handler: BaseRouteHandler) -> None: + """Initialize ``ParameterCollection``. + + Args: + route_handler: Associated route handler + """ + self.route_handler = route_handler + self._parameters: dict[tuple[str, str], Parameter] = {} + + def add(self, parameter: Parameter) -> None: + """Add a ``Parameter`` to the collection. + + If an existing parameter with the same name and type already exists, the + parameter is ignored. + + If an existing parameter with the same name but different type exists, raises + ``ImproperlyConfiguredException``. + """ + + if (parameter.name, parameter.param_in) not in self._parameters: + # because we are defining routes as unique per path, we have to handle here a situation when there is an optional + # path parameter. e.g. get(path=["/", "/{param:str}"]). When parsing the parameter for path, the route handler + # would still have a kwarg called param: + # def handler(param: str | None) -> ... + if parameter.param_in != ParamType.QUERY or all( + f"{{{parameter.name}:" not in path for path in self.route_handler.paths + ): + self._parameters[(parameter.name, parameter.param_in)] = parameter + return + + pre_existing = self._parameters[(parameter.name, parameter.param_in)] + if parameter == pre_existing: + return + + raise ImproperlyConfiguredException( + f"OpenAPI schema generation for handler `{self.route_handler}` detected multiple parameters named " + f"'{parameter.name}' with different types." + ) + + def list(self) -> list[Parameter]: + """Return a list of all ``Parameter``'s in the collection.""" + return list(self._parameters.values()) + + +class ParameterFactory: + """Factory for creating OpenAPI Parameters for a given route handler.""" + + def __init__( + self, + context: OpenAPIContext, + route_handler: BaseRouteHandler, + path_parameters: tuple[PathParameterDefinition, ...], + ) -> None: + """Initialize ParameterFactory. + + Args: + context: The OpenAPI context. + route_handler: The route handler. + path_parameters: The path parameters for the route. + """ + self.context = context + self.schema_creator = SchemaCreator.from_openapi_context(self.context, prefer_alias=True) + self.route_handler = route_handler + self.parameters = ParameterCollection(route_handler) + self.dependency_providers = route_handler.resolve_dependencies() + self.layered_parameters = route_handler.resolve_layered_parameters() + self.path_parameters_names = {p.name for p in path_parameters} + + def create_parameter(self, field_definition: FieldDefinition, parameter_name: str) -> Parameter: + """Create an OpenAPI Parameter instance for a field definition. + + Args: + field_definition: The field definition. + parameter_name: The name of the parameter. + """ + + result: Schema | Reference | None = None + kwarg_definition = ( + field_definition.kwarg_definition if isinstance(field_definition.kwarg_definition, ParameterKwarg) else None + ) + + if parameter_name in self.path_parameters_names: + param_in = ParamType.PATH + is_required = True + result = self.schema_creator.for_field_definition(field_definition) + elif kwarg_definition and kwarg_definition.header: + parameter_name = kwarg_definition.header + param_in = ParamType.HEADER + is_required = field_definition.is_required + elif kwarg_definition and kwarg_definition.cookie: + parameter_name = kwarg_definition.cookie + param_in = ParamType.COOKIE + is_required = field_definition.is_required + else: + is_required = field_definition.is_required + param_in = ParamType.QUERY + parameter_name = kwarg_definition.query if kwarg_definition and kwarg_definition.query else parameter_name + + if not result: + result = self.schema_creator.for_field_definition(field_definition) + + schema = result if isinstance(result, Schema) else self.context.schema_registry.from_reference(result).schema + + examples_list = kwarg_definition.examples or [] if kwarg_definition else [] + examples = get_formatted_examples(field_definition, examples_list) + + return Parameter( + description=schema.description, + name=parameter_name, + param_in=param_in, + required=is_required, + schema=result, + examples=examples or None, + ) + + def get_layered_parameter(self, field_name: str, field_definition: FieldDefinition) -> Parameter: + """Create a parameter for a field definition that has a KwargDefinition defined on the layers. + + Args: + field_name: The name of the field. + field_definition: The field definition. + """ + layer_field = self.layered_parameters[field_name] + + field = field_definition if field_definition.is_parameter_field else layer_field + default = layer_field.default if field_definition.has_default else field_definition.default + annotation = field_definition.annotation if field_definition is not Empty else layer_field.annotation + + parameter_name = field_name + if isinstance(field.kwarg_definition, ParameterKwarg): + parameter_name = ( + field.kwarg_definition.query + or field.kwarg_definition.header + or field.kwarg_definition.cookie + or field_name + ) + + field_definition = FieldDefinition.from_kwarg( + inner_types=field.inner_types, + default=default, + extra=field.extra, + annotation=annotation, + kwarg_definition=field.kwarg_definition, + name=field_name, + ) + return self.create_parameter(field_definition=field_definition, parameter_name=parameter_name) + + def create_parameters_for_field_definitions(self, fields: dict[str, FieldDefinition]) -> None: + """Add Parameter models to the handler's collection for the given field definitions. + + Args: + fields: The field definitions. + """ + unique_handler_fields = ( + (k, v) for k, v in fields.items() if k not in RESERVED_KWARGS and k not in self.layered_parameters + ) + unique_layered_fields = ( + (k, v) for k, v in self.layered_parameters.items() if k not in RESERVED_KWARGS and k not in fields + ) + intersection_fields = ( + (k, v) for k, v in fields.items() if k not in RESERVED_KWARGS and k in self.layered_parameters + ) + + for field_name, field_definition in unique_handler_fields: + if ( + isinstance(field_definition.kwarg_definition, DependencyKwarg) + and field_name not in self.dependency_providers + ): + # never document explicit dependencies + continue + + if provider := self.dependency_providers.get(field_name): + self.create_parameters_for_field_definitions(fields=provider.parsed_fn_signature.parameters) + else: + self.parameters.add(self.create_parameter(field_definition=field_definition, parameter_name=field_name)) + + for field_name, field_definition in unique_layered_fields: + self.parameters.add(self.create_parameter(field_definition=field_definition, parameter_name=field_name)) + + for field_name, field_definition in intersection_fields: + self.parameters.add(self.get_layered_parameter(field_name=field_name, field_definition=field_definition)) + + def create_parameters_for_handler(self) -> list[Parameter]: + """Create a list of path/query/header Parameter models for the given PathHandler.""" + handler_fields = self.route_handler.parsed_fn_signature.parameters + self.create_parameters_for_field_definitions(handler_fields) + return self.parameters.list() + + +def create_parameters_for_handler( + context: OpenAPIContext, + route_handler: BaseRouteHandler, + path_parameters: tuple[PathParameterDefinition, ...], +) -> list[Parameter]: + """Create a list of path/query/header Parameter models for the given PathHandler.""" + factory = ParameterFactory( + context=context, + route_handler=route_handler, + path_parameters=path_parameters, + ) + return factory.create_parameters_for_handler() diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/path_item.py b/venv/lib/python3.11/site-packages/litestar/_openapi/path_item.py new file mode 100644 index 0000000..74a04ce --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/path_item.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +from inspect import cleandoc +from typing import TYPE_CHECKING + +from litestar._openapi.parameters import create_parameters_for_handler +from litestar._openapi.request_body import create_request_body +from litestar._openapi.responses import create_responses_for_handler +from litestar._openapi.utils import SEPARATORS_CLEANUP_PATTERN +from litestar.enums import HttpMethod +from litestar.openapi.spec import Operation, PathItem +from litestar.utils.helpers import unwrap_partial + +if TYPE_CHECKING: + from litestar._openapi.datastructures import OpenAPIContext + from litestar.handlers.http_handlers import HTTPRouteHandler + from litestar.routes import HTTPRoute + +__all__ = ("create_path_item_for_route",) + + +class PathItemFactory: + """Factory for creating a PathItem instance for a given route.""" + + def __init__(self, openapi_context: OpenAPIContext, route: HTTPRoute) -> None: + self.context = openapi_context + self.route = route + self._path_item = PathItem() + + def create_path_item(self) -> PathItem: + """Create a PathItem for the given route parsing all http_methods into Operation Models. + + Returns: + A PathItem instance. + """ + for http_method, handler_tuple in self.route.route_handler_map.items(): + route_handler, _ = handler_tuple + + if not route_handler.resolve_include_in_schema(): + continue + + operation = self.create_operation_for_handler_method(route_handler, HttpMethod(http_method)) + + setattr(self._path_item, http_method.lower(), operation) + + return self._path_item + + def create_operation_for_handler_method( + self, route_handler: HTTPRouteHandler, http_method: HttpMethod + ) -> Operation: + """Create an Operation instance for a given route handler and http method. + + Args: + route_handler: A route handler instance. + http_method: An HttpMethod enum value. + + Returns: + An Operation instance. + """ + operation_id = self.create_operation_id(route_handler, http_method) + parameters = create_parameters_for_handler(self.context, route_handler, self.route.path_parameters) + signature_fields = route_handler.parsed_fn_signature.parameters + + request_body = None + if data_field := signature_fields.get("data"): + request_body = create_request_body( + self.context, route_handler.handler_id, route_handler.resolve_data_dto(), data_field + ) + + raises_validation_error = bool(data_field or self._path_item.parameters or parameters) + responses = create_responses_for_handler( + self.context, route_handler, raises_validation_error=raises_validation_error + ) + + return route_handler.operation_class( + operation_id=operation_id, + tags=route_handler.resolve_tags() or None, + summary=route_handler.summary or SEPARATORS_CLEANUP_PATTERN.sub("", route_handler.handler_name.title()), + description=self.create_description_for_handler(route_handler), + deprecated=route_handler.deprecated, + responses=responses, + request_body=request_body, + parameters=parameters or None, # type: ignore[arg-type] + security=route_handler.resolve_security() or None, + ) + + def create_operation_id(self, route_handler: HTTPRouteHandler, http_method: HttpMethod) -> str: + """Create an operation id for a given route handler and http method. + + Adds the operation id to the context's operation id set, where it is checked for uniqueness. + + Args: + route_handler: A route handler instance. + http_method: An HttpMethod enum value. + + Returns: + An operation id string. + """ + if isinstance(route_handler.operation_id, str): + operation_id = route_handler.operation_id + elif callable(route_handler.operation_id): + operation_id = route_handler.operation_id(route_handler, http_method, self.route.path_components) + else: + operation_id = self.context.openapi_config.operation_id_creator( + route_handler, http_method, self.route.path_components + ) + self.context.add_operation_id(operation_id) + return operation_id + + def create_description_for_handler(self, route_handler: HTTPRouteHandler) -> str | None: + """Produce the operation description for a route handler. + + Args: + route_handler: A route handler instance. + + Returns: + An optional description string + """ + handler_description = route_handler.description + if handler_description is None and self.context.openapi_config.use_handler_docstrings: + fn = unwrap_partial(route_handler.fn) + return cleandoc(fn.__doc__) if fn.__doc__ else None + return handler_description + + +def create_path_item_for_route(openapi_context: OpenAPIContext, route: HTTPRoute) -> PathItem: + """Create a PathItem for the given route parsing all http_methods into Operation Models. + + Args: + openapi_context: The OpenAPIContext instance. + route: The route to create a PathItem for. + + Returns: + A PathItem instance. + """ + path_item_factory = PathItemFactory(openapi_context, route) + return path_item_factory.create_path_item() diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/plugin.py b/venv/lib/python3.11/site-packages/litestar/_openapi/plugin.py new file mode 100644 index 0000000..9bdbdec --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/plugin.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from litestar._openapi.datastructures import OpenAPIContext +from litestar._openapi.path_item import create_path_item_for_route +from litestar.exceptions import ImproperlyConfiguredException +from litestar.plugins import InitPluginProtocol +from litestar.plugins.base import ReceiveRoutePlugin +from litestar.routes import HTTPRoute + +if TYPE_CHECKING: + from litestar.app import Litestar + from litestar.config.app import AppConfig + from litestar.openapi.config import OpenAPIConfig + from litestar.openapi.spec import OpenAPI + from litestar.routes import BaseRoute + + +class OpenAPIPlugin(InitPluginProtocol, ReceiveRoutePlugin): + __slots__ = ( + "app", + "included_routes", + "_openapi_config", + "_openapi_schema", + ) + + def __init__(self, app: Litestar) -> None: + self.app = app + self.included_routes: dict[str, HTTPRoute] = {} + self._openapi_config: OpenAPIConfig | None = None + self._openapi_schema: OpenAPI | None = None + + def _build_openapi_schema(self) -> OpenAPI: + openapi_config = self.openapi_config + + if openapi_config.create_examples: + from litestar._openapi.schema_generation.examples import ExampleFactory + + ExampleFactory.seed_random(openapi_config.random_seed) + + openapi = openapi_config.to_openapi_schema() + context = OpenAPIContext(openapi_config=openapi_config, plugins=self.app.plugins.openapi) + openapi.paths = { + route.path_format or "/": create_path_item_for_route(context, route) + for route in self.included_routes.values() + } + openapi.components.schemas = context.schema_registry.generate_components_schemas() + return openapi + + def provide_openapi(self) -> OpenAPI: + if not self._openapi_schema: + self._openapi_schema = self._build_openapi_schema() + return self._openapi_schema + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + if app_config.openapi_config: + self._openapi_config = app_config.openapi_config + app_config.route_handlers.append(self.openapi_config.openapi_controller) + return app_config + + @property + def openapi_config(self) -> OpenAPIConfig: + if not self._openapi_config: + raise ImproperlyConfiguredException("OpenAPIConfig not initialized") + return self._openapi_config + + def receive_route(self, route: BaseRoute) -> None: + if not isinstance(route, HTTPRoute): + return + + if any(route_handler.resolve_include_in_schema() for route_handler, _ in route.route_handler_map.values()): + # Force recompute the schema if a new route is added + self._openapi_schema = None + self.included_routes[route.path] = route diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/request_body.py b/venv/lib/python3.11/site-packages/litestar/_openapi/request_body.py new file mode 100644 index 0000000..7a5cf37 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/request_body.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from litestar._openapi.schema_generation import SchemaCreator +from litestar.enums import RequestEncodingType +from litestar.openapi.spec.media_type import OpenAPIMediaType +from litestar.openapi.spec.request_body import RequestBody +from litestar.params import BodyKwarg + +__all__ = ("create_request_body",) + + +if TYPE_CHECKING: + from litestar._openapi.datastructures import OpenAPIContext + from litestar.dto import AbstractDTO + from litestar.typing import FieldDefinition + + +def create_request_body( + context: OpenAPIContext, + handler_id: str, + resolved_data_dto: type[AbstractDTO] | None, + data_field: FieldDefinition, +) -> RequestBody: + """Create a RequestBody instance for the given route handler's data field. + + Args: + context: The OpenAPIContext instance. + handler_id: The handler id. + resolved_data_dto: The resolved data dto. + data_field: The data field. + + Returns: + A RequestBody instance. + """ + media_type: RequestEncodingType | str = RequestEncodingType.JSON + schema_creator = SchemaCreator.from_openapi_context(context, prefer_alias=True) + if isinstance(data_field.kwarg_definition, BodyKwarg) and data_field.kwarg_definition.media_type: + media_type = data_field.kwarg_definition.media_type + + if resolved_data_dto: + schema = resolved_data_dto.create_openapi_schema( + field_definition=data_field, + handler_id=handler_id, + schema_creator=schema_creator, + ) + else: + schema = schema_creator.for_field_definition(data_field) + + return RequestBody(required=True, content={media_type: OpenAPIMediaType(schema=schema)}) diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/responses.py b/venv/lib/python3.11/site-packages/litestar/_openapi/responses.py new file mode 100644 index 0000000..6b0f312 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/responses.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import contextlib +import re +from copy import copy +from dataclasses import asdict +from http import HTTPStatus +from operator import attrgetter +from typing import TYPE_CHECKING, Any, Iterator + +from litestar._openapi.schema_generation import SchemaCreator +from litestar._openapi.schema_generation.utils import get_formatted_examples +from litestar.enums import MediaType +from litestar.exceptions import HTTPException, ValidationException +from litestar.openapi.spec import Example, OpenAPIResponse, Reference +from litestar.openapi.spec.enums import OpenAPIFormat, OpenAPIType +from litestar.openapi.spec.header import OpenAPIHeader +from litestar.openapi.spec.media_type import OpenAPIMediaType +from litestar.openapi.spec.schema import Schema +from litestar.response import ( + File, + Redirect, + Stream, + Template, +) +from litestar.response import ( + Response as LitestarResponse, +) +from litestar.response.base import ASGIResponse +from litestar.types.builtin_types import NoneType +from litestar.typing import FieldDefinition +from litestar.utils import get_enum_string_value, get_name + +if TYPE_CHECKING: + from litestar._openapi.datastructures import OpenAPIContext + from litestar.datastructures.cookie import Cookie + from litestar.handlers.http_handlers import HTTPRouteHandler + from litestar.openapi.spec.responses import Responses + + +__all__ = ("create_responses_for_handler",) + +CAPITAL_LETTERS_PATTERN = re.compile(r"(?=[A-Z])") + + +def pascal_case_to_text(string: str) -> str: + """Given a 'PascalCased' string, return its split form- 'Pascal Cased'.""" + return " ".join(re.split(CAPITAL_LETTERS_PATTERN, string)).strip() + + +def create_cookie_schema(cookie: Cookie) -> Schema: + """Given a Cookie instance, return its corresponding OpenAPI schema. + + Args: + cookie: Cookie + + Returns: + Schema + """ + cookie_copy = copy(cookie) + cookie_copy.value = "<string>" + value = cookie_copy.to_header(header="") + return Schema(description=cookie.description or "", example=value) + + +class ResponseFactory: + """Factory for creating a Response instance for a given route handler.""" + + def __init__(self, context: OpenAPIContext, route_handler: HTTPRouteHandler) -> None: + """Initialize the factory. + + Args: + context: An OpenAPIContext instance. + route_handler: An HTTPRouteHandler instance. + """ + self.context = context + self.route_handler = route_handler + self.field_definition = route_handler.parsed_fn_signature.return_type + self.schema_creator = SchemaCreator.from_openapi_context(context, prefer_alias=False) + + def create_responses(self, raises_validation_error: bool) -> Responses | None: + """Create the schema for responses, if any. + + Args: + raises_validation_error: Boolean flag indicating whether the handler raises a ValidationException. + + Returns: + Responses + """ + responses: Responses = { + str(self.route_handler.status_code): self.create_success_response(), + } + + exceptions = list(self.route_handler.raises or []) + if raises_validation_error and ValidationException not in exceptions: + exceptions.append(ValidationException) + + for status_code, response in create_error_responses(exceptions=exceptions): + responses[status_code] = response + + for status_code, response in self.create_additional_responses(): + responses[status_code] = response + + return responses or None + + def create_description(self) -> str: + """Create the description for a success response.""" + default_descriptions: dict[Any, str] = { + Stream: "Stream Response", + Redirect: "Redirect Response", + File: "File Download", + } + return ( + self.route_handler.response_description + or default_descriptions.get(self.field_definition.annotation) + or HTTPStatus(self.route_handler.status_code).description + ) + + def create_success_response(self) -> OpenAPIResponse: + """Create the schema for a success response.""" + if self.field_definition.is_subclass_of((NoneType, ASGIResponse)): + response = OpenAPIResponse(content=None, description=self.create_description()) + elif self.field_definition.is_subclass_of(Redirect): + response = self.create_redirect_response() + elif self.field_definition.is_subclass_of((File, Stream)): + response = self.create_file_response() + else: + media_type = self.route_handler.media_type + + if dto := self.route_handler.resolve_return_dto(): + result = dto.create_openapi_schema( + field_definition=self.field_definition, + handler_id=self.route_handler.handler_id, + schema_creator=self.schema_creator, + ) + else: + if self.field_definition.is_subclass_of(Template): + field_def = FieldDefinition.from_annotation(str) + media_type = media_type or MediaType.HTML + elif self.field_definition.is_subclass_of(LitestarResponse): + field_def = ( + self.field_definition.inner_types[0] + if self.field_definition.inner_types + else FieldDefinition.from_annotation(Any) + ) + media_type = media_type or MediaType.JSON + else: + field_def = self.field_definition + + result = self.schema_creator.for_field_definition(field_def) + + schema = ( + result if isinstance(result, Schema) else self.context.schema_registry.from_reference(result).schema + ) + schema.content_encoding = self.route_handler.content_encoding + schema.content_media_type = self.route_handler.content_media_type + response = OpenAPIResponse( + content={get_enum_string_value(media_type): OpenAPIMediaType(schema=result)}, + description=self.create_description(), + ) + self.set_success_response_headers(response) + return response + + def create_redirect_response(self) -> OpenAPIResponse: + """Create the schema for a redirect response.""" + return OpenAPIResponse( + content=None, + description=self.create_description(), + headers={ + "location": OpenAPIHeader( + schema=Schema(type=OpenAPIType.STRING), description="target path for the redirect" + ) + }, + ) + + def create_file_response(self) -> OpenAPIResponse: + """Create the schema for a file/stream response.""" + return OpenAPIResponse( + content={ + self.route_handler.media_type: OpenAPIMediaType( + schema=Schema( + type=OpenAPIType.STRING, + content_encoding=self.route_handler.content_encoding, + content_media_type=self.route_handler.content_media_type or "application/octet-stream", + ), + ) + }, + description=self.create_description(), + headers={ + "content-length": OpenAPIHeader( + schema=Schema(type=OpenAPIType.STRING), description="File size in bytes" + ), + "last-modified": OpenAPIHeader( + schema=Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.DATE_TIME), + description="Last modified data-time in RFC 2822 format", + ), + "etag": OpenAPIHeader(schema=Schema(type=OpenAPIType.STRING), description="Entity tag"), + }, + ) + + def set_success_response_headers(self, response: OpenAPIResponse) -> None: + """Set the schema for success response headers, if any.""" + + if response.headers is None: + response.headers = {} + + if not self.schema_creator.generate_examples: + schema_creator = self.schema_creator + else: + schema_creator = SchemaCreator.from_openapi_context(self.context, generate_examples=False) + + for response_header in self.route_handler.resolve_response_headers(): + header = OpenAPIHeader() + for attribute_name, attribute_value in ( + (k, v) for k, v in asdict(response_header).items() if v is not None + ): + if attribute_name == "value": + header.schema = schema_creator.for_field_definition( + FieldDefinition.from_annotation(type(attribute_value)) + ) + elif attribute_name != "documentation_only": + setattr(header, attribute_name, attribute_value) + + response.headers[response_header.name] = header + + if cookies := self.route_handler.resolve_response_cookies(): + response.headers["Set-Cookie"] = OpenAPIHeader( + schema=Schema( + all_of=[create_cookie_schema(cookie=cookie) for cookie in sorted(cookies, key=attrgetter("key"))] + ) + ) + + def create_additional_responses(self) -> Iterator[tuple[str, OpenAPIResponse]]: + """Create the schema for additional responses, if any.""" + if not self.route_handler.responses: + return + + for status_code, additional_response in self.route_handler.responses.items(): + schema_creator = SchemaCreator.from_openapi_context( + self.context, + prefer_alias=False, + generate_examples=additional_response.generate_examples, + ) + field_def = FieldDefinition.from_annotation(additional_response.data_container) + + examples: dict[str, Example | Reference] | None = ( + dict(get_formatted_examples(field_def, additional_response.examples)) + if additional_response.examples + else None + ) + + content: dict[str, OpenAPIMediaType] | None + if additional_response.data_container is not None: + schema = schema_creator.for_field_definition(field_def) + content = {additional_response.media_type: OpenAPIMediaType(schema=schema, examples=examples)} + else: + content = None + + yield ( + str(status_code), + OpenAPIResponse( + description=additional_response.description, + content=content, + ), + ) + + +def create_error_responses(exceptions: list[type[HTTPException]]) -> Iterator[tuple[str, OpenAPIResponse]]: + """Create the schema for error responses, if any.""" + grouped_exceptions: dict[int, list[type[HTTPException]]] = {} + for exc in exceptions: + if not grouped_exceptions.get(exc.status_code): + grouped_exceptions[exc.status_code] = [] + grouped_exceptions[exc.status_code].append(exc) + for status_code, exception_group in grouped_exceptions.items(): + exceptions_schemas = [] + group_description: str = "" + for exc in exception_group: + example_detail = "" + if hasattr(exc, "detail") and exc.detail: + group_description = exc.detail + example_detail = exc.detail + + if not example_detail: + with contextlib.suppress(Exception): + example_detail = HTTPStatus(status_code).phrase + + exceptions_schemas.append( + Schema( + type=OpenAPIType.OBJECT, + required=["detail", "status_code"], + properties={ + "status_code": Schema(type=OpenAPIType.INTEGER), + "detail": Schema(type=OpenAPIType.STRING), + "extra": Schema( + type=[OpenAPIType.NULL, OpenAPIType.OBJECT, OpenAPIType.ARRAY], + additional_properties=Schema(), + ), + }, + description=pascal_case_to_text(get_name(exc)), + examples=[{"status_code": status_code, "detail": example_detail, "extra": {}}], + ) + ) + if len(exceptions_schemas) > 1: # noqa: SIM108 + schema = Schema(one_of=exceptions_schemas) + else: + schema = exceptions_schemas[0] + + if not group_description: + with contextlib.suppress(Exception): + group_description = HTTPStatus(status_code).description + + yield ( + str(status_code), + OpenAPIResponse( + description=group_description, + content={MediaType.JSON: OpenAPIMediaType(schema=schema)}, + ), + ) + + +def create_responses_for_handler( + context: OpenAPIContext, route_handler: HTTPRouteHandler, raises_validation_error: bool +) -> Responses | None: + """Create the schema for responses, if any. + + Args: + context: An OpenAPIContext instance. + route_handler: An HTTPRouteHandler instance. + raises_validation_error: Boolean flag indicating whether the handler raises a ValidationException. + + Returns: + Responses + """ + return ResponseFactory(context, route_handler).create_responses(raises_validation_error=raises_validation_error) diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__init__.py b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__init__.py new file mode 100644 index 0000000..8b9183e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__init__.py @@ -0,0 +1,7 @@ +from .plugins import openapi_schema_plugins +from .schema import SchemaCreator + +__all__ = ( + "SchemaCreator", + "openapi_schema_plugins", +) diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..60cdb7e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/constrained_fields.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/constrained_fields.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..156b683 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/constrained_fields.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/examples.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/examples.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1d1a327 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/examples.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/schema.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/schema.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c382161 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/schema.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8f1ed6b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/__pycache__/utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/constrained_fields.py b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/constrained_fields.py new file mode 100644 index 0000000..80f355d --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/constrained_fields.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from datetime import date, datetime, timezone +from re import Pattern +from typing import TYPE_CHECKING + +from litestar.openapi.spec.enums import OpenAPIFormat, OpenAPIType +from litestar.openapi.spec.schema import Schema + +if TYPE_CHECKING: + from decimal import Decimal + + from litestar.params import KwargDefinition + +__all__ = ( + "create_date_constrained_field_schema", + "create_numerical_constrained_field_schema", + "create_string_constrained_field_schema", +) + + +def create_numerical_constrained_field_schema( + field_type: type[int] | type[float] | type[Decimal], + kwarg_definition: KwargDefinition, +) -> Schema: + """Create Schema from Constrained Int/Float/Decimal field.""" + schema = Schema(type=OpenAPIType.INTEGER if issubclass(field_type, int) else OpenAPIType.NUMBER) + if kwarg_definition.le is not None: + schema.maximum = float(kwarg_definition.le) + if kwarg_definition.lt is not None: + schema.exclusive_maximum = float(kwarg_definition.lt) + if kwarg_definition.ge is not None: + schema.minimum = float(kwarg_definition.ge) + if kwarg_definition.gt is not None: + schema.exclusive_minimum = float(kwarg_definition.gt) + if kwarg_definition.multiple_of is not None: + schema.multiple_of = float(kwarg_definition.multiple_of) + return schema + + +def create_date_constrained_field_schema( + field_type: type[date] | type[datetime], + kwarg_definition: KwargDefinition, +) -> Schema: + """Create Schema from Constrained Date Field.""" + schema = Schema( + type=OpenAPIType.STRING, format=OpenAPIFormat.DATE if issubclass(field_type, date) else OpenAPIFormat.DATE_TIME + ) + for kwarg_definition_attr, schema_attr in [ + ("le", "maximum"), + ("lt", "exclusive_maximum"), + ("ge", "minimum"), + ("gt", "exclusive_minimum"), + ]: + if attr := getattr(kwarg_definition, kwarg_definition_attr): + setattr( + schema, + schema_attr, + datetime.combine( + datetime.fromtimestamp(attr, tz=timezone.utc) if isinstance(attr, (float, int)) else attr, + datetime.min.time(), + tzinfo=timezone.utc, + ).timestamp(), + ) + + return schema + + +def create_string_constrained_field_schema( + field_type: type[str] | type[bytes], + kwarg_definition: KwargDefinition, +) -> Schema: + """Create Schema from Constrained Str/Bytes field.""" + schema = Schema(type=OpenAPIType.STRING) + if issubclass(field_type, bytes): + schema.content_encoding = "utf-8" + if kwarg_definition.min_length: + schema.min_length = kwarg_definition.min_length + if kwarg_definition.max_length: + schema.max_length = kwarg_definition.max_length + if kwarg_definition.pattern: + schema.pattern = ( + kwarg_definition.pattern.pattern # type: ignore[attr-defined] + if isinstance(kwarg_definition.pattern, Pattern) # type: ignore[unreachable] + else kwarg_definition.pattern + ) + if kwarg_definition.lower_case: + schema.description = "must be in lower case" + if kwarg_definition.upper_case: + schema.description = "must be in upper case" + return schema diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/examples.py b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/examples.py new file mode 100644 index 0000000..49edf72 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/examples.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import typing +from dataclasses import replace +from decimal import Decimal +from enum import Enum +from typing import TYPE_CHECKING, Any + +import msgspec +from polyfactory.exceptions import ParameterException +from polyfactory.factories import DataclassFactory +from polyfactory.field_meta import FieldMeta, Null +from polyfactory.utils.helpers import unwrap_annotation +from polyfactory.utils.predicates import is_union +from typing_extensions import get_args + +from litestar.contrib.pydantic.utils import is_pydantic_model_instance +from litestar.openapi.spec import Example +from litestar.types import Empty + +if TYPE_CHECKING: + from litestar.typing import FieldDefinition + + +class ExampleFactory(DataclassFactory[Example]): + __model__ = Example + __random_seed__ = 10 + + +def _normalize_example_value(value: Any) -> Any: + """Normalize the example value to make it look a bit prettier.""" + # if UnsetType is part of the union, then it might get chosen as the value + # but that will not be properly serialized by msgspec unless it is for a field + # in a msgspec Struct + if is_union(value): + args = list(get_args(value)) + try: + args.remove(msgspec.UnsetType) + value = typing.Union[tuple(args)] # pyright: ignore + except ValueError: + # UnsetType not part of the Union + pass + + value = unwrap_annotation(annotation=value, random=ExampleFactory.__random__) + if isinstance(value, (Decimal, float)): + value = round(float(value), 2) + if isinstance(value, Enum): + value = value.value + if is_pydantic_model_instance(value): + from litestar.contrib.pydantic import _model_dump + + value = _model_dump(value) + if isinstance(value, (list, set)): + value = [_normalize_example_value(v) for v in value] + if isinstance(value, dict): + for k, v in value.items(): + value[k] = _normalize_example_value(v) + return value + + +def _create_field_meta(field: FieldDefinition) -> FieldMeta: + return FieldMeta.from_type( + annotation=field.annotation, + default=field.default if field.default is not Empty else Null, + name=field.name, + random=ExampleFactory.__random__, + ) + + +def create_examples_for_field(field: FieldDefinition) -> list[Example]: + """Create an OpenAPI Example instance. + + Args: + field: A signature field. + + Returns: + A list including a single example. + """ + try: + field_meta = _create_field_meta(replace(field, annotation=_normalize_example_value(field.annotation))) + value = ExampleFactory.get_field_value(field_meta) + return [Example(description=f"Example {field.name} value", value=value)] + except ParameterException: + return [] diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__init__.py b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__init__.py new file mode 100644 index 0000000..1b12b1e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__init__.py @@ -0,0 +1,16 @@ +from .dataclass import DataclassSchemaPlugin +from .pagination import PaginationSchemaPlugin +from .struct import StructSchemaPlugin +from .typed_dict import TypedDictSchemaPlugin + +__all__ = ("openapi_schema_plugins",) + +# NOTE: The Pagination type plugin has to come before the Dataclass plugin since the Pagination +# classes are dataclasses, but we want to handle them differently from how dataclasses are normally +# handled. +openapi_schema_plugins = [ + PaginationSchemaPlugin(), + StructSchemaPlugin(), + DataclassSchemaPlugin(), + TypedDictSchemaPlugin(), +] diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f022bdd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/dataclass.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/dataclass.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..7cbfe6f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/dataclass.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/pagination.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/pagination.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0bc4add --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/pagination.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/struct.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/struct.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..cef9d12 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/struct.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/typed_dict.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/typed_dict.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..3d534ba --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/__pycache__/typed_dict.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/dataclass.py b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/dataclass.py new file mode 100644 index 0000000..fb5da35 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/dataclass.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from dataclasses import MISSING, fields +from typing import TYPE_CHECKING + +from litestar.plugins import OpenAPISchemaPlugin +from litestar.typing import FieldDefinition +from litestar.utils.predicates import is_optional_union + +if TYPE_CHECKING: + from litestar._openapi.schema_generation import SchemaCreator + from litestar.openapi.spec import Schema + + +class DataclassSchemaPlugin(OpenAPISchemaPlugin): + def is_plugin_supported_field(self, field_definition: FieldDefinition) -> bool: + return field_definition.is_dataclass_type + + def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema: + type_hints = field_definition.get_type_hints(include_extras=True, resolve_generics=True) + dataclass_fields = fields(field_definition.type_) + return schema_creator.create_component_schema( + field_definition, + required=sorted( + field.name + for field in dataclass_fields + if ( + field.default is MISSING + and field.default_factory is MISSING + and not is_optional_union(type_hints[field.name]) + ) + ), + property_fields={ + field.name: FieldDefinition.from_kwarg(type_hints[field.name], field.name) for field in dataclass_fields + }, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/pagination.py b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/pagination.py new file mode 100644 index 0000000..9b4f6c6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/pagination.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from litestar.openapi.spec import OpenAPIType, Schema +from litestar.pagination import ClassicPagination, CursorPagination, OffsetPagination +from litestar.plugins import OpenAPISchemaPlugin + +if TYPE_CHECKING: + from litestar._openapi.schema_generation import SchemaCreator + from litestar.typing import FieldDefinition + + +class PaginationSchemaPlugin(OpenAPISchemaPlugin): + def is_plugin_supported_field(self, field_definition: FieldDefinition) -> bool: + return field_definition.origin in (ClassicPagination, CursorPagination, OffsetPagination) + + def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema: + if field_definition.origin is ClassicPagination: + return Schema( + type=OpenAPIType.OBJECT, + properties={ + "items": Schema( + type=OpenAPIType.ARRAY, + items=schema_creator.for_field_definition(field_definition.inner_types[0]), + ), + "page_size": Schema(type=OpenAPIType.INTEGER, description="Number of items per page."), + "current_page": Schema(type=OpenAPIType.INTEGER, description="Current page number."), + "total_pages": Schema(type=OpenAPIType.INTEGER, description="Total number of pages."), + }, + ) + + if field_definition.origin is OffsetPagination: + return Schema( + type=OpenAPIType.OBJECT, + properties={ + "items": Schema( + type=OpenAPIType.ARRAY, + items=schema_creator.for_field_definition(field_definition.inner_types[0]), + ), + "limit": Schema(type=OpenAPIType.INTEGER, description="Maximal number of items to send."), + "offset": Schema(type=OpenAPIType.INTEGER, description="Offset from the beginning of the query."), + "total": Schema(type=OpenAPIType.INTEGER, description="Total number of items."), + }, + ) + + cursor_schema = schema_creator.not_generating_examples.for_field_definition(field_definition.inner_types[0]) + cursor_schema.description = "Unique ID, designating the last identifier in the given data set. This value can be used to request the 'next' batch of records." + + return Schema( + type=OpenAPIType.OBJECT, + properties={ + "items": Schema( + type=OpenAPIType.ARRAY, + items=schema_creator.for_field_definition(field_definition=field_definition.inner_types[1]), + ), + "cursor": cursor_schema, + "results_per_page": Schema(type=OpenAPIType.INTEGER, description="Maximal number of items to send."), + }, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/struct.py b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/struct.py new file mode 100644 index 0000000..aabfdb3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/struct.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from msgspec import Struct +from msgspec.structs import fields + +from litestar.plugins import OpenAPISchemaPlugin +from litestar.types.empty import Empty +from litestar.typing import FieldDefinition +from litestar.utils.predicates import is_optional_union + +if TYPE_CHECKING: + from msgspec.structs import FieldInfo + + from litestar._openapi.schema_generation import SchemaCreator + from litestar.openapi.spec import Schema + + +class StructSchemaPlugin(OpenAPISchemaPlugin): + def is_plugin_supported_field(self, field_definition: FieldDefinition) -> bool: + return not field_definition.is_union and field_definition.is_subclass_of(Struct) + + def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema: + def is_field_required(field: FieldInfo) -> bool: + return field.required or field.default_factory is Empty + + type_hints = field_definition.get_type_hints(include_extras=True, resolve_generics=True) + struct_fields = fields(field_definition.type_) + + return schema_creator.create_component_schema( + field_definition, + required=sorted( + [ + field.encode_name + for field in struct_fields + if is_field_required(field=field) and not is_optional_union(type_hints[field.name]) + ] + ), + property_fields={ + field.encode_name: FieldDefinition.from_kwarg(type_hints[field.name], field.encode_name) + for field in struct_fields + }, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/typed_dict.py b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/typed_dict.py new file mode 100644 index 0000000..ef34e2b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/plugins/typed_dict.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from litestar.plugins import OpenAPISchemaPlugin +from litestar.typing import FieldDefinition + +if TYPE_CHECKING: + from litestar._openapi.schema_generation import SchemaCreator + from litestar.openapi.spec import Schema + + +class TypedDictSchemaPlugin(OpenAPISchemaPlugin): + def is_plugin_supported_field(self, field_definition: FieldDefinition) -> bool: + return field_definition.is_typeddict_type + + def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema: + type_hints = field_definition.get_type_hints(include_extras=True, resolve_generics=True) + + return schema_creator.create_component_schema( + field_definition, + required=sorted(getattr(field_definition.type_, "__required_keys__", [])), + property_fields={k: FieldDefinition.from_kwarg(v, k) for k, v in type_hints.items()}, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/schema.py b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/schema.py new file mode 100644 index 0000000..0b7d6c6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/schema.py @@ -0,0 +1,616 @@ +from __future__ import annotations + +from collections import deque +from copy import copy +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from enum import Enum, EnumMeta +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + DefaultDict, + Deque, + Dict, + FrozenSet, + Hashable, + Iterable, + List, + Literal, + Mapping, + MutableMapping, + MutableSequence, + OrderedDict, + Pattern, + Sequence, + Set, + Tuple, + Union, + cast, +) +from uuid import UUID + +from typing_extensions import Self, get_args + +from litestar._openapi.datastructures import SchemaRegistry +from litestar._openapi.schema_generation.constrained_fields import ( + create_date_constrained_field_schema, + create_numerical_constrained_field_schema, + create_string_constrained_field_schema, +) +from litestar._openapi.schema_generation.utils import ( + _get_normalized_schema_key, + _should_create_enum_schema, + _should_create_literal_schema, + _type_or_first_not_none_inner_type, + get_json_schema_formatted_examples, +) +from litestar.datastructures import UploadFile +from litestar.exceptions import ImproperlyConfiguredException +from litestar.openapi.spec.enums import OpenAPIFormat, OpenAPIType +from litestar.openapi.spec.schema import Schema, SchemaDataContainer +from litestar.params import BodyKwarg, ParameterKwarg +from litestar.plugins import OpenAPISchemaPlugin +from litestar.types import Empty +from litestar.types.builtin_types import NoneType +from litestar.typing import FieldDefinition +from litestar.utils.helpers import get_name +from litestar.utils.predicates import ( + is_class_and_subclass, + is_undefined_sentinel, +) +from litestar.utils.typing import ( + get_origin_or_inner_type, + make_non_optional_union, +) + +if TYPE_CHECKING: + from litestar._openapi.datastructures import OpenAPIContext + from litestar.openapi.spec import Example, Reference + from litestar.plugins import OpenAPISchemaPluginProtocol + +KWARG_DEFINITION_ATTRIBUTE_TO_OPENAPI_PROPERTY_MAP: dict[str, str] = { + "content_encoding": "content_encoding", + "default": "default", + "description": "description", + "enum": "enum", + "examples": "examples", + "external_docs": "external_docs", + "format": "format", + "ge": "minimum", + "gt": "exclusive_minimum", + "le": "maximum", + "lt": "exclusive_maximum", + "max_items": "max_items", + "max_length": "max_length", + "min_items": "min_items", + "min_length": "min_length", + "multiple_of": "multiple_of", + "pattern": "pattern", + "title": "title", + "read_only": "read_only", +} + +TYPE_MAP: dict[type[Any] | None | Any, Schema] = { + Decimal: Schema(type=OpenAPIType.NUMBER), + DefaultDict: Schema(type=OpenAPIType.OBJECT), + Deque: Schema(type=OpenAPIType.ARRAY), + Dict: Schema(type=OpenAPIType.OBJECT), + FrozenSet: Schema(type=OpenAPIType.ARRAY), + IPv4Address: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.IPV4), + IPv4Interface: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.IPV4), + IPv4Network: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.IPV4), + IPv6Address: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.IPV6), + IPv6Interface: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.IPV6), + IPv6Network: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.IPV6), + Iterable: Schema(type=OpenAPIType.ARRAY), + List: Schema(type=OpenAPIType.ARRAY), + Mapping: Schema(type=OpenAPIType.OBJECT), + MutableMapping: Schema(type=OpenAPIType.OBJECT), + MutableSequence: Schema(type=OpenAPIType.ARRAY), + None: Schema(type=OpenAPIType.NULL), + NoneType: Schema(type=OpenAPIType.NULL), + OrderedDict: Schema(type=OpenAPIType.OBJECT), + Path: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI), + Pattern: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.REGEX), + Sequence: Schema(type=OpenAPIType.ARRAY), + Set: Schema(type=OpenAPIType.ARRAY), + Tuple: Schema(type=OpenAPIType.ARRAY), + UUID: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.UUID), + bool: Schema(type=OpenAPIType.BOOLEAN), + bytearray: Schema(type=OpenAPIType.STRING), + bytes: Schema(type=OpenAPIType.STRING), + date: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.DATE), + datetime: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.DATE_TIME), + deque: Schema(type=OpenAPIType.ARRAY), + dict: Schema(type=OpenAPIType.OBJECT), + float: Schema(type=OpenAPIType.NUMBER), + frozenset: Schema(type=OpenAPIType.ARRAY), + int: Schema(type=OpenAPIType.INTEGER), + list: Schema(type=OpenAPIType.ARRAY), + set: Schema(type=OpenAPIType.ARRAY), + str: Schema(type=OpenAPIType.STRING), + time: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.DURATION), + timedelta: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.DURATION), + tuple: Schema(type=OpenAPIType.ARRAY), +} + + +def _types_in_list(lst: list[Any]) -> list[OpenAPIType] | OpenAPIType: + """Extract unique OpenAPITypes present in the values of a list. + + Args: + lst: A list of values + + Returns: + OpenAPIType in the given list. If more then one exists, return + a list of OpenAPITypes. + """ + schema_types: list[OpenAPIType] = [] + for item in lst: + schema_type = TYPE_MAP[type(item)].type + if isinstance(schema_type, OpenAPIType): + schema_types.append(schema_type) + else: + raise RuntimeError("Unexpected type for schema item") # pragma: no cover + schema_types = list(set(schema_types)) + return schema_types[0] if len(schema_types) == 1 else schema_types + + +def _get_type_schema_name(field_definition: FieldDefinition) -> str: + """Extract the schema name from a data container. + + Args: + field_definition: A field definition instance. + + Returns: + A string + """ + + if name := getattr(field_definition.annotation, "__schema_name__", None): + return cast("str", name) + + name = get_name(field_definition.annotation) + if field_definition.inner_types: + inner_parts = ", ".join(_get_type_schema_name(t) for t in field_definition.inner_types) + return f"{name}[{inner_parts}]" + + return name + + +def create_enum_schema(annotation: EnumMeta, include_null: bool = False) -> Schema: + """Create a schema instance for an enum. + + Args: + annotation: An enum. + include_null: Whether to include null as a possible value. + + Returns: + A schema instance. + """ + enum_values: list[str | int | None] = [v.value for v in annotation] # type: ignore[var-annotated] + if include_null and None not in enum_values: + enum_values.append(None) + return Schema(type=_types_in_list(enum_values), enum=enum_values) + + +def _iter_flat_literal_args(annotation: Any) -> Iterable[Any]: + """Iterate over the flattened arguments of a Literal. + + Args: + annotation: An Literal annotation. + + Yields: + The flattened arguments of the Literal. + """ + for arg in get_args(annotation): + if get_origin_or_inner_type(arg) is Literal: + yield from _iter_flat_literal_args(arg) + else: + yield arg.value if isinstance(arg, Enum) else arg + + +def create_literal_schema(annotation: Any, include_null: bool = False) -> Schema: + """Create a schema instance for a Literal. + + Args: + annotation: An Literal annotation. + include_null: Whether to include null as a possible value. + + Returns: + A schema instance. + """ + args = list(_iter_flat_literal_args(annotation)) + if include_null and None not in args: + args.append(None) + schema = Schema(type=_types_in_list(args)) + if len(args) > 1: + schema.enum = args + else: + schema.const = args[0] + return schema + + +def create_schema_for_annotation(annotation: Any) -> Schema: + """Get a schema from the type mapping - if possible. + + Args: + annotation: A type annotation. + + Returns: + A schema instance or None. + """ + + return copy(TYPE_MAP[annotation]) if annotation in TYPE_MAP else Schema() + + +class SchemaCreator: + __slots__ = ("generate_examples", "plugins", "prefer_alias", "schema_registry") + + def __init__( + self, + generate_examples: bool = False, + plugins: Iterable[OpenAPISchemaPluginProtocol] | None = None, + prefer_alias: bool = True, + schema_registry: SchemaRegistry | None = None, + ) -> None: + """Instantiate a SchemaCreator. + + Args: + generate_examples: Whether to generate examples if none are given. + plugins: A list of plugins. + prefer_alias: Whether to prefer the alias name for the schema. + schema_registry: A SchemaRegistry instance. + """ + self.generate_examples = generate_examples + self.plugins = plugins if plugins is not None else [] + self.prefer_alias = prefer_alias + self.schema_registry = schema_registry or SchemaRegistry() + + @classmethod + def from_openapi_context(cls, context: OpenAPIContext, prefer_alias: bool = True, **kwargs: Any) -> Self: + kwargs.setdefault("generate_examples", context.openapi_config.create_examples) + kwargs.setdefault("plugins", context.plugins) + kwargs.setdefault("schema_registry", context.schema_registry) + return cls(**kwargs, prefer_alias=prefer_alias) + + @property + def not_generating_examples(self) -> SchemaCreator: + """Return a SchemaCreator with generate_examples set to False.""" + if not self.generate_examples: + return self + return type(self)(generate_examples=False, plugins=self.plugins, prefer_alias=False) + + @staticmethod + def plugin_supports_field(plugin: OpenAPISchemaPluginProtocol, field: FieldDefinition) -> bool: + if predicate := getattr(plugin, "is_plugin_supported_field", None): + return predicate(field) # type: ignore[no-any-return] + return plugin.is_plugin_supported_type(field.annotation) + + def get_plugin_for(self, field_definition: FieldDefinition) -> OpenAPISchemaPluginProtocol | None: + return next( + (plugin for plugin in self.plugins if self.plugin_supports_field(plugin, field_definition)), + None, + ) + + def is_constrained_field(self, field_definition: FieldDefinition) -> bool: + """Return if the field is constrained, taking into account constraints defined by plugins""" + return ( + isinstance(field_definition.kwarg_definition, (ParameterKwarg, BodyKwarg)) + and field_definition.kwarg_definition.is_constrained + ) or any( + p.is_constrained_field(field_definition) + for p in self.plugins + if isinstance(p, OpenAPISchemaPlugin) and p.is_plugin_supported_field(field_definition) + ) + + def is_undefined(self, value: Any) -> bool: + """Return if the field is undefined, taking into account undefined types defined by plugins""" + return is_undefined_sentinel(value) or any( + p.is_undefined_sentinel(value) for p in self.plugins if isinstance(p, OpenAPISchemaPlugin) + ) + + def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Reference: + """Create a Schema for a given FieldDefinition. + + Args: + field_definition: A signature field instance. + + Returns: + A schema instance. + """ + + result: Schema | Reference + + if plugin_for_annotation := self.get_plugin_for(field_definition): + result = self.for_plugin(field_definition, plugin_for_annotation) + elif _should_create_enum_schema(field_definition): + annotation = _type_or_first_not_none_inner_type(field_definition) + result = create_enum_schema(annotation, include_null=field_definition.is_optional) + elif _should_create_literal_schema(field_definition): + annotation = ( + make_non_optional_union(field_definition.annotation) + if field_definition.is_optional + else field_definition.annotation + ) + result = create_literal_schema(annotation, include_null=field_definition.is_optional) + elif field_definition.is_optional: + result = self.for_optional_field(field_definition) + elif field_definition.is_union: + result = self.for_union_field(field_definition) + elif field_definition.is_type_var: + result = self.for_typevar() + elif field_definition.inner_types and not field_definition.is_generic: + result = self.for_object_type(field_definition) + elif self.is_constrained_field(field_definition): + result = self.for_constrained_field(field_definition) + elif field_definition.is_subclass_of(UploadFile): + result = self.for_upload_file(field_definition) + else: + result = create_schema_for_annotation(field_definition.annotation) + + return self.process_schema_result(field_definition, result) if isinstance(result, Schema) else result + + @staticmethod + def for_upload_file(field_definition: FieldDefinition) -> Schema: + """Create schema for UploadFile. + + Args: + field_definition: A field definition instance. + + Returns: + A Schema instance. + """ + + property_key = "file" + schema = Schema( + type=OpenAPIType.STRING, + content_media_type="application/octet-stream", + format=OpenAPIFormat.BINARY, + ) + + # If the type is `dict[str, UploadFile]`, then it's the same as a `list[UploadFile]` + # but we will internally convert that into a `dict[str, UploadFile]`. + if field_definition.is_non_string_sequence or field_definition.is_mapping: + property_key = "files" + schema = Schema(type=OpenAPIType.ARRAY, items=schema) + + # If the uploadfile is annotated directly on the handler, then the + # 'properties' needs to be created. Else, the 'properties' will be + # created by the corresponding plugin. + is_defined_on_handler = field_definition.name == "data" and isinstance( + field_definition.kwarg_definition, BodyKwarg + ) + if is_defined_on_handler: + return Schema(type=OpenAPIType.OBJECT, properties={property_key: schema}) + + return schema + + @staticmethod + def for_typevar() -> Schema: + """Create a schema for a TypeVar. + + Returns: + A schema instance. + """ + + return Schema(type=OpenAPIType.OBJECT) + + def for_optional_field(self, field_definition: FieldDefinition) -> Schema: + """Create a Schema for an optional FieldDefinition. + + Args: + field_definition: A signature field instance. + + Returns: + A schema instance. + """ + schema_or_reference = self.for_field_definition( + FieldDefinition.from_kwarg( + annotation=make_non_optional_union(field_definition.annotation), + name=field_definition.name, + default=field_definition.default, + ) + ) + if isinstance(schema_or_reference, Schema) and isinstance(schema_or_reference.one_of, list): + result = schema_or_reference.one_of + else: + result = [schema_or_reference] + + return Schema(one_of=[Schema(type=OpenAPIType.NULL), *result]) + + def for_union_field(self, field_definition: FieldDefinition) -> Schema: + """Create a Schema for a union FieldDefinition. + + Args: + field_definition: A signature field instance. + + Returns: + A schema instance. + """ + inner_types = (f for f in (field_definition.inner_types or []) if not self.is_undefined(f.annotation)) + values = list(map(self.for_field_definition, inner_types)) + return Schema(one_of=values) + + def for_object_type(self, field_definition: FieldDefinition) -> Schema: + """Create schema for object types (dict, Mapping, list, Sequence etc.) types. + + Args: + field_definition: A signature field instance. + + Returns: + A schema instance. + """ + if field_definition.has_inner_subclass_of(UploadFile): + return self.for_upload_file(field_definition) + + if field_definition.is_mapping: + return Schema( + type=OpenAPIType.OBJECT, + additional_properties=( + self.for_field_definition(field_definition.inner_types[1]) + if field_definition.inner_types and len(field_definition.inner_types) == 2 + else None + ), + ) + + if field_definition.is_non_string_sequence or field_definition.is_non_string_iterable: + # filters out ellipsis from tuple[int, ...] type annotations + inner_types = (f for f in field_definition.inner_types if f.annotation is not Ellipsis) + items = list(map(self.for_field_definition, inner_types or ())) + + return Schema( + type=OpenAPIType.ARRAY, + items=Schema(one_of=items) if len(items) > 1 else items[0], + ) + + raise ImproperlyConfiguredException( # pragma: no cover + f"Parameter '{field_definition.name}' with type '{field_definition.annotation}' could not be mapped to an Open API type. " + f"This can occur if a user-defined generic type is resolved as a parameter. If '{field_definition.name}' should " + "not be documented as a parameter, annotate it using the `Dependency` function, e.g., " + f"`{field_definition.name}: ... = Dependency(...)`." + ) + + def for_plugin(self, field_definition: FieldDefinition, plugin: OpenAPISchemaPluginProtocol) -> Schema | Reference: + """Create a schema using a plugin. + + Args: + field_definition: A signature field instance. + plugin: A plugin for the field type. + + Returns: + A schema instance. + """ + key = _get_normalized_schema_key(field_definition.annotation) + if (ref := self.schema_registry.get_reference_for_key(key)) is not None: + return ref + + schema = plugin.to_openapi_schema(field_definition=field_definition, schema_creator=self) + if isinstance(schema, SchemaDataContainer): # pragma: no cover + return self.for_field_definition( + FieldDefinition.from_kwarg( + annotation=schema.data_container, + name=field_definition.name, + default=field_definition.default, + extra=field_definition.extra, + kwarg_definition=field_definition.kwarg_definition, + ) + ) + return schema + + def for_constrained_field(self, field: FieldDefinition) -> Schema: + """Create Schema for Pydantic Constrained fields (created using constr(), conint() and so forth, or by subclassing + Constrained*) + + Args: + field: A signature field instance. + + Returns: + A schema instance. + """ + kwarg_definition = cast(Union[ParameterKwarg, BodyKwarg], field.kwarg_definition) + if any(is_class_and_subclass(field.annotation, t) for t in (int, float, Decimal)): + return create_numerical_constrained_field_schema(field.annotation, kwarg_definition) + if any(is_class_and_subclass(field.annotation, t) for t in (str, bytes)): # type: ignore[arg-type] + return create_string_constrained_field_schema(field.annotation, kwarg_definition) + if any(is_class_and_subclass(field.annotation, t) for t in (date, datetime)): + return create_date_constrained_field_schema(field.annotation, kwarg_definition) + return self.for_collection_constrained_field(field) + + def for_collection_constrained_field(self, field_definition: FieldDefinition) -> Schema: + """Create Schema from Constrained List/Set field. + + Args: + field_definition: A signature field instance. + + Returns: + A schema instance. + """ + schema = Schema(type=OpenAPIType.ARRAY) + kwarg_definition = cast(Union[ParameterKwarg, BodyKwarg], field_definition.kwarg_definition) + if kwarg_definition.min_items: + schema.min_items = kwarg_definition.min_items + if kwarg_definition.max_items: + schema.max_items = kwarg_definition.max_items + if any(is_class_and_subclass(field_definition.annotation, t) for t in (set, frozenset)): # type: ignore[arg-type] + schema.unique_items = True + + item_creator = self.not_generating_examples + if field_definition.inner_types: + items = list(map(item_creator.for_field_definition, field_definition.inner_types)) + schema.items = Schema(one_of=items) if len(items) > 1 else items[0] + else: + schema.items = item_creator.for_field_definition( + FieldDefinition.from_kwarg( + field_definition.annotation.item_type, f"{field_definition.annotation.__name__}Field" + ) + ) + return schema + + def process_schema_result(self, field: FieldDefinition, schema: Schema) -> Schema | Reference: + if field.kwarg_definition and field.is_const and field.has_default and schema.const is None: + schema.const = field.default + + if field.kwarg_definition: + for kwarg_definition_key, schema_key in KWARG_DEFINITION_ATTRIBUTE_TO_OPENAPI_PROPERTY_MAP.items(): + if (value := getattr(field.kwarg_definition, kwarg_definition_key, Empty)) and ( + not isinstance(value, Hashable) or not self.is_undefined(value) + ): + if schema_key == "examples": + value = get_json_schema_formatted_examples(cast("list[Example]", value)) + + # we only want to transfer values from the `KwargDefinition` to `Schema` if the schema object + # doesn't already have a value for that property. For example, if a field is a constrained date, + # by this point, we have already set the `exclusive_minimum` and/or `exclusive_maximum` fields + # to floating point timestamp values on the schema object. However, the original `date` objects + # that define those constraints on `KwargDefinition` are still `date` objects. We don't want to + # overwrite them here. + if getattr(schema, schema_key, None) is None: + setattr(schema, schema_key, value) + + if not schema.examples and self.generate_examples: + from litestar._openapi.schema_generation.examples import create_examples_for_field + + schema.examples = get_json_schema_formatted_examples(create_examples_for_field(field)) + + if schema.title and schema.type == OpenAPIType.OBJECT: + key = _get_normalized_schema_key(field.annotation) + return self.schema_registry.get_reference_for_key(key) or schema + return schema + + def create_component_schema( + self, + type_: FieldDefinition, + /, + required: list[str], + property_fields: Mapping[str, FieldDefinition], + openapi_type: OpenAPIType = OpenAPIType.OBJECT, + title: str | None = None, + examples: list[Any] | None = None, + ) -> Schema: + """Create a schema for the components/schemas section of the OpenAPI spec. + + These are schemas that can be referenced by other schemas in the document, including self references. + + To support self referencing schemas, the schema is added to the registry before schemas for its properties + are created. This allows the schema to be referenced by its properties. + + Args: + type_: ``FieldDefinition`` instance of the type to create a schema for. + required: A list of required fields. + property_fields: Mapping of name to ``FieldDefinition`` instances for the properties of the schema. + openapi_type: The OpenAPI type, defaults to ``OpenAPIType.OBJECT``. + title: The schema title, generated if not provided. + examples: A mapping of example names to ``Example`` instances, not required. + + Returns: + A schema instance. + """ + schema = self.schema_registry.get_schema_for_key(_get_normalized_schema_key(type_.annotation)) + schema.title = title or _get_type_schema_name(type_) + schema.required = required + schema.type = openapi_type + schema.properties = {k: self.for_field_definition(v) for k, v in property_fields.items()} + schema.examples = examples + return schema diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/utils.py b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/utils.py new file mode 100644 index 0000000..7ce27ca --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/schema_generation/utils.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Any, Mapping, _GenericAlias # type: ignore[attr-defined] + +from litestar.utils.helpers import get_name + +if TYPE_CHECKING: + from collections.abc import Sequence + + from litestar.openapi.spec import Example + from litestar.typing import FieldDefinition + +__all__ = ( + "_type_or_first_not_none_inner_type", + "_should_create_enum_schema", + "_should_create_literal_schema", + "_get_normalized_schema_key", +) + + +def _type_or_first_not_none_inner_type(field_definition: FieldDefinition) -> Any: + """Get the first inner type that is not None. + + This is a narrow focussed utility to be used when we know that a field definition either represents + a single type, or a single type in a union with `None`, and we want the single type. + + Args: + field_definition: A field definition instance. + + Returns: + A field definition instance. + """ + if not field_definition.is_optional: + return field_definition.annotation + inner = next((t for t in field_definition.inner_types if not t.is_none_type), None) + if inner is None: + raise ValueError("Field definition has no inner type that is not None") + return inner.annotation + + +def _should_create_enum_schema(field_definition: FieldDefinition) -> bool: + """Predicate to determine if we should create an enum schema for the field def, or not. + + This returns true if the field definition is an enum, or if the field definition is a union + of an enum and ``None``. + + When an annotation is ``SomeEnum | None`` we should create a schema for the enum that includes ``null`` + in the enum values. + + Args: + field_definition: A field definition instance. + + Returns: + A boolean + """ + return field_definition.is_subclass_of(Enum) or ( + field_definition.is_optional + and len(field_definition.args) == 2 + and field_definition.has_inner_subclass_of(Enum) + ) + + +def _should_create_literal_schema(field_definition: FieldDefinition) -> bool: + """Predicate to determine if we should create a literal schema for the field def, or not. + + This returns ``True`` if the field definition is an literal, or if the field definition is a union + of a literal and None. + + When an annotation is `Literal["anything"] | None` we should create a schema for the literal that includes `null` + in the enum values. + + Args: + field_definition: A field definition instance. + + Returns: + A boolean + """ + return ( + field_definition.is_literal + or field_definition.is_optional + and all(inner.is_literal for inner in field_definition.inner_types if not inner.is_none_type) + ) + + +def _get_normalized_schema_key(annotation: Any) -> tuple[str, ...]: + """Create a key for a type annotation. + + The key should be a tuple such as ``("path", "to", "type", "TypeName")``. + + Args: + annotation: a type annotation + + Returns: + A tuple of strings. + """ + module = getattr(annotation, "__module__", "") + name = str(annotation)[len(module) + 1 :] if isinstance(annotation, _GenericAlias) else annotation.__qualname__ + name = name.replace(".<locals>.", ".") + return *module.split("."), name + + +def get_formatted_examples(field_definition: FieldDefinition, examples: Sequence[Example]) -> Mapping[str, Example]: + """Format the examples into the OpenAPI schema format.""" + + name = field_definition.name or get_name(field_definition.type_) + name = name.lower() + + return {f"{name}-example-{i}": example for i, example in enumerate(examples, 1)} + + +def get_json_schema_formatted_examples(examples: Sequence[Example]) -> list[Any]: + """Format the examples into the JSON schema format.""" + return [example.value for example in examples] diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__init__.py b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__init__.py diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..7fc0cb2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__pycache__/converter.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__pycache__/converter.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5bc5015 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__pycache__/converter.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__pycache__/schema_parsing.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__pycache__/schema_parsing.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f6e0196 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__pycache__/schema_parsing.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__pycache__/types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__pycache__/types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ffe7efb --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/__pycache__/types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/converter.py b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/converter.py new file mode 100644 index 0000000..4782dbe --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/converter.py @@ -0,0 +1,308 @@ +from __future__ import annotations + +from copy import copy +from dataclasses import fields +from typing import Any, TypeVar, cast + +from litestar._openapi.typescript_converter.schema_parsing import ( + normalize_typescript_namespace, + parse_schema, +) +from litestar._openapi.typescript_converter.types import ( + TypeScriptInterface, + TypeScriptNamespace, + TypeScriptPrimitive, + TypeScriptProperty, + TypeScriptType, + TypeScriptUnion, +) +from litestar.enums import HttpMethod, ParamType +from litestar.openapi.spec import ( + Components, + OpenAPI, + Operation, + Parameter, + Reference, + RequestBody, + Responses, + Schema, +) + +__all__ = ( + "convert_openapi_to_typescript", + "deref_container", + "get_openapi_type", + "parse_params", + "parse_request_body", + "parse_responses", + "resolve_ref", +) + +from litestar.openapi.spec.base import BaseSchemaObject + +T = TypeVar("T") + + +def _deref_schema_object(value: BaseSchemaObject, components: Components) -> BaseSchemaObject: + for field in fields(value): + if field_value := getattr(value, field.name, None): + if isinstance(field_value, Reference): + setattr( + value, + field.name, + deref_container(resolve_ref(field_value, components=components), components=components), + ) + elif isinstance(field_value, (Schema, dict, list)): + setattr(value, field.name, deref_container(field_value, components=components)) + return value + + +def _deref_dict(value: dict[str, Any], components: Components) -> dict[str, Any]: + for k, v in value.items(): + if isinstance(v, Reference): + value[k] = deref_container(resolve_ref(v, components=components), components=components) + elif isinstance(v, (Schema, dict, list)): + value[k] = deref_container(v, components=components) + return value + + +def _deref_list(values: list[Any], components: Components) -> list[Any]: + for i, value in enumerate(values): + if isinstance(value, Reference): + values[i] = deref_container(resolve_ref(value, components=components), components=components) + elif isinstance(value, (Schema, (dict, list))): + values[i] = deref_container(value, components=components) + return values + + +def deref_container(open_api_container: T, components: Components) -> T: + """Dereference an object that may contain Reference instances. + + Args: + open_api_container: Either an OpenAPI content, a dict or a list. + components: The OpenAPI schema Components section. + + Returns: + A dereferenced object. + """ + if isinstance(open_api_container, BaseSchemaObject): + return cast("T", _deref_schema_object(open_api_container, components)) + + if isinstance(open_api_container, dict): + return cast("T", _deref_dict(copy(open_api_container), components)) + + if isinstance(open_api_container, list): + return cast("T", _deref_list(copy(open_api_container), components)) + raise ValueError(f"unexpected container type {type(open_api_container).__name__}") # pragma: no cover + + +def resolve_ref(ref: Reference, components: Components) -> Schema: + """Resolve a reference object into the actual value it points at. + + Args: + ref: A Reference instance. + components: The OpenAPI schema Components section. + + Returns: + An OpenAPI schema instance. + """ + current: Any = components + for path in [p for p in ref.ref.split("/") if p not in {"#", "components"}]: + current = current[path] if isinstance(current, dict) else getattr(current, path, None) + + if not isinstance(current, Schema): # pragma: no cover + raise ValueError( + f"unexpected value type, expected schema but received {type(current).__name__ if current is not None else 'None'}" + ) + + return current + + +def get_openapi_type(value: Reference | T, components: Components) -> T: + """Extract or dereference an OpenAPI container type. + + Args: + value: Either a reference or a container type. + components: The OpenAPI schema Components section. + + Returns: + The extracted container. + """ + if isinstance(value, Reference): + resolved_ref = resolve_ref(value, components=components) + return cast("T", deref_container(open_api_container=resolved_ref, components=components)) + + return deref_container(open_api_container=value, components=components) + + +def parse_params( + params: list[Parameter], + components: Components, +) -> tuple[TypeScriptInterface, ...]: + """Parse request parameters. + + Args: + params: An OpenAPI Operation parameters. + components: The OpenAPI schema Components section. + + Returns: + A tuple of resolved interfaces. + """ + cookie_params: list[TypeScriptProperty] = [] + header_params: list[TypeScriptProperty] = [] + path_params: list[TypeScriptProperty] = [] + query_params: list[TypeScriptProperty] = [] + + for param in params: + if param.schema: + schema = get_openapi_type(param.schema, components) + ts_prop = TypeScriptProperty( + key=normalize_typescript_namespace(param.name, allow_quoted=True), + required=param.required, + value=parse_schema(schema), + ) + if param.param_in == ParamType.COOKIE: + cookie_params.append(ts_prop) + elif param.param_in == ParamType.HEADER: + header_params.append(ts_prop) + elif param.param_in == ParamType.PATH: + path_params.append(ts_prop) + else: + query_params.append(ts_prop) + + result: list[TypeScriptInterface] = [] + + if cookie_params: + result.append(TypeScriptInterface("CookieParameters", tuple(cookie_params))) + if header_params: + result.append(TypeScriptInterface("HeaderParameters", tuple(header_params))) + if path_params: + result.append(TypeScriptInterface("PathParameters", tuple(path_params))) + if query_params: + result.append(TypeScriptInterface("QueryParameters", tuple(query_params))) + + return tuple(result) + + +def parse_request_body(body: RequestBody, components: Components) -> TypeScriptType: + """Parse the schema request body. + + Args: + body: An OpenAPI RequestBody instance. + components: The OpenAPI schema Components section. + + Returns: + A TypeScript type. + """ + undefined = TypeScriptPrimitive("undefined") + if not body.content: + return TypeScriptType("RequestBody", undefined) + + if content := [get_openapi_type(v.schema, components) for v in body.content.values() if v.schema]: + schema = content[0] + return TypeScriptType( + "RequestBody", + parse_schema(schema) if body.required else TypeScriptUnion((parse_schema(schema), undefined)), + ) + + return TypeScriptType("RequestBody", undefined) + + +def parse_responses(responses: Responses, components: Components) -> tuple[TypeScriptNamespace, ...]: + """Parse a given Operation's Responses object. + + Args: + responses: An OpenAPI Responses object. + components: The OpenAPI schema Components section. + + Returns: + A tuple of namespaces, mapping response codes to data. + """ + result: list[TypeScriptNamespace] = [] + for http_status, response in [ + (status, get_openapi_type(res, components=components)) for status, res in responses.items() + ]: + if response.content and ( + content := [get_openapi_type(v.schema, components) for v in response.content.values() if v.schema] + ): + ts_type = parse_schema(content[0]) + else: + ts_type = TypeScriptPrimitive("undefined") + + containers = [ + TypeScriptType("ResponseBody", ts_type), + TypeScriptInterface( + "ResponseHeaders", + tuple( + TypeScriptProperty( + required=get_openapi_type(header, components=components).required, + key=normalize_typescript_namespace(key, allow_quoted=True), + value=TypeScriptPrimitive("string"), + ) + for key, header in response.headers.items() + ), + ) + if response.headers + else None, + ] + + result.append(TypeScriptNamespace(f"Http{http_status}", tuple(c for c in containers if c))) + + return tuple(result) + + +def convert_openapi_to_typescript(openapi_schema: OpenAPI, namespace: str = "API") -> TypeScriptNamespace: + """Convert an OpenAPI Schema instance to a TypeScript namespace. This function is the main entry point for the + TypeScript converter. + + Args: + openapi_schema: An OpenAPI Schema instance. + namespace: The namespace to use. + + Returns: + A string representing the generated types. + """ + if not openapi_schema.paths: # pragma: no cover + raise ValueError("OpenAPI schema has no paths") + if not openapi_schema.components: # pragma: no cover + raise ValueError("OpenAPI schema has no components") + + operations: list[TypeScriptNamespace] = [] + + for path_item in openapi_schema.paths.values(): + shared_params = [ + get_openapi_type(p, components=openapi_schema.components) for p in (path_item.parameters or []) + ] + for method in HttpMethod: + if ( + operation := cast("Operation | None", getattr(path_item, method.lower(), "None")) + ) and operation.operation_id: + params = parse_params( + [ + *( + get_openapi_type(p, components=openapi_schema.components) + for p in (operation.parameters or []) + ), + *shared_params, + ], + components=openapi_schema.components, + ) + request_body = ( + parse_request_body( + get_openapi_type(operation.request_body, components=openapi_schema.components), + components=openapi_schema.components, + ) + if operation.request_body + else None + ) + + responses = parse_responses(operation.responses or {}, components=openapi_schema.components) + + operations.append( + TypeScriptNamespace( + normalize_typescript_namespace(operation.operation_id, allow_quoted=False), + tuple(container for container in (*params, request_body, *responses) if container), + ) + ) + + return TypeScriptNamespace(namespace, tuple(operations)) diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/schema_parsing.py b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/schema_parsing.py new file mode 100644 index 0000000..c5cbbd0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/schema_parsing.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any, Literal, overload + +from litestar._openapi.typescript_converter.types import ( + TypeScriptAnonymousInterface, + TypeScriptArray, + TypeScriptElement, + TypeScriptInterface, + TypeScriptIntersection, + TypeScriptLiteral, + TypeScriptPrimitive, + TypeScriptProperty, + TypeScriptUnion, +) +from litestar.openapi.spec import Schema +from litestar.openapi.spec.enums import OpenAPIType + +__all__ = ("create_interface", "is_schema_value", "normalize_typescript_namespace", "parse_schema", "parse_type_schema") + +if TYPE_CHECKING: + from typing_extensions import TypeGuard + +openapi_typescript_equivalent_types = Literal[ + "string", "boolean", "number", "null", "Record<string, unknown>", "unknown[]" +] + +openapi_to_typescript_type_map: dict[OpenAPIType, openapi_typescript_equivalent_types] = { + OpenAPIType.ARRAY: "unknown[]", + OpenAPIType.BOOLEAN: "boolean", + OpenAPIType.INTEGER: "number", + OpenAPIType.NULL: "null", + OpenAPIType.NUMBER: "number", + OpenAPIType.OBJECT: "Record<string, unknown>", + OpenAPIType.STRING: "string", +} + +invalid_namespace_re = re.compile(r"[^\w+_$]*") +allowed_key_re = re.compile(r"[\w+_$]*") + + +def normalize_typescript_namespace(value: str, allow_quoted: bool) -> str: + """Normalize a namespace, e.g. variable name, or object key, to values supported by TS. + + Args: + value: A string to normalize. + allow_quoted: Whether to allow quoting the value. + + Returns: + A normalized value + """ + if not allow_quoted and not value[0].isalpha() and value[0] not in {"_", "$"}: + raise ValueError(f"invalid typescript namespace {value}") + if allow_quoted: + return value if allowed_key_re.fullmatch(value) else f'"{value}"' + return invalid_namespace_re.sub("", value) + + +def is_schema_value(value: Any) -> TypeGuard[Schema]: + """Typeguard for a schema value. + + Args: + value: An arbitrary value + + Returns: + A typeguard boolean dictating whether the passed in value is a Schema. + """ + return isinstance(value, Schema) + + +@overload +def create_interface(properties: dict[str, Schema], required: set[str] | None) -> TypeScriptAnonymousInterface: ... + + +@overload +def create_interface(properties: dict[str, Schema], required: set[str] | None, name: str) -> TypeScriptInterface: ... + + +def create_interface( + properties: dict[str, Schema], required: set[str] | None = None, name: str | None = None +) -> TypeScriptAnonymousInterface | TypeScriptInterface: + """Create a typescript interface from the given schema.properties values. + + Args: + properties: schema.properties mapping. + required: An optional list of required properties. + name: An optional string representing the interface name. + + Returns: + A typescript interface or anonymous interface. + """ + parsed_properties = tuple( + TypeScriptProperty( + key=normalize_typescript_namespace(key, allow_quoted=True), + value=parse_schema(schema), + required=key in required if required is not None else True, + ) + for key, schema in properties.items() + ) + return ( + TypeScriptInterface(name=name, properties=parsed_properties) + if name is not None + else TypeScriptAnonymousInterface(properties=parsed_properties) + ) + + +def parse_type_schema(schema: Schema) -> TypeScriptPrimitive | TypeScriptLiteral | TypeScriptUnion: + """Parse an OpenAPI schema representing a primitive type(s). + + Args: + schema: An OpenAPI schema. + + Returns: + A typescript type. + """ + if schema.enum: + return TypeScriptUnion(types=tuple(TypeScriptLiteral(value=value) for value in schema.enum)) + if schema.const: + return TypeScriptLiteral(value=schema.const) + if isinstance(schema.type, list): + return TypeScriptUnion( + tuple(TypeScriptPrimitive(openapi_to_typescript_type_map[s_type]) for s_type in schema.type) + ) + if schema.type in openapi_to_typescript_type_map and isinstance(schema.type, OpenAPIType): + return TypeScriptPrimitive(openapi_to_typescript_type_map[schema.type]) + raise TypeError(f"received an unexpected openapi type: {schema.type}") # pragma: no cover + + +def parse_schema(schema: Schema) -> TypeScriptElement: + """Parse an OpenAPI schema object recursively to create typescript types. + + Args: + schema: An OpenAPI Schema object. + + Returns: + A typescript type. + """ + if schema.all_of: + return TypeScriptIntersection(tuple(parse_schema(s) for s in schema.all_of if is_schema_value(s))) + if schema.one_of: + return TypeScriptUnion(tuple(parse_schema(s) for s in schema.one_of if is_schema_value(s))) + if is_schema_value(schema.items): + return TypeScriptArray(parse_schema(schema.items)) + if schema.type == OpenAPIType.OBJECT: + return create_interface( + properties={k: v for k, v in schema.properties.items() if is_schema_value(v)} if schema.properties else {}, + required=set(schema.required) if schema.required else None, + ) + return parse_type_schema(schema=schema) diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/types.py b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/types.py new file mode 100644 index 0000000..ff265d4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/typescript_converter/types.py @@ -0,0 +1,308 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Literal + +__all__ = ( + "TypeScriptAnonymousInterface", + "TypeScriptArray", + "TypeScriptConst", + "TypeScriptContainer", + "TypeScriptElement", + "TypeScriptEnum", + "TypeScriptInterface", + "TypeScriptIntersection", + "TypeScriptLiteral", + "TypeScriptNamespace", + "TypeScriptPrimitive", + "TypeScriptProperty", + "TypeScriptType", + "TypeScriptUnion", +) + + +def _as_string(value: Any) -> str: + if isinstance(value, str): + return f'"{value}"' + + if isinstance(value, bool): + return "true" if value else "false" + + return "null" if value is None else str(value) + + +class TypeScriptElement(ABC): + """A class representing a TypeScript type element.""" + + @abstractmethod + def write(self) -> str: + """Write a typescript value corresponding to the given typescript element. + + Returns: + A typescript string + """ + raise NotImplementedError("") + + +class TypeScriptContainer(TypeScriptElement): + """A class representing a TypeScript type container.""" + + name: str + + @abstractmethod + def write(self) -> str: + """Write a typescript value corresponding to the given typescript container. + + Returns: + A typescript string + """ + raise NotImplementedError("") + + +@dataclass(unsafe_hash=True) +class TypeScriptIntersection(TypeScriptElement): + """A class representing a TypeScript intersection type.""" + + types: tuple[TypeScriptElement, ...] + + def write(self) -> str: + """Write a typescript intersection value. + + Example: + { prop: string } & { another: number } + + Returns: + A typescript string + """ + return " & ".join(t.write() for t in self.types) + + +@dataclass(unsafe_hash=True) +class TypeScriptUnion(TypeScriptElement): + """A class representing a TypeScript union type.""" + + types: tuple[TypeScriptElement, ...] + + def write(self) -> str: + """Write a typescript union value. + + Example: + string | number + + Returns: + A typescript string + """ + return " | ".join(sorted(t.write() for t in self.types)) + + +@dataclass(unsafe_hash=True) +class TypeScriptPrimitive(TypeScriptElement): + """A class representing a TypeScript primitive type.""" + + type: Literal[ + "string", "number", "boolean", "any", "null", "undefined", "symbol", "Record<string, unknown>", "unknown[]" + ] + + def write(self) -> str: + """Write a typescript primitive type. + + Example: + null + + Returns: + A typescript string + """ + return self.type + + +@dataclass(unsafe_hash=True) +class TypeScriptLiteral(TypeScriptElement): + """A class representing a TypeScript literal type.""" + + value: str | int | float | bool | None + + def write(self) -> str: + """Write a typescript literal type. + + Example: + "someValue" + + Returns: + A typescript string + """ + return _as_string(self.value) + + +@dataclass(unsafe_hash=True) +class TypeScriptArray(TypeScriptElement): + """A class representing a TypeScript array type.""" + + item_type: TypeScriptElement + + def write(self) -> str: + """Write a typescript array type. + + Example: + number[] + + Returns: + A typescript string + """ + value = ( + f"({self.item_type.write()})" + if isinstance(self.item_type, (TypeScriptUnion, TypeScriptIntersection)) + else self.item_type.write() + ) + return f"{value}[]" + + +@dataclass(unsafe_hash=True) +class TypeScriptProperty(TypeScriptElement): + """A class representing a TypeScript interface property.""" + + required: bool + key: str + value: TypeScriptElement + + def write(self) -> str: + """Write a typescript property. This class is used exclusively inside interfaces. + + Example: + key: string; + optional?: number; + + Returns: + A typescript string + """ + return f"{self.key}{':' if self.required else '?:'} {self.value.write()};" + + +@dataclass(unsafe_hash=True) +class TypeScriptAnonymousInterface(TypeScriptElement): + """A class representing a TypeScript anonymous interface.""" + + properties: tuple[TypeScriptProperty, ...] + + def write(self) -> str: + """Write a typescript interface object, without a name. + + Example: + { + key: string; + optional?: number; + } + + Returns: + A typescript string + """ + props = "\t" + "\n\t".join([prop.write() for prop in sorted(self.properties, key=lambda prop: prop.key)]) + return f"{{\n{props}\n}}" + + +@dataclass(unsafe_hash=True) +class TypeScriptInterface(TypeScriptContainer): + """A class representing a TypeScript interface.""" + + name: str + properties: tuple[TypeScriptProperty, ...] + + def write(self) -> str: + """Write a typescript interface. + + Example: + export interface MyInterface { + key: string; + optional?: number; + }; + + Returns: + A typescript string + """ + interface = TypeScriptAnonymousInterface(properties=self.properties) + return f"export interface {self.name} {interface.write()};" + + +@dataclass(unsafe_hash=True) +class TypeScriptEnum(TypeScriptContainer): + """A class representing a TypeScript enum.""" + + name: str + values: tuple[tuple[str, str], ...] | tuple[tuple[str, int | float], ...] + + def write(self) -> str: + """Write a typescript enum. + + Example: + export enum MyEnum { + DOG = "canine", + CAT = "feline", + }; + + Returns: + A typescript string + """ + members = "\t" + "\n\t".join( + [f"{key} = {_as_string(value)}," for key, value in sorted(self.values, key=lambda member: member[0])] + ) + return f"export enum {self.name} {{\n{members}\n}};" + + +@dataclass(unsafe_hash=True) +class TypeScriptType(TypeScriptContainer): + """A class representing a TypeScript type.""" + + name: str + value: TypeScriptElement + + def write(self) -> str: + """Write a typescript type. + + Example: + export type MyType = number | "42"; + + Returns: + A typescript string + """ + return f"export type {self.name} = {self.value.write()};" + + +@dataclass(unsafe_hash=True) +class TypeScriptConst(TypeScriptContainer): + """A class representing a TypeScript const.""" + + name: str + value: TypeScriptPrimitive | TypeScriptLiteral + + def write(self) -> str: + """Write a typescript const. + + Example: + export const MyConst: number; + + Returns: + A typescript string + """ + return f"export const {self.name}: {self.value.write()};" + + +@dataclass(unsafe_hash=True) +class TypeScriptNamespace(TypeScriptContainer): + """A class representing a TypeScript namespace.""" + + name: str + values: tuple[TypeScriptContainer, ...] + + def write(self) -> str: + """Write a typescript namespace. + + Example: + export MyNamespace { + export const MyConst: number; + } + + Returns: + A typescript string + """ + members = "\t" + "\n\n\t".join([value.write() for value in sorted(self.values, key=lambda el: el.name)]) + return f"export namespace {self.name} {{\n{members}\n}};" diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/utils.py b/venv/lib/python3.11/site-packages/litestar/_openapi/utils.py new file mode 100644 index 0000000..b1950fa --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/utils.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +from litestar.types.internal_types import PathParameterDefinition + +if TYPE_CHECKING: + from litestar.handlers.http_handlers import HTTPRouteHandler + from litestar.types import Method + + +__all__ = ("default_operation_id_creator", "SEPARATORS_CLEANUP_PATTERN") + +SEPARATORS_CLEANUP_PATTERN = re.compile(r"[!#$%&'*+\-.^_`|~:]+") + + +def default_operation_id_creator( + route_handler: HTTPRouteHandler, + http_method: Method, + path_components: list[str | PathParameterDefinition], +) -> str: + """Create a unique 'operationId' for an OpenAPI PathItem entry. + + Args: + route_handler: The HTTP Route Handler instance. + http_method: The HTTP method for the given PathItem. + path_components: A list of path components. + + Returns: + A camelCased operationId created from the handler function name, + http method and path components. + """ + + handler_namespace = ( + http_method.title() + route_handler.handler_name.title() + if len(route_handler.http_methods) > 1 + else route_handler.handler_name.title() + ) + + components_namespace = "" + for component in (c.name if isinstance(c, PathParameterDefinition) else c for c in path_components): + if component.title() not in components_namespace: + components_namespace += component.title() + + return SEPARATORS_CLEANUP_PATTERN.sub("", components_namespace + handler_namespace) diff --git a/venv/lib/python3.11/site-packages/litestar/_parsers.py b/venv/lib/python3.11/site-packages/litestar/_parsers.py new file mode 100644 index 0000000..9b9f459 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_parsers.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from collections import defaultdict +from functools import lru_cache +from http.cookies import _unquote as unquote_cookie +from urllib.parse import unquote + +try: + from fast_query_parsers import parse_query_string as parse_qsl +except ImportError: + from urllib.parse import parse_qsl as _parse_qsl + + def parse_qsl(qs: bytes, separator: str) -> list[tuple[str, str]]: + return _parse_qsl(qs.decode("latin-1"), keep_blank_values=True, separator=separator) + + +__all__ = ("parse_cookie_string", "parse_query_string", "parse_url_encoded_form_data") + + +@lru_cache(1024) +def parse_url_encoded_form_data(encoded_data: bytes) -> dict[str, str | list[str]]: + """Parse an url encoded form data dict. + + Args: + encoded_data: The encoded byte string. + + Returns: + A parsed dict. + """ + decoded_dict: defaultdict[str, list[str]] = defaultdict(list) + for k, v in parse_qsl(encoded_data, separator="&"): + decoded_dict[k].append(v) + return {k: v if len(v) > 1 else v[0] for k, v in decoded_dict.items()} + + +@lru_cache(1024) +def parse_query_string(query_string: bytes) -> tuple[tuple[str, str], ...]: + """Parse a query string into a tuple of key value pairs. + + Args: + query_string: A query string. + + Returns: + A tuple of key value pairs. + """ + return tuple(parse_qsl(query_string, separator="&")) + + +@lru_cache(1024) +def parse_cookie_string(cookie_string: str) -> dict[str, str]: + """Parse a cookie string into a dictionary of values. + + Args: + cookie_string: A cookie string. + + Returns: + A string keyed dictionary of values + """ + cookies = [cookie.split("=", 1) if "=" in cookie else ("", cookie) for cookie in cookie_string.split(";")] + output: dict[str, str] = { + k: unquote(unquote_cookie(v)) + for k, v in filter( + lambda x: x[0] or x[1], + ((k.strip(), v.strip()) for k, v in cookies), + ) + } + return output diff --git a/venv/lib/python3.11/site-packages/litestar/_signature/__init__.py b/venv/lib/python3.11/site-packages/litestar/_signature/__init__.py new file mode 100644 index 0000000..418e3b5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_signature/__init__.py @@ -0,0 +1,3 @@ +from .model import SignatureModel + +__all__ = ("SignatureModel",) diff --git a/venv/lib/python3.11/site-packages/litestar/_signature/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_signature/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8282668 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_signature/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_signature/__pycache__/model.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_signature/__pycache__/model.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..9b8fbb7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_signature/__pycache__/model.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_signature/__pycache__/types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_signature/__pycache__/types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5abb937 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_signature/__pycache__/types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_signature/__pycache__/utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_signature/__pycache__/utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b8fcafb --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_signature/__pycache__/utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_signature/model.py b/venv/lib/python3.11/site-packages/litestar/_signature/model.py new file mode 100644 index 0000000..42c7994 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_signature/model.py @@ -0,0 +1,316 @@ +# ruff: noqa: UP006, UP007 +from __future__ import annotations + +import re +from functools import partial +from pathlib import Path, PurePath +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Literal, + Optional, + Sequence, + Set, + Type, + TypedDict, + Union, + cast, +) +from uuid import UUID + +from msgspec import NODEFAULT, Meta, Struct, ValidationError, convert, defstruct +from msgspec.structs import asdict +from typing_extensions import Annotated + +from litestar._signature.types import ExtendedMsgSpecValidationError +from litestar._signature.utils import ( + _get_decoder_for_type, + _normalize_annotation, + _validate_signature_dependencies, +) +from litestar.datastructures.state import ImmutableState +from litestar.datastructures.url import URL +from litestar.dto import AbstractDTO, DTOData +from litestar.enums import ParamType, ScopeType +from litestar.exceptions import InternalServerException, ValidationException +from litestar.params import KwargDefinition, ParameterKwarg +from litestar.typing import FieldDefinition # noqa +from litestar.utils import is_class_and_subclass +from litestar.utils.dataclass import simple_asdict + +if TYPE_CHECKING: + from typing_extensions import NotRequired + + from litestar.connection import ASGIConnection + from litestar.types import AnyCallable, TypeDecodersSequence + from litestar.utils.signature import ParsedSignature + +__all__ = ( + "ErrorMessage", + "SignatureModel", +) + + +class ErrorMessage(TypedDict): + # key may not be set in some cases, like when a query param is set but + # doesn't match the required length during `attrs` validation + # in this case, we don't show a key at all as it will be empty + key: NotRequired[str] + message: str + source: NotRequired[Literal["body"] | ParamType] + + +MSGSPEC_CONSTRAINT_FIELDS = ( + "gt", + "ge", + "lt", + "le", + "multiple_of", + "pattern", + "min_length", + "max_length", +) + +ERR_RE = re.compile(r"`\$\.(.+)`$") + +DEFAULT_TYPE_DECODERS = [ + (lambda x: is_class_and_subclass(x, (Path, PurePath, ImmutableState, UUID)), lambda t, v: t(v)), +] + + +def _deserializer(target_type: Any, value: Any, default_deserializer: Callable[[Any, Any], Any]) -> Any: + if isinstance(value, DTOData): + return value + + if isinstance(value, target_type): + return value + + if decoder := getattr(target_type, "_decoder", None): + return decoder(target_type, value) + + return default_deserializer(target_type, value) + + +class SignatureModel(Struct): + """Model that represents a function signature that uses a msgspec specific type or types.""" + + _data_dto: ClassVar[Optional[Type[AbstractDTO]]] + _dependency_name_set: ClassVar[Set[str]] + _fields: ClassVar[Dict[str, FieldDefinition]] + _return_annotation: ClassVar[Any] + + @classmethod + def _create_exception(cls, connection: ASGIConnection, messages: list[ErrorMessage]) -> Exception: + """Create an exception class - either a ValidationException or an InternalServerException, depending on whether + the failure is in client provided values or injected dependencies. + + Args: + connection: An ASGI connection instance. + messages: A list of error messages. + + Returns: + An Exception + """ + method = connection.method if hasattr(connection, "method") else ScopeType.WEBSOCKET # pyright: ignore + if client_errors := [ + err_message + for err_message in messages + if ("key" in err_message and err_message["key"] not in cls._dependency_name_set) or "key" not in err_message + ]: + path = URL.from_components( + path=connection.url.path, + query=connection.url.query, + ) + return ValidationException(detail=f"Validation failed for {method} {path}", extra=client_errors) + return InternalServerException() + + @classmethod + def _build_error_message(cls, keys: Sequence[str], exc_msg: str, connection: ASGIConnection) -> ErrorMessage: + """Build an error message. + + Args: + keys: A list of keys. + exc_msg: A message. + connection: An ASGI connection instance. + + Returns: + An ErrorMessage + """ + + message: ErrorMessage = {"message": exc_msg.split(" - ")[0]} + + if keys: + message["key"] = key = ".".join(keys) + if keys[0].startswith("data"): + message["key"] = message["key"].replace("data.", "") + message["source"] = "body" + elif key in connection.query_params: + message["source"] = ParamType.QUERY + + elif key in cls._fields and isinstance(cls._fields[key].kwarg_definition, ParameterKwarg): + if cast(ParameterKwarg, cls._fields[key].kwarg_definition).cookie: + message["source"] = ParamType.COOKIE + elif cast(ParameterKwarg, cls._fields[key].kwarg_definition).header: + message["source"] = ParamType.HEADER + else: + message["source"] = ParamType.QUERY + + return message + + @classmethod + def _collect_errors(cls, deserializer: Callable[[Any, Any], Any], **kwargs: Any) -> list[tuple[str, Exception]]: + exceptions: list[tuple[str, Exception]] = [] + for field_name in cls._fields: + try: + raw_value = kwargs[field_name] + annotation = cls.__annotations__[field_name] + convert(raw_value, type=annotation, strict=False, dec_hook=deserializer, str_keys=True) + except Exception as e: # noqa: BLE001 + exceptions.append((field_name, e)) + + return exceptions + + @classmethod + def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwargs: Any) -> dict[str, Any]: + """Extract values from the connection instance and return a dict of parsed values. + + Args: + connection: The ASGI connection instance. + **kwargs: A dictionary of kwargs. + + Raises: + ValidationException: If validation failed. + InternalServerException: If another exception has been raised. + + Returns: + A dictionary of parsed values + """ + messages: list[ErrorMessage] = [] + deserializer = partial(_deserializer, default_deserializer=connection.route_handler.default_deserializer) + try: + return convert(kwargs, cls, strict=False, dec_hook=deserializer, str_keys=True).to_dict() + except ExtendedMsgSpecValidationError as e: + for exc in e.errors: + keys = [str(loc) for loc in exc["loc"]] + message = cls._build_error_message(keys=keys, exc_msg=exc["msg"], connection=connection) + messages.append(message) + raise cls._create_exception(messages=messages, connection=connection) from e + except ValidationError as e: + for field_name, exc in cls._collect_errors(deserializer=deserializer, **kwargs): # type: ignore[assignment] + match = ERR_RE.search(str(exc)) + keys = [field_name, str(match.group(1))] if match else [field_name] + message = cls._build_error_message(keys=keys, exc_msg=str(exc), connection=connection) + messages.append(message) + raise cls._create_exception(messages=messages, connection=connection) from e + + def to_dict(self) -> dict[str, Any]: + """Normalize access to the signature model's dictionary method, because different backends use different methods + for this. + + Returns: A dictionary of string keyed values. + """ + return asdict(self) + + @classmethod + def create( + cls, + dependency_name_set: set[str], + fn: AnyCallable, + parsed_signature: ParsedSignature, + type_decoders: TypeDecodersSequence, + data_dto: type[AbstractDTO] | None = None, + ) -> type[SignatureModel]: + fn_name = ( + fn_name if (fn_name := getattr(fn, "__name__", "anonymous")) and fn_name != "<lambda>" else "anonymous" + ) + + dependency_names = _validate_signature_dependencies( + dependency_name_set=dependency_name_set, fn_name=fn_name, parsed_signature=parsed_signature + ) + + struct_fields: list[tuple[str, Any, Any]] = [] + + for field_definition in parsed_signature.parameters.values(): + meta_data: Meta | None = None + + if isinstance(field_definition.kwarg_definition, KwargDefinition): + meta_kwargs: dict[str, Any] = {"extra": {}} + + kwarg_definition = simple_asdict(field_definition.kwarg_definition, exclude_empty=True) + if min_items := kwarg_definition.pop("min_items", None): + meta_kwargs["min_length"] = min_items + if max_items := kwarg_definition.pop("max_items", None): + meta_kwargs["max_length"] = max_items + + for k, v in kwarg_definition.items(): + if hasattr(Meta, k) and v is not None: + meta_kwargs[k] = v + else: + meta_kwargs["extra"][k] = v + + meta_data = Meta(**meta_kwargs) + + annotation = cls._create_annotation( + field_definition=field_definition, + type_decoders=[*(type_decoders or []), *DEFAULT_TYPE_DECODERS], + meta_data=meta_data, + data_dto=data_dto, + ) + + default = field_definition.default if field_definition.has_default else NODEFAULT + struct_fields.append((field_definition.name, annotation, default)) + + return defstruct( # type:ignore[return-value] + f"{fn_name}_signature_model", + struct_fields, + bases=(cls,), + module=getattr(fn, "__module__", None), + namespace={ + "_return_annotation": parsed_signature.return_type.annotation, + "_dependency_name_set": dependency_names, + "_fields": parsed_signature.parameters, + "_data_dto": data_dto, + }, + kw_only=True, + ) + + @classmethod + def _create_annotation( + cls, + field_definition: FieldDefinition, + type_decoders: TypeDecodersSequence, + meta_data: Meta | None = None, + data_dto: type[AbstractDTO] | None = None, + ) -> Any: + # DTOs have already validated their data, so we can just use Any here + if field_definition.name == "data" and data_dto: + return Any + + annotation = _normalize_annotation(field_definition=field_definition) + + if annotation is Any: + return annotation + + if field_definition.is_union: + types = [ + cls._create_annotation( + field_definition=inner_type, + type_decoders=type_decoders, + meta_data=meta_data, + ) + for inner_type in field_definition.inner_types + if not inner_type.is_none_type + ] + return Optional[Union[tuple(types)]] if field_definition.is_optional else Union[tuple(types)] # pyright: ignore + + if decoder := _get_decoder_for_type(annotation, type_decoders=type_decoders): + # FIXME: temporary (hopefully) hack, see: https://github.com/jcrist/msgspec/issues/497 + setattr(annotation, "_decoder", decoder) + + if meta_data: + annotation = Annotated[annotation, meta_data] # pyright: ignore + + return annotation diff --git a/venv/lib/python3.11/site-packages/litestar/_signature/types.py b/venv/lib/python3.11/site-packages/litestar/_signature/types.py new file mode 100644 index 0000000..ac174cc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_signature/types.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from typing import Any + +from msgspec import ValidationError + + +class ExtendedMsgSpecValidationError(ValidationError): + def __init__(self, errors: list[dict[str, Any]]) -> None: + self.errors = errors + super().__init__(errors) diff --git a/venv/lib/python3.11/site-packages/litestar/_signature/utils.py b/venv/lib/python3.11/site-packages/litestar/_signature/utils.py new file mode 100644 index 0000000..8c0d15f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_signature/utils.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable + +from litestar.constants import SKIP_VALIDATION_NAMES +from litestar.exceptions import ImproperlyConfiguredException +from litestar.params import DependencyKwarg +from litestar.types import Empty, TypeDecodersSequence + +if TYPE_CHECKING: + from litestar.typing import FieldDefinition + from litestar.utils.signature import ParsedSignature + + +__all__ = ("_validate_signature_dependencies", "_normalize_annotation", "_get_decoder_for_type") + + +def _validate_signature_dependencies( + dependency_name_set: set[str], fn_name: str, parsed_signature: ParsedSignature +) -> set[str]: + """Validate dependencies of ``parsed_signature``. + + Args: + dependency_name_set: A set of dependency names + fn_name: A callable's name. + parsed_signature: A parsed signature. + + Returns: + A set of validated dependency names. + """ + dependency_names: set[str] = set(dependency_name_set) + + for parameter in parsed_signature.parameters.values(): + if isinstance(parameter.kwarg_definition, DependencyKwarg) and parameter.name not in dependency_name_set: + if not parameter.is_optional and parameter.default is Empty: + raise ImproperlyConfiguredException( + f"Explicit dependency '{parameter.name}' for '{fn_name}' has no default value, " + f"or provided dependency." + ) + dependency_names.add(parameter.name) + return dependency_names + + +def _normalize_annotation(field_definition: FieldDefinition) -> Any: + if field_definition.name in SKIP_VALIDATION_NAMES or ( + isinstance(field_definition.kwarg_definition, DependencyKwarg) + and field_definition.kwarg_definition.skip_validation + ): + return Any + + return field_definition.annotation + + +def _get_decoder_for_type(target_type: Any, type_decoders: TypeDecodersSequence) -> Callable[[type, Any], Any] | None: + return next( + (decoder for predicate, decoder in type_decoders if predicate(target_type)), + None, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/app.py b/venv/lib/python3.11/site-packages/litestar/app.py new file mode 100644 index 0000000..e1bd989 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/app.py @@ -0,0 +1,880 @@ +from __future__ import annotations + +import inspect +import logging +import os +from contextlib import ( + AbstractAsyncContextManager, + AsyncExitStack, + asynccontextmanager, + suppress, +) +from datetime import date, datetime, time, timedelta +from functools import partial +from itertools import chain +from pathlib import Path +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Iterable, Mapping, Sequence, TypedDict, cast + +from litestar._asgi import ASGIRouter +from litestar._asgi.utils import get_route_handlers, wrap_in_exception_handler +from litestar._openapi.plugin import OpenAPIPlugin +from litestar._openapi.schema_generation import openapi_schema_plugins +from litestar.config.allowed_hosts import AllowedHostsConfig +from litestar.config.app import AppConfig +from litestar.config.response_cache import ResponseCacheConfig +from litestar.connection import Request, WebSocket +from litestar.datastructures.state import State +from litestar.events.emitter import BaseEventEmitterBackend, SimpleEventEmitter +from litestar.exceptions import ( + MissingDependencyException, + NoRouteMatchFoundException, +) +from litestar.logging.config import LoggingConfig, get_logger_placeholder +from litestar.middleware.cors import CORSMiddleware +from litestar.openapi.config import OpenAPIConfig +from litestar.plugins import ( + CLIPluginProtocol, + InitPluginProtocol, + OpenAPISchemaPluginProtocol, + PluginProtocol, + PluginRegistry, + SerializationPluginProtocol, +) +from litestar.plugins.base import CLIPlugin +from litestar.router import Router +from litestar.routes import ASGIRoute, HTTPRoute, WebSocketRoute +from litestar.static_files.base import StaticFiles +from litestar.stores.registry import StoreRegistry +from litestar.types import Empty, TypeDecodersSequence +from litestar.types.internal_types import PathParameterDefinition, TemplateConfigType +from litestar.utils import deprecated, ensure_async_callable, join_paths, unique +from litestar.utils.dataclass import extract_dataclass_items +from litestar.utils.predicates import is_async_callable +from litestar.utils.warnings import warn_pdb_on_exception + +if TYPE_CHECKING: + from typing_extensions import Self + + from litestar.config.app import ExperimentalFeatures + from litestar.config.compression import CompressionConfig + from litestar.config.cors import CORSConfig + from litestar.config.csrf import CSRFConfig + from litestar.datastructures import CacheControlHeader, ETag + from litestar.dto import AbstractDTO + from litestar.events.listener import EventListener + from litestar.logging.config import BaseLoggingConfig + from litestar.openapi.spec import SecurityRequirement + from litestar.openapi.spec.open_api import OpenAPI + from litestar.response import Response + from litestar.static_files.config import StaticFilesConfig + from litestar.stores.base import Store + from litestar.types import ( + AfterExceptionHookHandler, + AfterRequestHookHandler, + AfterResponseHookHandler, + AnyCallable, + ASGIApp, + BeforeMessageSendHookHandler, + BeforeRequestHookHandler, + ControllerRouterHandler, + Dependencies, + EmptyType, + ExceptionHandlersMap, + GetLogger, + Guard, + LifeSpanReceive, + LifeSpanScope, + LifeSpanSend, + Logger, + Message, + Middleware, + OnAppInitHandler, + ParametersMap, + Receive, + ResponseCookies, + ResponseHeaders, + RouteHandlerType, + Scope, + Send, + TypeEncodersMap, + ) + from litestar.types.callable_types import LifespanHook + + +__all__ = ("HandlerIndex", "Litestar", "DEFAULT_OPENAPI_CONFIG") + +DEFAULT_OPENAPI_CONFIG = OpenAPIConfig(title="Litestar API", version="1.0.0") +"""The default OpenAPI config used if not configuration is explicitly passed to the +:class:`Litestar <.app.Litestar>` instance constructor. +""" + + +class HandlerIndex(TypedDict): + """Map route handler names to a mapping of paths + route handler. + + It's returned from the 'get_handler_index_by_name' utility method. + """ + + paths: list[str] + """Full route paths to the route handler.""" + handler: RouteHandlerType + """Route handler instance.""" + identifier: str + """Unique identifier of the handler. + + Either equal to :attr`__name__ <obj.__name__>` attribute or ``__str__`` value of the handler. + """ + + +class Litestar(Router): + """The Litestar application. + + ``Litestar`` is the root level of the app - it has the base path of ``/`` and all root level Controllers, Routers + and Route Handlers should be registered on it. + """ + + __slots__ = ( + "_lifespan_managers", + "_server_lifespan_managers", + "_debug", + "_openapi_schema", + "_static_files_config", + "plugins", + "after_exception", + "allowed_hosts", + "asgi_handler", + "asgi_router", + "before_send", + "compression_config", + "cors_config", + "csrf_config", + "event_emitter", + "get_logger", + "include_in_schema", + "logger", + "logging_config", + "multipart_form_part_limit", + "on_shutdown", + "on_startup", + "openapi_config", + "request_class", + "response_cache_config", + "route_map", + "signature_namespace", + "state", + "stores", + "template_engine", + "websocket_class", + "pdb_on_exception", + "experimental_features", + ) + + def __init__( + self, + route_handlers: Sequence[ControllerRouterHandler] | None = None, + *, + after_exception: Sequence[AfterExceptionHookHandler] | None = None, + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + allowed_hosts: Sequence[str] | AllowedHostsConfig | None = None, + before_request: BeforeRequestHookHandler | None = None, + before_send: Sequence[BeforeMessageSendHookHandler] | None = None, + cache_control: CacheControlHeader | None = None, + compression_config: CompressionConfig | None = None, + cors_config: CORSConfig | None = None, + csrf_config: CSRFConfig | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + debug: bool | None = None, + dependencies: Dependencies | None = None, + etag: ETag | None = None, + event_emitter_backend: type[BaseEventEmitterBackend] = SimpleEventEmitter, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + include_in_schema: bool | EmptyType = Empty, + listeners: Sequence[EventListener] | None = None, + logging_config: BaseLoggingConfig | EmptyType | None = Empty, + middleware: Sequence[Middleware] | None = None, + multipart_form_part_limit: int = 1000, + on_app_init: Sequence[OnAppInitHandler] | None = None, + on_shutdown: Sequence[LifespanHook] | None = None, + on_startup: Sequence[LifespanHook] | None = None, + openapi_config: OpenAPIConfig | None = DEFAULT_OPENAPI_CONFIG, + opt: Mapping[str, Any] | None = None, + parameters: ParametersMap | None = None, + plugins: Sequence[PluginProtocol] | None = None, + request_class: type[Request] | None = None, + response_cache_config: ResponseCacheConfig | None = None, + response_class: type[Response] | None = None, + response_cookies: ResponseCookies | None = None, + response_headers: ResponseHeaders | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + security: Sequence[SecurityRequirement] | None = None, + signature_namespace: Mapping[str, Any] | None = None, + signature_types: Sequence[Any] | None = None, + state: State | None = None, + static_files_config: Sequence[StaticFilesConfig] | None = None, + stores: StoreRegistry | dict[str, Store] | None = None, + tags: Sequence[str] | None = None, + template_config: TemplateConfigType | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, + lifespan: Sequence[Callable[[Litestar], AbstractAsyncContextManager] | AbstractAsyncContextManager] + | None = None, + pdb_on_exception: bool | None = None, + experimental_features: Iterable[ExperimentalFeatures] | None = None, + ) -> None: + """Initialize a ``Litestar`` application. + + Args: + after_exception: A sequence of :class:`exception hook handlers <.types.AfterExceptionHookHandler>`. This + hook is called after an exception occurs. In difference to exception handlers, it is not meant to + return a response - only to process the exception (e.g. log it, send it to Sentry etc.). + after_request: A sync or async function executed after the route handler function returned and the response + object has been resolved. Receives the response object. + after_response: A sync or async function called after the response has been awaited. It receives the + :class:`Request <.connection.Request>` object and should not return any values. + allowed_hosts: A sequence of allowed hosts, or an + :class:`AllowedHostsConfig <.config.allowed_hosts.AllowedHostsConfig>` instance. Enables the builtin + allowed hosts middleware. + before_request: A sync or async function called immediately before calling the route handler. Receives the + :class:`Request <.connection.Request>` instance and any non-``None`` return value is used for the + response, bypassing the route handler. + before_send: A sequence of :class:`before send hook handlers <.types.BeforeMessageSendHookHandler>`. Called + when the ASGI send function is called. + cache_control: A ``cache-control`` header of type + :class:`CacheControlHeader <litestar.datastructures.CacheControlHeader>` to add to route handlers of + this app. Can be overridden by route handlers. + compression_config: Configures compression behaviour of the application, this enabled a builtin or user + defined Compression middleware. + cors_config: If set, configures :class:`CORSMiddleware <.middleware.cors.CORSMiddleware>`. + csrf_config: If set, configures :class:`CSRFMiddleware <.middleware.csrf.CSRFMiddleware>`. + debug: If ``True``, app errors rendered as HTML with a stack trace. + dependencies: A string keyed mapping of dependency :class:`Providers <.di.Provide>`. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + etag: An ``etag`` header of type :class:`ETag <.datastructures.ETag>` to add to route handlers of this app. + Can be overridden by route handlers. + event_emitter_backend: A subclass of + :class:`BaseEventEmitterBackend <.events.emitter.BaseEventEmitterBackend>`. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + include_in_schema: A boolean flag dictating whether the route handler should be documented in the OpenAPI schema. + lifespan: A list of callables returning async context managers, wrapping the lifespan of the ASGI application + listeners: A sequence of :class:`EventListener <.events.listener.EventListener>`. + logging_config: A subclass of :class:`BaseLoggingConfig <.logging.config.BaseLoggingConfig>`. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + multipart_form_part_limit: The maximal number of allowed parts in a multipart/formdata request. This limit + is intended to protect from DoS attacks. + on_app_init: A sequence of :class:`OnAppInitHandler <.types.OnAppInitHandler>` instances. Handlers receive + an instance of :class:`AppConfig <.config.app.AppConfig>` that will have been initially populated with + the parameters passed to :class:`Litestar <litestar.app.Litestar>`, and must return an instance of same. + If more than one handler is registered they are called in the order they are provided. + on_shutdown: A sequence of :class:`LifespanHook <.types.LifespanHook>` called during application + shutdown. + on_startup: A sequence of :class:`LifespanHook <litestar.types.LifespanHook>` called during + application startup. + openapi_config: Defaults to :attr:`DEFAULT_OPENAPI_CONFIG` + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <litestar.connection.request.Request>` or + :class:`ASGI Scope <.types.Scope>`. + parameters: A mapping of :class:`Parameter <.params.Parameter>` definitions available to all application + paths. + pdb_on_exception: Drop into the PDB when an exception occurs. + plugins: Sequence of plugins. + request_class: An optional subclass of :class:`Request <.connection.Request>` to use for http connections. + response_class: A custom subclass of :class:`Response <.response.Response>` to be used as the app's default + response. + response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>`. + response_headers: A string keyed mapping of :class:`ResponseHeader <.datastructures.ResponseHeader>` + response_cache_config: Configures caching behavior of the application. + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + route_handlers: A sequence of route handlers, which can include instances of + :class:`Router <.router.Router>`, subclasses of :class:`Controller <.controller.Controller>` or any + callable decorated by the route handler decorators. + security: A sequence of dicts that will be added to the schema of all route handlers in the application. + See + :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>` for details. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modeling. + signature_types: A sequence of types for use in forward reference resolution during signature modeling. + These types will be added to the signature namespace using their ``__name__`` attribute. + state: An optional :class:`State <.datastructures.State>` for application state. + static_files_config: A sequence of :class:`StaticFilesConfig <.static_files.StaticFilesConfig>` + stores: Central registry of :class:`Store <.stores.base.Store>` that will be available throughout the + application. If this is a dictionary to it will be passed to a + :class:`StoreRegistry <.stores.registry.StoreRegistry>`. If it is a + :class:`StoreRegistry <.stores.registry.StoreRegistry>`, this instance will be used directly. + tags: A sequence of string tags that will be appended to the schema of all route handlers under the + application. + template_config: An instance of :class:`TemplateConfig <.template.TemplateConfig>` + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + websocket_class: An optional subclass of :class:`WebSocket <.connection.WebSocket>` to use for websocket + connections. + experimental_features: An iterable of experimental features to enable + """ + + if logging_config is Empty: + logging_config = LoggingConfig() + + if debug is None: + debug = os.getenv("LITESTAR_DEBUG", "0") == "1" + + if pdb_on_exception is None: + pdb_on_exception = os.getenv("LITESTAR_PDB", "0") == "1" + + config = AppConfig( + after_exception=list(after_exception or []), + after_request=after_request, + after_response=after_response, + allowed_hosts=allowed_hosts if isinstance(allowed_hosts, AllowedHostsConfig) else list(allowed_hosts or []), + before_request=before_request, + before_send=list(before_send or []), + cache_control=cache_control, + compression_config=compression_config, + cors_config=cors_config, + csrf_config=csrf_config, + debug=debug, + dependencies=dict(dependencies or {}), + dto=dto, + etag=etag, + event_emitter_backend=event_emitter_backend, + exception_handlers=exception_handlers or {}, + guards=list(guards or []), + include_in_schema=include_in_schema, + lifespan=list(lifespan or []), + listeners=list(listeners or []), + logging_config=logging_config, + middleware=list(middleware or []), + multipart_form_part_limit=multipart_form_part_limit, + on_shutdown=list(on_shutdown or []), + on_startup=list(on_startup or []), + openapi_config=openapi_config, + opt=dict(opt or {}), + parameters=parameters or {}, + pdb_on_exception=pdb_on_exception, + plugins=self._get_default_plugins(list(plugins or [])), + request_class=request_class, + response_cache_config=response_cache_config or ResponseCacheConfig(), + response_class=response_class, + response_cookies=response_cookies or [], + response_headers=response_headers or [], + return_dto=return_dto, + route_handlers=list(route_handlers) if route_handlers is not None else [], + security=list(security or []), + signature_namespace=dict(signature_namespace or {}), + signature_types=list(signature_types or []), + state=state or State(), + static_files_config=list(static_files_config or []), + stores=stores, + tags=list(tags or []), + template_config=template_config, + type_encoders=type_encoders, + type_decoders=type_decoders, + websocket_class=websocket_class, + experimental_features=list(experimental_features or []), + ) + + config.plugins.extend([OpenAPIPlugin(self), *openapi_schema_plugins]) + + for handler in chain( + on_app_init or [], + (p.on_app_init for p in config.plugins if isinstance(p, InitPluginProtocol)), + ): + config = handler(config) # pyright: ignore + self.plugins = PluginRegistry(config.plugins) + + self._openapi_schema: OpenAPI | None = None + self._debug: bool = True + self.stores: StoreRegistry = ( + config.stores if isinstance(config.stores, StoreRegistry) else StoreRegistry(config.stores) + ) + self._lifespan_managers = config.lifespan + for store in self.stores._stores.values(): + self._lifespan_managers.append(store) + self._server_lifespan_managers = [p.server_lifespan for p in config.plugins or [] if isinstance(p, CLIPlugin)] + self.experimental_features = frozenset(config.experimental_features or []) + self.get_logger: GetLogger = get_logger_placeholder + self.logger: Logger | None = None + self.routes: list[HTTPRoute | ASGIRoute | WebSocketRoute] = [] + self.asgi_router = ASGIRouter(app=self) + + self.after_exception = [ensure_async_callable(h) for h in config.after_exception] + self.allowed_hosts = cast("AllowedHostsConfig | None", config.allowed_hosts) + self.before_send = [ensure_async_callable(h) for h in config.before_send] + self.compression_config = config.compression_config + self.cors_config = config.cors_config + self.csrf_config = config.csrf_config + self.event_emitter = config.event_emitter_backend(listeners=config.listeners) + self.logging_config = config.logging_config + self.multipart_form_part_limit = config.multipart_form_part_limit + self.on_shutdown = config.on_shutdown + self.on_startup = config.on_startup + self.openapi_config = config.openapi_config + self.request_class: type[Request] = config.request_class or Request + self.response_cache_config = config.response_cache_config + self.state = config.state + self._static_files_config = config.static_files_config + self.template_engine = config.template_config.engine_instance if config.template_config else None + self.websocket_class: type[WebSocket] = config.websocket_class or WebSocket + self.debug = config.debug + self.pdb_on_exception: bool = config.pdb_on_exception + self.include_in_schema = include_in_schema + + if self.pdb_on_exception: + warn_pdb_on_exception() + + try: + from starlette.exceptions import HTTPException as StarletteHTTPException + + from litestar.middleware.exceptions.middleware import _starlette_exception_handler + + config.exception_handlers.setdefault(StarletteHTTPException, _starlette_exception_handler) + except ImportError: + pass + + super().__init__( + after_request=config.after_request, + after_response=config.after_response, + before_request=config.before_request, + cache_control=config.cache_control, + dependencies=config.dependencies, + dto=config.dto, + etag=config.etag, + exception_handlers=config.exception_handlers, + guards=config.guards, + middleware=config.middleware, + opt=config.opt, + parameters=config.parameters, + path="", + request_class=self.request_class, + response_class=config.response_class, + response_cookies=config.response_cookies, + response_headers=config.response_headers, + return_dto=config.return_dto, + # route handlers are registered below + route_handlers=[], + security=config.security, + signature_namespace=config.signature_namespace, + signature_types=config.signature_types, + tags=config.tags, + type_encoders=config.type_encoders, + type_decoders=config.type_decoders, + include_in_schema=config.include_in_schema, + websocket_class=self.websocket_class, + ) + + for route_handler in config.route_handlers: + self.register(route_handler) + + if self.logging_config: + self.get_logger = self.logging_config.configure() + self.logger = self.get_logger("litestar") + + for static_config in self._static_files_config: + self.register(static_config.to_static_files_app()) + + self.asgi_handler = self._create_asgi_handler() + + @property + @deprecated(version="2.6.0", kind="property", info="Use create_static_files router instead") + def static_files_config(self) -> list[StaticFilesConfig]: + return self._static_files_config + + @property + @deprecated(version="2.0", alternative="Litestar.plugins.cli", kind="property") + def cli_plugins(self) -> list[CLIPluginProtocol]: + return list(self.plugins.cli) + + @property + @deprecated(version="2.0", alternative="Litestar.plugins.openapi", kind="property") + def openapi_schema_plugins(self) -> list[OpenAPISchemaPluginProtocol]: + return list(self.plugins.openapi) + + @property + @deprecated(version="2.0", alternative="Litestar.plugins.serialization", kind="property") + def serialization_plugins(self) -> list[SerializationPluginProtocol]: + return list(self.plugins.serialization) + + @staticmethod + def _get_default_plugins(plugins: list[PluginProtocol]) -> list[PluginProtocol]: + from litestar.plugins.core import MsgspecDIPlugin + + plugins.append(MsgspecDIPlugin()) + + with suppress(MissingDependencyException): + from litestar.contrib.pydantic import ( + PydanticDIPlugin, + PydanticInitPlugin, + PydanticPlugin, + PydanticSchemaPlugin, + ) + + pydantic_plugin_found = any(isinstance(plugin, PydanticPlugin) for plugin in plugins) + pydantic_init_plugin_found = any(isinstance(plugin, PydanticInitPlugin) for plugin in plugins) + pydantic_schema_plugin_found = any(isinstance(plugin, PydanticSchemaPlugin) for plugin in plugins) + pydantic_serialization_plugin_found = any(isinstance(plugin, PydanticDIPlugin) for plugin in plugins) + if not pydantic_plugin_found and not pydantic_init_plugin_found and not pydantic_schema_plugin_found: + plugins.append(PydanticPlugin()) + elif not pydantic_plugin_found and pydantic_init_plugin_found and not pydantic_schema_plugin_found: + plugins.append(PydanticSchemaPlugin()) + elif not pydantic_plugin_found and not pydantic_init_plugin_found: + plugins.append(PydanticInitPlugin()) + if not pydantic_plugin_found and not pydantic_serialization_plugin_found: + plugins.append(PydanticDIPlugin()) + with suppress(MissingDependencyException): + from litestar.contrib.attrs import AttrsSchemaPlugin + + pre_configured = any(isinstance(plugin, AttrsSchemaPlugin) for plugin in plugins) + if not pre_configured: + plugins.append(AttrsSchemaPlugin()) + return plugins + + @property + def debug(self) -> bool: + return self._debug + + @debug.setter + def debug(self, value: bool) -> None: + """Sets the debug logging level for the application. + + When possible, it calls the `self.logging_config.set_level` method. This allows for implementation specific code and APIs to be called. + """ + if self.logger and self.logging_config: + self.logging_config.set_level(self.logger, logging.DEBUG if value else logging.INFO) + elif self.logger and hasattr(self.logger, "setLevel"): # pragma: no cover + self.logger.setLevel(logging.DEBUG if value else logging.INFO) # pragma: no cover + if isinstance(self.logging_config, LoggingConfig): + self.logging_config.loggers["litestar"]["level"] = "DEBUG" if value else "INFO" + self._debug = value + + async def __call__( + self, + scope: Scope | LifeSpanScope, + receive: Receive | LifeSpanReceive, + send: Send | LifeSpanSend, + ) -> None: + """Application entry point. + + Lifespan events (startup / shutdown) are sent to the lifespan handler, otherwise the ASGI handler is used + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + if scope["type"] == "lifespan": + await self.asgi_router.lifespan(receive=receive, send=send) # type: ignore[arg-type] + return + + scope["app"] = self + scope.setdefault("state", {}) + await self.asgi_handler(scope, receive, self._wrap_send(send=send, scope=scope)) # type: ignore[arg-type] + + async def _call_lifespan_hook(self, hook: LifespanHook) -> None: + ret = hook(self) if inspect.signature(hook).parameters else hook() # type: ignore[call-arg] + + if is_async_callable(hook): # pyright: ignore[reportGeneralTypeIssues] + await ret + + @asynccontextmanager + async def lifespan(self) -> AsyncGenerator[None, None]: + """Context manager handling the ASGI lifespan. + + It will be entered when the ``lifespan`` message has been received from the + server, and exit after the ``asgi.shutdown`` message. During this period, it is + responsible for calling the ``on_startup``, ``on_shutdown`` hooks, as well as + custom lifespan managers. + """ + async with AsyncExitStack() as exit_stack: + for hook in self.on_shutdown[::-1]: + exit_stack.push_async_callback(partial(self._call_lifespan_hook, hook)) + + await exit_stack.enter_async_context(self.event_emitter) + + for manager in self._lifespan_managers: + if not isinstance(manager, AbstractAsyncContextManager): + manager = manager(self) + await exit_stack.enter_async_context(manager) + + for hook in self.on_startup: + await self._call_lifespan_hook(hook) + + yield + + @property + def openapi_schema(self) -> OpenAPI: + """Access the OpenAPI schema of the application. + + Returns: + The :class:`OpenAPI` + <pydantic_openapi_schema.open_api.OpenAPI> instance of the + application. + + Raises: + ImproperlyConfiguredException: If the application ``openapi_config`` attribute is ``None``. + """ + return self.plugins.get(OpenAPIPlugin).provide_openapi() + + @classmethod + def from_config(cls, config: AppConfig) -> Self: + """Initialize a ``Litestar`` application from a configuration instance. + + Args: + config: An instance of :class:`AppConfig` <.config.AppConfig> + + Returns: + An instance of ``Litestar`` application. + """ + return cls(**dict(extract_dataclass_items(config))) + + def register(self, value: ControllerRouterHandler) -> None: # type: ignore[override] + """Register a route handler on the app. + + This method can be used to dynamically add endpoints to an application. + + Args: + value: An instance of :class:`Router <.router.Router>`, a subclass of + :class:`Controller <.controller.Controller>` or any function decorated by the route handler decorators. + + Returns: + None + """ + routes = super().register(value=value) + + for route in routes: + route_handlers = get_route_handlers(route) + + for route_handler in route_handlers: + route_handler.on_registration(self) + + if isinstance(route, HTTPRoute): + route.create_handler_map() + + elif isinstance(route, WebSocketRoute): + route.handler_parameter_model = route.create_handler_kwargs_model(route.route_handler) + + for plugin in self.plugins.receive_route: + plugin.receive_route(route) + + self.asgi_router.construct_routing_trie() + + def get_handler_index_by_name(self, name: str) -> HandlerIndex | None: + """Receives a route handler name and returns an optional dictionary containing the route handler instance and + list of paths sorted lexically. + + Examples: + .. code-block:: python + + from litestar import Litestar, get + + + @get("/", name="my-handler") + def handler() -> None: + pass + + + app = Litestar(route_handlers=[handler]) + + handler_index = app.get_handler_index_by_name("my-handler") + + # { "paths": ["/"], "handler" ... } + + Args: + name: A route handler unique name. + + Returns: + A :class:`HandlerIndex <.app.HandlerIndex>` instance or ``None``. + """ + handler = self.asgi_router.route_handler_index.get(name) + if not handler: + return None + + identifier = handler.name or str(handler) + routes = self.asgi_router.route_mapping[identifier] + paths = sorted(unique([route.path for route in routes])) + + return HandlerIndex(handler=handler, paths=paths, identifier=identifier) + + def route_reverse(self, name: str, **path_parameters: Any) -> str: + """Receives a route handler name, path parameter values and returns url path to the handler with filled path + parameters. + + Examples: + .. code-block:: python + + from litestar import Litestar, get + + + @get("/group/{group_id:int}/user/{user_id:int}", name="get_membership_details") + def get_membership_details(group_id: int, user_id: int) -> None: + pass + + + app = Litestar(route_handlers=[get_membership_details]) + + path = app.route_reverse("get_membership_details", user_id=100, group_id=10) + + # /group/10/user/100 + + Args: + name: A route handler unique name. + **path_parameters: Actual values for path parameters in the route. + + Raises: + NoRouteMatchFoundException: If route with 'name' does not exist, path parameters are missing in + ``**path_parameters or have wrong type``. + + Returns: + A fully formatted url path. + """ + handler_index = self.get_handler_index_by_name(name) + if handler_index is None: + raise NoRouteMatchFoundException(f"Route {name} can not be found") + + allow_str_instead = {datetime, date, time, timedelta, float, Path} + routes = sorted( + self.asgi_router.route_mapping[handler_index["identifier"]], + key=lambda r: len(r.path_parameters), + reverse=True, + ) + passed_parameters = set(path_parameters.keys()) + + selected_route = next( + ( + route + for route in routes + if passed_parameters.issuperset({param.name for param in route.path_parameters}) + ), + routes[-1], + ) + output: list[str] = [] + for component in selected_route.path_components: + if isinstance(component, PathParameterDefinition): + val = path_parameters.get(component.name) + if not isinstance(val, component.type) and ( + component.type not in allow_str_instead or not isinstance(val, str) + ): + raise NoRouteMatchFoundException( + f"Received type for path parameter {component.name} doesn't match declared type {component.type}" + ) + output.append(str(val)) + else: + output.append(component) + + return join_paths(output) + + @deprecated( + "2.6.0", info="Use create_static_files router instead of StaticFilesConfig, which works with route_reverse" + ) + def url_for_static_asset(self, name: str, file_path: str) -> str: + """Receives a static files handler name, an asset file path and returns resolved url path to the asset. + + Examples: + .. code-block:: python + + from litestar import Litestar + from litestar.static_files.config import StaticFilesConfig + + app = Litestar( + static_files_config=[ + StaticFilesConfig(directories=["css"], path="/static/css", name="css") + ] + ) + + path = app.url_for_static_asset("css", "main.css") + + # /static/css/main.css + + Args: + name: A static handler unique name. + file_path: a string containing path to an asset. + + Raises: + NoRouteMatchFoundException: If static files handler with ``name`` does not exist. + + Returns: + A url path to the asset. + """ + + handler_index = self.get_handler_index_by_name(name) + if handler_index is None: + raise NoRouteMatchFoundException(f"Static handler {name} can not be found") + + handler_fn = cast("AnyCallable", handler_index["handler"].fn) + if not isinstance(handler_fn, StaticFiles): + raise NoRouteMatchFoundException(f"Handler with name {name} is not a static files handler") + + return join_paths([handler_index["paths"][0], file_path]) # type: ignore[unreachable] + + @property + def route_handler_method_view(self) -> dict[str, list[str]]: + """Map route handlers to paths. + + Returns: + A dictionary of router handlers and lists of paths as strings + """ + route_map: dict[str, list[str]] = { + handler: [route.path for route in routes] for handler, routes in self.asgi_router.route_mapping.items() + } + return route_map + + def _create_asgi_handler(self) -> ASGIApp: + """Create an ASGIApp that wraps the ASGI router inside an exception handler. + + If CORS or TrustedHost configs are provided to the constructor, they will wrap the router as well. + """ + asgi_handler: ASGIApp = self.asgi_router + if self.cors_config: + asgi_handler = CORSMiddleware(app=asgi_handler, config=self.cors_config) + + return wrap_in_exception_handler( + app=asgi_handler, + exception_handlers=self.exception_handlers or {}, # pyright: ignore + ) + + def _wrap_send(self, send: Send, scope: Scope) -> Send: + """Wrap the ASGI send and handles any 'before send' hooks. + + Args: + send: The ASGI send function. + scope: The ASGI scope. + + Returns: + An ASGI send function. + """ + if self.before_send: + + async def wrapped_send(message: Message) -> None: + for hook in self.before_send: + await hook(message, scope) + await send(message) + + return wrapped_send + return send + + def update_openapi_schema(self) -> None: + """Update the OpenAPI schema to reflect the route handlers registered on the app. + + Returns: + None + """ + self.plugins.get(OpenAPIPlugin)._build_openapi_schema() + + def emit(self, event_id: str, *args: Any, **kwargs: Any) -> None: + """Emit an event to all attached listeners. + + Args: + event_id: The ID of the event to emit, e.g ``my_event``. + args: args to pass to the listener(s). + kwargs: kwargs to pass to the listener(s) + + Returns: + None + """ + self.event_emitter.emit(event_id, *args, **kwargs) diff --git a/venv/lib/python3.11/site-packages/litestar/background_tasks.py b/venv/lib/python3.11/site-packages/litestar/background_tasks.py new file mode 100644 index 0000000..a475836 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/background_tasks.py @@ -0,0 +1,74 @@ +from typing import Any, Callable, Iterable + +from anyio import create_task_group +from typing_extensions import ParamSpec + +from litestar.utils.sync import ensure_async_callable + +__all__ = ("BackgroundTask", "BackgroundTasks") + + +P = ParamSpec("P") + + +class BackgroundTask: + """A container for a 'background' task function. + + Background tasks are called once a Response finishes. + """ + + __slots__ = ("fn", "args", "kwargs") + + def __init__(self, fn: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None: + """Initialize ``BackgroundTask``. + + Args: + fn: A sync or async function to call as the background task. + *args: Args to pass to the func. + **kwargs: Kwargs to pass to the func + """ + self.fn = ensure_async_callable(fn) + self.args = args + self.kwargs = kwargs + + async def __call__(self) -> None: + """Call the wrapped function with the passed in arguments. + + Returns: + None + """ + await self.fn(*self.args, **self.kwargs) + + +class BackgroundTasks: + """A container for multiple 'background' task functions. + + Background tasks are called once a Response finishes. + """ + + __slots__ = ("tasks", "run_in_task_group") + + def __init__(self, tasks: Iterable[BackgroundTask], run_in_task_group: bool = False) -> None: + """Initialize ``BackgroundTasks``. + + Args: + tasks: An iterable of :class:`BackgroundTask <.background_tasks.BackgroundTask>` instances. + run_in_task_group: If you set this value to ``True`` than the tasks will run concurrently, using + a :class:`TaskGroup <anyio.abc.TaskGroup>`. Note: This will not preserve execution order. + """ + self.tasks = tasks + self.run_in_task_group = run_in_task_group + + async def __call__(self) -> None: + """Call the wrapped background tasks. + + Returns: + None + """ + if self.run_in_task_group: + async with create_task_group() as task_group: + for task in self.tasks: + task_group.start_soon(task) + else: + for task in self.tasks: + await task() diff --git a/venv/lib/python3.11/site-packages/litestar/channels/__init__.py b/venv/lib/python3.11/site-packages/litestar/channels/__init__.py new file mode 100644 index 0000000..0167223 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/__init__.py @@ -0,0 +1,5 @@ +from .backends.base import ChannelsBackend +from .plugin import ChannelsPlugin +from .subscriber import Subscriber + +__all__ = ("ChannelsPlugin", "ChannelsBackend", "Subscriber") diff --git a/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..bf9d6bd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/plugin.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/plugin.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..08361dc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/plugin.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/subscriber.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/subscriber.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8d609b4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/subscriber.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/__init__.py b/venv/lib/python3.11/site-packages/litestar/channels/backends/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/__init__.py diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ab4e477 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/asyncpg.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/asyncpg.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a577096 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/asyncpg.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..334d295 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/memory.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/memory.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..9a87da5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/memory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/psycopg.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/psycopg.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f663280 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/psycopg.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/redis.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/redis.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..bf86a3e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/redis.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_flushall_streams.lua b/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_flushall_streams.lua new file mode 100644 index 0000000..a3faa6e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_flushall_streams.lua @@ -0,0 +1,15 @@ +local key_pattern = ARGV[1] + +local cursor = 0 +local deleted_streams = 0 + +repeat + local result = redis.call('SCAN', cursor, 'MATCH', key_pattern) + for _,key in ipairs(result[2]) do + redis.call('DEL', key) + deleted_streams = deleted_streams + 1 + end + cursor = tonumber(result[1]) +until cursor == 0 + +return deleted_streams diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_pubsub_publish.lua b/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_pubsub_publish.lua new file mode 100644 index 0000000..8402d08 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_pubsub_publish.lua @@ -0,0 +1,5 @@ +local data = ARGV[1] + +for _, channel in ipairs(KEYS) do + redis.call("PUBLISH", channel, data) +end diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_xadd_expire.lua b/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_xadd_expire.lua new file mode 100644 index 0000000..f6b322f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_xadd_expire.lua @@ -0,0 +1,13 @@ +local data = ARGV[1] +local limit = ARGV[2] +local exp = ARGV[3] +local maxlen_approx = ARGV[4] + +for i, key in ipairs(KEYS) do + if maxlen_approx == 1 then + redis.call("XADD", key, "MAXLEN", "~", limit, "*", "data", data, "channel", ARGV[i + 4]) + else + redis.call("XADD", key, "MAXLEN", limit, "*", "data", data, "channel", ARGV[i + 4]) + end + redis.call("PEXPIRE", key, exp) +end diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/asyncpg.py b/venv/lib/python3.11/site-packages/litestar/channels/backends/asyncpg.py new file mode 100644 index 0000000..967b208 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/asyncpg.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import asyncio +from contextlib import AsyncExitStack +from functools import partial +from typing import AsyncGenerator, Awaitable, Callable, Iterable, overload + +import asyncpg + +from litestar.channels import ChannelsBackend +from litestar.exceptions import ImproperlyConfiguredException + + +class AsyncPgChannelsBackend(ChannelsBackend): + _listener_conn: asyncpg.Connection + + @overload + def __init__(self, dsn: str) -> None: ... + + @overload + def __init__( + self, + *, + make_connection: Callable[[], Awaitable[asyncpg.Connection]], + ) -> None: ... + + def __init__( + self, + dsn: str | None = None, + *, + make_connection: Callable[[], Awaitable[asyncpg.Connection]] | None = None, + ) -> None: + if not (dsn or make_connection): + raise ImproperlyConfiguredException("Need to specify dsn or make_connection") + + self._subscribed_channels: set[str] = set() + self._exit_stack = AsyncExitStack() + self._connect = make_connection or partial(asyncpg.connect, dsn=dsn) + self._queue: asyncio.Queue[tuple[str, bytes]] | None = None + + async def on_startup(self) -> None: + self._queue = asyncio.Queue() + self._listener_conn = await self._connect() + + async def on_shutdown(self) -> None: + await self._listener_conn.close() + self._queue = None + + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + if self._queue is None: + raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?") + + dec_data = data.decode("utf-8") + + conn = await self._connect() + try: + for channel in channels: + await conn.execute("SELECT pg_notify($1, $2);", channel, dec_data) + finally: + await conn.close() + + async def subscribe(self, channels: Iterable[str]) -> None: + for channel in set(channels) - self._subscribed_channels: + await self._listener_conn.add_listener(channel, self._listener) # type: ignore[arg-type] + self._subscribed_channels.add(channel) + + async def unsubscribe(self, channels: Iterable[str]) -> None: + for channel in channels: + await self._listener_conn.remove_listener(channel, self._listener) # type: ignore[arg-type] + self._subscribed_channels = self._subscribed_channels - set(channels) + + async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]: + if self._queue is None: + raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?") + + while True: + channel, message = await self._queue.get() + self._queue.task_done() + # an UNLISTEN may be in transit while we're getting here, so we double-check + # that we are actually supposed to deliver this message + if channel in self._subscribed_channels: + yield channel, message + + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + raise NotImplementedError() + + def _listener(self, /, connection: asyncpg.Connection, pid: int, channel: str, payload: object) -> None: + if not isinstance(payload, str): + raise RuntimeError("Invalid data received") + self._queue.put_nowait((channel, payload.encode("utf-8"))) # type: ignore[union-attr] diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/base.py b/venv/lib/python3.11/site-packages/litestar/channels/backends/base.py new file mode 100644 index 0000000..ce7ee81 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/base.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import AsyncGenerator, Iterable + + +class ChannelsBackend(ABC): + @abstractmethod + async def on_startup(self) -> None: + """Called by the plugin on application startup""" + ... + + @abstractmethod + async def on_shutdown(self) -> None: + """Called by the plugin on application shutdown""" + ... + + @abstractmethod + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + """Publish the message ``data`` to all ``channels``""" + ... + + @abstractmethod + async def subscribe(self, channels: Iterable[str]) -> None: + """Start listening for events on ``channels``""" + ... + + @abstractmethod + async def unsubscribe(self, channels: Iterable[str]) -> None: + """Stop listening for events on ``channels``""" + ... + + @abstractmethod + def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]: + """Return a generator, iterating over events of subscribed channels as they become available""" + ... + + @abstractmethod + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + """Return the event history of ``channel``, at most ``limit`` entries""" + ... diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/memory.py b/venv/lib/python3.11/site-packages/litestar/channels/backends/memory.py new file mode 100644 index 0000000..a96a66b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/memory.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from asyncio import Queue +from collections import defaultdict, deque +from typing import Any, AsyncGenerator, Iterable + +from litestar.channels.backends.base import ChannelsBackend + + +class MemoryChannelsBackend(ChannelsBackend): + """An in-memory channels backend""" + + def __init__(self, history: int = 0) -> None: + self._max_history_length = history + self._channels: set[str] = set() + self._queue: Queue[tuple[str, bytes]] | None = None + self._history: defaultdict[str, deque[bytes]] = defaultdict(lambda: deque(maxlen=self._max_history_length)) + + async def on_startup(self) -> None: + self._queue = Queue() + + async def on_shutdown(self) -> None: + self._queue = None + + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + """Publish ``data`` to ``channels``. If a channel has not yet been subscribed to, + this will be a no-op. + + Args: + data: Data to publish + channels: Channels to publish to + + Returns: + None + + Raises: + RuntimeError: If ``on_startup`` has not been called yet + """ + if not self._queue: + raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?") + + for channel in channels: + if channel not in self._channels: + continue + + self._queue.put_nowait((channel, data)) + if self._max_history_length: + for channel in channels: + self._history[channel].append(data) + + async def subscribe(self, channels: Iterable[str]) -> None: + """Subscribe to ``channels``, and enable publishing to them""" + self._channels.update(channels) + + async def unsubscribe(self, channels: Iterable[str]) -> None: + """Unsubscribe from ``channels``""" + self._channels -= set(channels) + try: + for channel in channels: + del self._history[channel] + except KeyError: + pass + + async def stream_events(self) -> AsyncGenerator[tuple[str, Any], None]: + """Return a generator, iterating over events of subscribed channels as they become available""" + if self._queue is None: + raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?") + + while True: + channel, message = await self._queue.get() + self._queue.task_done() + + # if a message is published to a channel and the channel is then + # unsubscribed before retrieving that message from the stream, it can still + # end up here, so we double-check if we still are interested in this message + if channel in self._channels: + yield channel, message + + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + """Return the event history of ``channel``, at most ``limit`` entries""" + history = list(self._history[channel]) + if limit: + history = history[-limit:] + return history diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/psycopg.py b/venv/lib/python3.11/site-packages/litestar/channels/backends/psycopg.py new file mode 100644 index 0000000..14b53bc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/psycopg.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from contextlib import AsyncExitStack +from typing import AsyncGenerator, Iterable + +import psycopg + +from .base import ChannelsBackend + + +def _safe_quote(ident: str) -> str: + return '"{}"'.format(ident.replace('"', '""')) # sourcery skip + + +class PsycoPgChannelsBackend(ChannelsBackend): + _listener_conn: psycopg.AsyncConnection + + def __init__(self, pg_dsn: str) -> None: + self._pg_dsn = pg_dsn + self._subscribed_channels: set[str] = set() + self._exit_stack = AsyncExitStack() + + async def on_startup(self) -> None: + self._listener_conn = await psycopg.AsyncConnection.connect(self._pg_dsn, autocommit=True) + await self._exit_stack.enter_async_context(self._listener_conn) + + async def on_shutdown(self) -> None: + await self._exit_stack.aclose() + + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + dec_data = data.decode("utf-8") + async with await psycopg.AsyncConnection.connect(self._pg_dsn) as conn: + for channel in channels: + await conn.execute("SELECT pg_notify(%s, %s);", (channel, dec_data)) + + async def subscribe(self, channels: Iterable[str]) -> None: + for channel in set(channels) - self._subscribed_channels: + # can't use placeholders in LISTEN + await self._listener_conn.execute(f"LISTEN {_safe_quote(channel)};") # pyright: ignore + + self._subscribed_channels.add(channel) + + async def unsubscribe(self, channels: Iterable[str]) -> None: + for channel in channels: + # can't use placeholders in UNLISTEN + await self._listener_conn.execute(f"UNLISTEN {_safe_quote(channel)};") # pyright: ignore + self._subscribed_channels = self._subscribed_channels - set(channels) + + async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]: + async for notify in self._listener_conn.notifies(): + yield notify.channel, notify.payload.encode("utf-8") + + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + raise NotImplementedError() diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/redis.py b/venv/lib/python3.11/site-packages/litestar/channels/backends/redis.py new file mode 100644 index 0000000..f03c9f2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/redis.py @@ -0,0 +1,277 @@ +from __future__ import annotations + +import asyncio +import sys + +if sys.version_info < (3, 9): + import importlib_resources # pyright: ignore +else: + import importlib.resources as importlib_resources +from abc import ABC +from datetime import timedelta +from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, cast + +from litestar.channels.backends.base import ChannelsBackend + +if TYPE_CHECKING: + from redis.asyncio import Redis + from redis.asyncio.client import PubSub + +_resource_path = importlib_resources.files("litestar.channels.backends") +_PUBSUB_PUBLISH_SCRIPT = (_resource_path / "_redis_pubsub_publish.lua").read_text() +_FLUSHALL_STREAMS_SCRIPT = (_resource_path / "_redis_flushall_streams.lua").read_text() +_XADD_EXPIRE_SCRIPT = (_resource_path / "_redis_xadd_expire.lua").read_text() + + +class _LazyEvent: + """A lazy proxy to asyncio.Event that only creates the event once it's accessed. + + It ensures that the Event is created within a running event loop. If it's not, there can be an issue where a future + within the event itself is attached to a different loop. + + This happens in our tests and could also happen when a user creates an instance of the backend outside an event loop + in their application. + """ + + def __init__(self) -> None: + self.__event: asyncio.Event | None = None + + @property + def _event(self) -> asyncio.Event: + if self.__event is None: + self.__event = asyncio.Event() + return self.__event + + def set(self) -> None: + self._event.set() + + def clear(self) -> None: + self._event.clear() + + async def wait(self) -> None: + await self._event.wait() + + +class RedisChannelsBackend(ChannelsBackend, ABC): + def __init__(self, *, redis: Redis, key_prefix: str, stream_sleep_no_subscriptions: int) -> None: + """Base redis channels backend. + + Args: + redis: A :class:`redis.asyncio.Redis` instance + key_prefix: Key prefix to use for storing data in redis + stream_sleep_no_subscriptions: Amount of time in milliseconds to pause the + :meth:`stream_events` generator, should no subscribers exist + """ + self._redis = redis + self._key_prefix = key_prefix + self._stream_sleep_no_subscriptions = stream_sleep_no_subscriptions + + def _make_key(self, channel: str) -> str: + return f"{self._key_prefix}_{channel.upper()}" + + +class RedisChannelsPubSubBackend(RedisChannelsBackend): + def __init__( + self, *, redis: Redis, stream_sleep_no_subscriptions: int = 1, key_prefix: str = "LITESTAR_CHANNELS" + ) -> None: + """Redis channels backend, `Pub/Sub <https://redis.io/docs/manual/pubsub/>`_. + + This backend provides low overhead and resource usage but no support for history. + + Args: + redis: A :class:`redis.asyncio.Redis` instance + key_prefix: Key prefix to use for storing data in redis + stream_sleep_no_subscriptions: Amount of time in milliseconds to pause the + :meth:`stream_events` generator, should no subscribers exist + """ + super().__init__( + redis=redis, stream_sleep_no_subscriptions=stream_sleep_no_subscriptions, key_prefix=key_prefix + ) + self.__pub_sub: PubSub | None = None + self._publish_script = self._redis.register_script(_PUBSUB_PUBLISH_SCRIPT) + self._has_subscribed = _LazyEvent() + + @property + def _pub_sub(self) -> PubSub: + if self.__pub_sub is None: + self.__pub_sub = self._redis.pubsub() + return self.__pub_sub + + async def on_startup(self) -> None: + # this method should not do anything in this case + pass + + async def on_shutdown(self) -> None: + await self._pub_sub.reset() + + async def subscribe(self, channels: Iterable[str]) -> None: + """Subscribe to ``channels``, and enable publishing to them""" + await self._pub_sub.subscribe(*channels) + self._has_subscribed.set() + + async def unsubscribe(self, channels: Iterable[str]) -> None: + """Stop listening for events on ``channels``""" + await self._pub_sub.unsubscribe(*channels) + # if we have no active subscriptions, or only subscriptions which are pending + # to be unsubscribed we consider the backend to be unsubscribed from all + # channels, so we reset the event + if not self._pub_sub.channels.keys() - self._pub_sub.pending_unsubscribe_channels: + self._has_subscribed.clear() + + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + """Publish ``data`` to ``channels`` + + .. note:: + This operation is performed atomically, using a lua script + """ + await self._publish_script(keys=list(set(channels)), args=[data]) + + async def stream_events(self) -> AsyncGenerator[tuple[str, Any], None]: + """Return a generator, iterating over events of subscribed channels as they become available. + + If no channels have been subscribed to yet via :meth:`subscribe`, sleep for ``stream_sleep_no_subscriptions`` + milliseconds. + """ + + while True: + await self._has_subscribed.wait() + message = await self._pub_sub.get_message( + ignore_subscribe_messages=True, timeout=self._stream_sleep_no_subscriptions + ) + if message is None: + continue + + channel: str = message["channel"].decode() + data: bytes = message["data"] + # redis handles the unsubscibes with a queue; Unsubscribing doesn't mean the + # unsubscribe will happen immediately after requesting it, so we could + # receive a message on a channel that, from a client's perspective, it's not + # subscribed to anymore + if channel.encode() in self._pub_sub.channels.keys() - self._pub_sub.pending_unsubscribe_channels: + yield channel, data + + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + """Not implemented""" + raise NotImplementedError() + + +class RedisChannelsStreamBackend(RedisChannelsBackend): + def __init__( + self, + history: int, + *, + redis: Redis, + stream_sleep_no_subscriptions: int = 1, + cap_streams_approximate: bool = True, + stream_ttl: int | timedelta = timedelta(seconds=60), + key_prefix: str = "LITESTAR_CHANNELS", + ) -> None: + """Redis channels backend, `streams <https://redis.io/docs/data-types/streams/>`_. + + Args: + history: Amount of messages to keep. This will set a ``MAXLEN`` to the streams + redis: A :class:`redis.asyncio.Redis` instance + key_prefix: Key prefix to use for streams + stream_sleep_no_subscriptions: Amount of time in milliseconds to pause the + :meth:`stream_events` generator, should no subscribers exist + cap_streams_approximate: Set the streams ``MAXLEN`` using the ``~`` approximation + operator for improved performance + stream_ttl: TTL of a stream in milliseconds or as a timedelta. A streams TTL will be set on each publishing + operation using ``PEXPIRE`` + """ + super().__init__( + redis=redis, stream_sleep_no_subscriptions=stream_sleep_no_subscriptions, key_prefix=key_prefix + ) + + self._history_limit = history + self._subscribed_channels: set[str] = set() + self._cap_streams_approximate = cap_streams_approximate + self._stream_ttl = stream_ttl if isinstance(stream_ttl, int) else int(stream_ttl.total_seconds() * 1000) + self._flush_all_streams_script = self._redis.register_script(_FLUSHALL_STREAMS_SCRIPT) + self._publish_script = self._redis.register_script(_XADD_EXPIRE_SCRIPT) + self._has_subscribed_channels = _LazyEvent() + + async def on_startup(self) -> None: + """Called on application startup""" + + async def on_shutdown(self) -> None: + """Called on application shutdown""" + + async def subscribe(self, channels: Iterable[str]) -> None: + """Subscribe to ``channels``""" + self._subscribed_channels.update(channels) + self._has_subscribed_channels.set() + + async def unsubscribe(self, channels: Iterable[str]) -> None: + """Unsubscribe from ``channels``""" + self._subscribed_channels -= set(channels) + if not len(self._subscribed_channels): + self._has_subscribed_channels.clear() + + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + """Publish ``data`` to ``channels``. + + .. note:: + This operation is performed atomically, using a Lua script + """ + channels = set(channels) + await self._publish_script( + keys=[self._make_key(key) for key in channels], + args=[ + data, + self._history_limit, + self._stream_ttl, + int(self._cap_streams_approximate), + *channels, + ], + ) + + async def _get_subscribed_channels(self) -> set[str]: + """Get subscribed channels. If no channels are currently subscribed, wait""" + await self._has_subscribed_channels.wait() + return self._subscribed_channels + + async def stream_events(self) -> AsyncGenerator[tuple[str, Any], None]: + """Return a generator, iterating over events of subscribed channels as they become available. + + If no channels have been subscribed to yet via :meth:`subscribe`, sleep for ``stream_sleep_no_subscriptions`` + milliseconds. + """ + stream_ids: dict[str, bytes] = {} + while True: + # We wait for subscribed channels, because we can't pass an empty dict to + # xread and block for subscribers + stream_keys = [self._make_key(c) for c in await self._get_subscribed_channels()] + + data: list[tuple[bytes, list[tuple[bytes, dict[bytes, bytes]]]]] = await self._redis.xread( + {key: stream_ids.get(key, 0) for key in stream_keys}, block=self._stream_sleep_no_subscriptions + ) + + if not data: + continue + + for stream_key, channel_events in data: + for event in channel_events: + event_data = event[1][b"data"] + channel_name = event[1][b"channel"].decode() + stream_ids[stream_key.decode()] = event[0] + yield channel_name, event_data + + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + """Return the history of ``channels``, returning at most ``limit`` messages""" + data: Iterable[tuple[bytes, dict[bytes, bytes]]] + if limit: + data = reversed(await self._redis.xrevrange(self._make_key(channel), count=limit)) + else: + data = await self._redis.xrange(self._make_key(channel)) + + return [event[b"data"] for _, event in data] + + async def flush_all(self) -> int: + """Delete all stream keys with the ``key_prefix``. + + .. important:: + This method is incompatible with redis clusters + """ + deleted_streams = await self._flush_all_streams_script(keys=[], args=[f"{self._key_prefix}*"]) + return cast("int", deleted_streams) diff --git a/venv/lib/python3.11/site-packages/litestar/channels/plugin.py b/venv/lib/python3.11/site-packages/litestar/channels/plugin.py new file mode 100644 index 0000000..5988445 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/plugin.py @@ -0,0 +1,359 @@ +from __future__ import annotations + +import asyncio +from asyncio import CancelledError, Queue, Task, create_task +from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress +from functools import partial +from typing import TYPE_CHECKING, AsyncGenerator, Awaitable, Callable, Iterable + +import msgspec.json + +from litestar.di import Provide +from litestar.exceptions import ImproperlyConfiguredException, LitestarException +from litestar.handlers import WebsocketRouteHandler +from litestar.plugins import InitPluginProtocol +from litestar.serialization import default_serializer + +from .subscriber import BacklogStrategy, EventCallback, Subscriber + +if TYPE_CHECKING: + from types import TracebackType + + from litestar.channels.backends.base import ChannelsBackend + from litestar.config.app import AppConfig + from litestar.connection import WebSocket + from litestar.types import LitestarEncodableType, TypeEncodersMap + from litestar.types.asgi_types import WebSocketMode + + +class ChannelsException(LitestarException): + pass + + +class ChannelsPlugin(InitPluginProtocol, AbstractAsyncContextManager): + def __init__( + self, + backend: ChannelsBackend, + *, + channels: Iterable[str] | None = None, + arbitrary_channels_allowed: bool = False, + create_ws_route_handlers: bool = False, + ws_handler_send_history: int = 0, + ws_handler_base_path: str = "/", + ws_send_mode: WebSocketMode = "text", + subscriber_max_backlog: int | None = None, + subscriber_backlog_strategy: BacklogStrategy = "backoff", + subscriber_class: type[Subscriber] = Subscriber, + type_encoders: TypeEncodersMap | None = None, + ) -> None: + """Plugin to handle broadcasting to WebSockets with support for channels. + + This plugin is available as an injected dependency using the ``channels`` key. + + Args: + backend: Backend to store data in + channels: Channels to serve. If ``arbitrary_channels_allowed`` is ``False`` (the default), trying to + subscribe to a channel not in ``channels`` will raise an exception + arbitrary_channels_allowed: Allow the creation of channels on the fly + create_ws_route_handlers: If ``True``, websocket route handlers will be created for all channels defined in + ``channels``. If ``arbitrary_channels_allowed`` is ``True``, a single handler will be created instead, + handling all channels. The handlers created will accept WebSocket connections and start sending received + events for their respective channels. + ws_handler_send_history: Amount of history entries to send from the generated websocket route handlers after + a client has connected. A value of ``0`` indicates to not send a history + ws_handler_base_path: Path prefix used for the generated route handlers + ws_send_mode: Send mode to use for sending data through a :class:`WebSocket <.connection.WebSocket>`. + This will be used when sending within generated route handlers or :meth:`Subscriber.run_in_background` + subscriber_max_backlog: Maximum amount of unsent messages to be held in memory for a given subscriber. If + that limit is reached, new messages will be treated accordingly to ``backlog_strategy`` + subscriber_backlog_strategy: Define the behaviour if ``max_backlog`` is reached for a subscriber. ` + `backoff`` will result in new messages being dropped until older ones have been processed. ``dropleft`` + will drop older messages in favour of new ones. + subscriber_class: A :class:`Subscriber` subclass to return from :meth:`subscribe` + type_encoders: An additional mapping of type encoders used to encode data before sending + + """ + self._backend = backend + self._pub_queue: Queue[tuple[bytes, list[str]]] | None = None + self._pub_task: Task | None = None + self._sub_task: Task | None = None + + if not (channels or arbitrary_channels_allowed): + raise ImproperlyConfiguredException("Must define either channels or set arbitrary_channels_allowed=True") + + # make the path absolute, so we can simply concatenate it later + if not ws_handler_base_path.endswith("/"): + ws_handler_base_path += "/" + + self._arbitrary_channels_allowed = arbitrary_channels_allowed + self._create_route_handlers = create_ws_route_handlers + self._handler_root_path = ws_handler_base_path + self._socket_send_mode: WebSocketMode = ws_send_mode + self._encode_json = msgspec.json.Encoder( + enc_hook=partial(default_serializer, type_encoders=type_encoders) + ).encode + self._handler_should_send_history = bool(ws_handler_send_history) + self._history_limit = None if ws_handler_send_history < 0 else ws_handler_send_history + self._max_backlog = subscriber_max_backlog + self._backlog_strategy: BacklogStrategy = subscriber_backlog_strategy + self._subscriber_class = subscriber_class + + self._channels: dict[str, set[Subscriber]] = {channel: set() for channel in channels or []} + + def encode_data(self, data: LitestarEncodableType) -> bytes: + """Encode data before storing it in the backend""" + if isinstance(data, bytes): + return data + + return data.encode() if isinstance(data, str) else self._encode_json(data) + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + """Plugin hook. Set up a ``channels`` dependency, add route handlers and register application hooks""" + app_config.dependencies["channels"] = Provide(lambda: self, use_cache=True, sync_to_thread=False) + app_config.lifespan.append(self) + app_config.signature_namespace.update(ChannelsPlugin=ChannelsPlugin) + + if self._create_route_handlers: + if self._arbitrary_channels_allowed: + path = self._handler_root_path + "{channel_name:str}" + route_handlers = [WebsocketRouteHandler(path)(self._ws_handler_func)] + else: + route_handlers = [ + WebsocketRouteHandler(self._handler_root_path + channel_name)( + self._create_ws_handler_func(channel_name) + ) + for channel_name in self._channels + ] + app_config.route_handlers.extend(route_handlers) + + return app_config + + def publish(self, data: LitestarEncodableType, channels: str | Iterable[str]) -> None: + """Schedule ``data`` to be published to ``channels``. + + .. note:: + This is a synchronous method that returns immediately. There are no + guarantees that when this method returns the data will have been published + to the backend. For that, use :meth:`wait_published` + + """ + if isinstance(channels, str): + channels = [channels] + data = self.encode_data(data) + try: + self._pub_queue.put_nowait((data, list(channels))) # type: ignore[union-attr] + except AttributeError as e: + raise RuntimeError("Plugin not yet initialized. Did you forget to call on_startup?") from e + + async def wait_published(self, data: LitestarEncodableType, channels: str | Iterable[str]) -> None: + """Publish ``data`` to ``channels``""" + if isinstance(channels, str): + channels = [channels] + data = self.encode_data(data) + + await self._backend.publish(data, channels) + + async def subscribe(self, channels: str | Iterable[str], history: int | None = None) -> Subscriber: + """Create a :class:`Subscriber`, providing a stream of all events in ``channels``. + + The created subscriber will be passive by default and has to be consumed manually, + either by using :meth:`Subscriber.run_in_background` or iterating over events + using :meth:`Subscriber.iter_events`. + + Args: + channels: Channel(s) to subscribe to + history: If a non-negative integer, add this amount of history entries from + each channel to the subscriber's event stream. Note that this will wait + until all history entries are fetched and pushed to the subscriber's + stream. For more control use :meth:`put_subscriber_history`. + + Returns: + A :class:`Subscriber` + + Raises: + ChannelsException: If a channel in ``channels`` has not been declared on this backend and + ``arbitrary_channels_allowed`` has not been set to ``True`` + """ + if isinstance(channels, str): + channels = [channels] + + subscriber = self._subscriber_class( + plugin=self, + max_backlog=self._max_backlog, + backlog_strategy=self._backlog_strategy, + ) + channels_to_subscribe = set() + + for channel in channels: + if channel not in self._channels: + if not self._arbitrary_channels_allowed: + raise ChannelsException( + f"Unknown channel: {channel!r}. Either explicitly defined the channel or set " + "arbitrary_channels_allowed=True" + ) + self._channels[channel] = set() + channel_subscribers = self._channels[channel] + if not channel_subscribers: + channels_to_subscribe.add(channel) + + channel_subscribers.add(subscriber) + + if channels_to_subscribe: + await self._backend.subscribe(channels_to_subscribe) + + if history: + await self.put_subscriber_history(subscriber=subscriber, limit=history, channels=channels) + + return subscriber + + async def unsubscribe(self, subscriber: Subscriber, channels: str | Iterable[str] | None = None) -> None: + """Unsubscribe a :class:`Subscriber` from ``channels``. If the subscriber has a running sending task, it will + be stopped. + + Args: + channels: Channels to unsubscribe from. If ``None``, unsubscribe from all channels + subscriber: :class:`Subscriber` to unsubscribe + """ + if channels is None: + channels = list(self._channels.keys()) + elif isinstance(channels, str): + channels = [channels] + + channels_to_unsubscribe: set[str] = set() + + for channel in channels: + channel_subscribers = self._channels[channel] + + try: + channel_subscribers.remove(subscriber) + except KeyError: # subscriber was not subscribed to this channel. This may happen if channels is None + continue + + if not channel_subscribers: + channels_to_unsubscribe.add(channel) + + if all(subscriber not in queues for queues in self._channels.values()): + await subscriber.put(None) # this will stop any running task or generator by breaking the inner loop + if subscriber.is_running: + await subscriber.stop() + + if channels_to_unsubscribe: + await self._backend.unsubscribe(channels_to_unsubscribe) + + @asynccontextmanager + async def start_subscription( + self, channels: str | Iterable[str], history: int | None = None + ) -> AsyncGenerator[Subscriber, None]: + """Create a :class:`Subscriber` and tie its subscriptions to a context manager; + Upon exiting the context, :meth:`unsubscribe` will be called. + + Args: + channels: Channel(s) to subscribe to + history: If a non-negative integer, add this amount of history entries from + each channel to the subscriber's event stream. Note that this will wait + until all history entries are fetched and pushed to the subscriber's + stream. For more control use :meth:`put_subscriber_history`. + + Returns: + A :class:`Subscriber` + """ + subscriber = await self.subscribe(channels, history=history) + + try: + yield subscriber + finally: + await self.unsubscribe(subscriber, channels) + + async def put_subscriber_history( + self, subscriber: Subscriber, channels: str | Iterable[str], limit: int | None = None + ) -> None: + """Fetch the history of ``channels`` from the backend and put them in the + subscriber's stream + """ + if isinstance(channels, str): + channels = [channels] + + for channel in channels: + history = await self._backend.get_history(channel, limit) + for entry in history: + await subscriber.put(entry) + + async def _ws_handler_func(self, channel_name: str, socket: WebSocket) -> None: + await socket.accept() + + # the ternary operator triggers a mypy bug: https://github.com/python/mypy/issues/10740 + on_event: EventCallback = socket.send_text if self._socket_send_mode == "text" else socket.send_bytes # type: ignore[assignment] + + async with self.start_subscription(channel_name) as subscriber: + if self._handler_should_send_history: + await self.put_subscriber_history(subscriber, channels=channel_name, limit=self._history_limit) + + # use the background task, so we can block on receive(), breaking the loop when a connection closes + async with subscriber.run_in_background(on_event): + while (await socket.receive())["type"] != "websocket.disconnect": + continue + + def _create_ws_handler_func(self, channel_name: str) -> Callable[[WebSocket], Awaitable[None]]: + async def ws_handler_func(socket: WebSocket) -> None: + await self._ws_handler_func(channel_name=channel_name, socket=socket) + + return ws_handler_func + + async def _pub_worker(self) -> None: + while self._pub_queue: + data, channels = await self._pub_queue.get() + await self._backend.publish(data, channels) + self._pub_queue.task_done() + + async def _sub_worker(self) -> None: + async for channel, payload in self._backend.stream_events(): + for subscriber in self._channels.get(channel, []): + subscriber.put_nowait(payload) + + async def _on_startup(self) -> None: + await self._backend.on_startup() + self._pub_queue = Queue() + self._pub_task = create_task(self._pub_worker()) + self._sub_task = create_task(self._sub_worker()) + if self._channels: + await self._backend.subscribe(list(self._channels)) + + async def _on_shutdown(self) -> None: + if self._pub_queue: + await self._pub_queue.join() + self._pub_queue = None + + await asyncio.gather( + *[ + subscriber.stop(join=False) + for subscribers in self._channels.values() + for subscriber in subscribers + if subscriber.is_running + ] + ) + + if self._sub_task: + self._sub_task.cancel() + with suppress(CancelledError): + await self._sub_task + self._sub_task = None + + if self._pub_task: + self._pub_task.cancel() + with suppress(CancelledError): + await self._pub_task + self._sub_task = None + + await self._backend.on_shutdown() + + async def __aenter__(self) -> ChannelsPlugin: + await self._on_startup() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self._on_shutdown() diff --git a/venv/lib/python3.11/site-packages/litestar/channels/subscriber.py b/venv/lib/python3.11/site-packages/litestar/channels/subscriber.py new file mode 100644 index 0000000..b358bc4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/subscriber.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import asyncio +from asyncio import CancelledError, Queue, QueueFull +from collections import deque +from contextlib import AsyncExitStack, asynccontextmanager, suppress +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generic, Literal, TypeVar + +if TYPE_CHECKING: + from litestar.channels import ChannelsPlugin + + +T = TypeVar("T") + +BacklogStrategy = Literal["backoff", "dropleft"] + +EventCallback = Callable[[bytes], Awaitable[Any]] + + +class AsyncDeque(Queue, Generic[T]): + def __init__(self, maxsize: int | None) -> None: + self._deque_maxlen = maxsize + super().__init__() + + def _init(self, maxsize: int) -> None: + self._queue: deque[T] = deque(maxlen=self._deque_maxlen) + + +class Subscriber: + """A wrapper around a stream of events published to subscribed channels""" + + def __init__( + self, + plugin: ChannelsPlugin, + max_backlog: int | None = None, + backlog_strategy: BacklogStrategy = "backoff", + ) -> None: + self._task: asyncio.Task | None = None + self._plugin = plugin + self._backend = plugin._backend + self._queue: Queue[bytes | None] | AsyncDeque[bytes | None] + + if max_backlog and backlog_strategy == "dropleft": + self._queue = AsyncDeque(maxsize=max_backlog or 0) + else: + self._queue = Queue(maxsize=max_backlog or 0) + + async def put(self, item: bytes | None) -> None: + await self._queue.put(item) + + def put_nowait(self, item: bytes | None) -> bool: + """Put an item in the subscriber's stream without waiting""" + try: + self._queue.put_nowait(item) + return True + except QueueFull: + return False + + @property + def qsize(self) -> int: + return self._queue.qsize() + + async def iter_events(self) -> AsyncGenerator[bytes, None]: + """Iterate over the stream of events. If no items are available, block until + one becomes available + """ + while True: + item = await self._queue.get() + if item is None: + self._queue.task_done() + break + yield item + self._queue.task_done() + + @asynccontextmanager + async def run_in_background(self, on_event: EventCallback, join: bool = True) -> AsyncGenerator[None, None]: + """Start a task in the background that sends events from the subscriber's stream + to ``socket`` as they become available. On exit, it will prevent the stream from + accepting new events and wait until the currently enqueued ones are processed. + Should the context be left with an exception, the task will be cancelled + immediately. + + Args: + on_event: Callback to invoke with the event data for every event + join: If ``True``, wait for all items in the stream to be processed before + stopping the worker. Note that an error occurring within the context + will always lead to the immediate cancellation of the worker + """ + self._start_in_background(on_event=on_event) + async with AsyncExitStack() as exit_stack: + exit_stack.push_async_callback(self.stop, join=False) + yield + exit_stack.pop_all() + await self.stop(join=join) + + async def _worker(self, on_event: EventCallback) -> None: + async for event in self.iter_events(): + await on_event(event) + + def _start_in_background(self, on_event: EventCallback) -> None: + """Start a task in the background that sends events from the subscriber's stream + to ``socket`` as they become available. + + Args: + on_event: Callback to invoke with the event data for every event + """ + if self._task is not None: + raise RuntimeError("Subscriber is already running") + self._task = asyncio.create_task(self._worker(on_event)) + + @property + def is_running(self) -> bool: + """Return whether a sending task is currently running""" + return self._task is not None + + async def stop(self, join: bool = False) -> None: + """Stop a task was previously started with :meth:`run_in_background`. If the + task is not yet done it will be cancelled and awaited + + Args: + join: If ``True`` wait for all items to be processed before stopping the task + """ + if not self._task: + return + + if join: + await self._queue.join() + + if not self._task.done(): + self._task.cancel() + + with suppress(CancelledError): + await self._task + + self._task = None diff --git a/venv/lib/python3.11/site-packages/litestar/cli/__init__.py b/venv/lib/python3.11/site-packages/litestar/cli/__init__.py new file mode 100644 index 0000000..f6c366e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/cli/__init__.py @@ -0,0 +1,29 @@ +"""Litestar CLI.""" + +from __future__ import annotations + +from importlib.util import find_spec + +# Ensure `rich_click` patching occurs before we do any imports from `click`. +if find_spec("rich_click") is not None: # pragma: no cover + import rich_click as click + from rich_click.cli import patch as rich_click_patch + + rich_click_patch() + click.rich_click.USE_RICH_MARKUP = True + click.rich_click.USE_MARKDOWN = False + click.rich_click.SHOW_ARGUMENTS = True + click.rich_click.GROUP_ARGUMENTS_OPTIONS = True + click.rich_click.SHOW_ARGUMENTS = True + click.rich_click.GROUP_ARGUMENTS_OPTIONS = True + click.rich_click.STYLE_ERRORS_SUGGESTION = "magenta italic" + click.rich_click.ERRORS_SUGGESTION = "" + click.rich_click.ERRORS_EPILOGUE = "" + click.rich_click.MAX_WIDTH = 80 + click.rich_click.SHOW_METAVARS_COLUMN = True + click.rich_click.APPEND_METAVARS_HELP = True + + +from .main import litestar_group + +__all__ = ["litestar_group"] diff --git a/venv/lib/python3.11/site-packages/litestar/cli/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/cli/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..808a2fb --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/cli/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/cli/__pycache__/_utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/cli/__pycache__/_utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4baba11 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/cli/__pycache__/_utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/cli/__pycache__/main.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/cli/__pycache__/main.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..7ede592 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/cli/__pycache__/main.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/cli/_utils.py b/venv/lib/python3.11/site-packages/litestar/cli/_utils.py new file mode 100644 index 0000000..f36cd77 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/cli/_utils.py @@ -0,0 +1,562 @@ +from __future__ import annotations + +import contextlib +import importlib +import inspect +import os +import re +import sys +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from functools import wraps +from importlib.util import find_spec +from itertools import chain +from os import getenv +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Sequence, TypeVar, cast + +from click import ClickException, Command, Context, Group, pass_context +from rich import get_console +from rich.table import Table +from typing_extensions import ParamSpec, get_type_hints + +from litestar import Litestar, __version__ +from litestar.middleware import DefineMiddleware +from litestar.utils import get_name + +if sys.version_info >= (3, 10): + from importlib.metadata import entry_points +else: + from importlib_metadata import entry_points + + +if TYPE_CHECKING: + from litestar.openapi import OpenAPIConfig + from litestar.routes import ASGIRoute, HTTPRoute, WebSocketRoute + from litestar.types import AnyCallable + + +UVICORN_INSTALLED = find_spec("uvicorn") is not None +JSBEAUTIFIER_INSTALLED = find_spec("jsbeautifier") is not None + + +__all__ = ( + "UVICORN_INSTALLED", + "JSBEAUTIFIER_INSTALLED", + "LoadedApp", + "LitestarCLIException", + "LitestarEnv", + "LitestarExtensionGroup", + "LitestarGroup", + "show_app_info", +) + + +P = ParamSpec("P") +T = TypeVar("T") + + +AUTODISCOVERY_FILE_NAMES = ["app", "application"] + +console = get_console() + + +class LitestarCLIException(ClickException): + """Base class for Litestar CLI exceptions.""" + + def __init__(self, message: str) -> None: + """Initialize exception and style error message.""" + super().__init__(message) + + +@dataclass +class LitestarEnv: + """Information about the current Litestar environment variables.""" + + app_path: str + debug: bool + app: Litestar + cwd: Path + host: str | None = None + port: int | None = None + fd: int | None = None + uds: str | None = None + reload: bool | None = None + reload_dirs: tuple[str, ...] | None = None + reload_include: tuple[str, ...] | None = None + reload_exclude: tuple[str, ...] | None = None + web_concurrency: int | None = None + is_app_factory: bool = False + certfile_path: str | None = None + keyfile_path: str | None = None + create_self_signed_cert: bool = False + + @classmethod + def from_env(cls, app_path: str | None, app_dir: Path | None = None) -> LitestarEnv: + """Load environment variables. + + If ``python-dotenv`` is installed, use it to populate environment first + """ + cwd = Path().cwd() if app_dir is None else app_dir + cwd_str_path = str(cwd) + if cwd_str_path not in sys.path: + sys.path.append(cwd_str_path) + + with contextlib.suppress(ImportError): + import dotenv + + dotenv.load_dotenv() + app_path = app_path or getenv("LITESTAR_APP") + if app_path and getenv("LITESTAR_APP") is None: + os.environ["LITESTAR_APP"] = app_path + if app_path: + console.print(f"Using Litestar app from env: [bright_blue]{app_path!r}") + loaded_app = _load_app_from_path(app_path) + else: + loaded_app = _autodiscover_app(cwd) + + port = getenv("LITESTAR_PORT") + web_concurrency = getenv("WEB_CONCURRENCY") + uds = getenv("LITESTAR_UNIX_DOMAIN_SOCKET") + fd = getenv("LITESTAR_FILE_DESCRIPTOR") + reload_dirs = tuple(s.strip() for s in getenv("LITESTAR_RELOAD_DIRS", "").split(",") if s) or None + reload_include = tuple(s.strip() for s in getenv("LITESTAR_RELOAD_INCLUDES", "").split(",") if s) or None + reload_exclude = tuple(s.strip() for s in getenv("LITESTAR_RELOAD_EXCLUDES", "").split(",") if s) or None + + return cls( + app_path=loaded_app.app_path, + app=loaded_app.app, + debug=_bool_from_env("LITESTAR_DEBUG"), + host=getenv("LITESTAR_HOST"), + port=int(port) if port else None, + uds=uds, + fd=int(fd) if fd else None, + reload=_bool_from_env("LITESTAR_RELOAD"), + reload_dirs=reload_dirs, + reload_include=reload_include, + reload_exclude=reload_exclude, + web_concurrency=int(web_concurrency) if web_concurrency else None, + is_app_factory=loaded_app.is_factory, + cwd=cwd, + certfile_path=getenv("LITESTAR_SSL_CERT_PATH"), + keyfile_path=getenv("LITESTAR_SSL_KEY_PATH"), + create_self_signed_cert=_bool_from_env("LITESTAR_CREATE_SELF_SIGNED_CERT"), + ) + + +@dataclass +class LoadedApp: + """Information about a loaded Litestar app.""" + + app: Litestar + app_path: str + is_factory: bool + + +class LitestarGroup(Group): + """:class:`click.Group` subclass that automatically injects ``app`` and ``env` kwargs into commands that request it. + + Use this as the ``cls`` for :class:`click.Group` if you're extending the internal CLI with a group. For ``command``s + added directly to the root group this is not needed. + """ + + def __init__( + self, + name: str | None = None, + commands: dict[str, Command] | Sequence[Command] | None = None, + **attrs: Any, + ) -> None: + """Init ``LitestarGroup``""" + self.group_class = LitestarGroup + super().__init__(name=name, commands=commands, **attrs) + + def add_command(self, cmd: Command, name: str | None = None) -> None: + """Add command. + + If necessary, inject ``app`` and ``env`` kwargs + """ + if cmd.callback: + cmd.callback = _inject_args(cmd.callback) + super().add_command(cmd) + + def command(self, *args: Any, **kwargs: Any) -> Callable[[AnyCallable], Command] | Command: # type: ignore[override] + # For some reason, even when copying the overloads + signature from click 1:1, mypy goes haywire + """Add a function as a command. + + If necessary, inject ``app`` and ``env`` kwargs + """ + + def decorator(f: AnyCallable) -> Command: + f = _inject_args(f) + return cast("Command", Group.command(self, *args, **kwargs)(f)) + + return decorator + + +class LitestarExtensionGroup(LitestarGroup): + """``LitestarGroup`` subclass that will load Litestar-CLI extensions from the `litestar.commands` entry_point. + + This group class should not be used on any group besides the root ``litestar_group``. + """ + + def __init__( + self, + name: str | None = None, + commands: dict[str, Command] | Sequence[Command] | None = None, + **attrs: Any, + ) -> None: + """Init ``LitestarExtensionGroup``""" + super().__init__(name=name, commands=commands, **attrs) + self._prepare_done = False + + for entry_point in entry_points(group="litestar.commands"): + command = entry_point.load() + _wrap_commands([command]) + self.add_command(command, entry_point.name) + + def _prepare(self, ctx: Context) -> None: + if self._prepare_done: + return + + if isinstance(ctx.obj, LitestarEnv): + env: LitestarEnv | None = ctx.obj + else: + try: + env = ctx.obj = LitestarEnv.from_env(ctx.params.get("app_path"), ctx.params.get("app_dir")) + except LitestarCLIException: + env = None + + if env: + for plugin in env.app.plugins.cli: + plugin.on_cli_init(self) + + self._prepare_done = True + + def make_context( + self, + info_name: str | None, + args: list[str], + parent: Context | None = None, + **extra: Any, + ) -> Context: + ctx = super().make_context(info_name, args, parent, **extra) + self._prepare(ctx) + return ctx + + def list_commands(self, ctx: Context) -> list[str]: + self._prepare(ctx) + return super().list_commands(ctx) + + +def _inject_args(func: Callable[P, T]) -> Callable[P, T]: + """Inject the app instance into a ``Command``""" + params = inspect.signature(func).parameters + + @wraps(func) + def wrapped(ctx: Context, /, *args: P.args, **kwargs: P.kwargs) -> T: + needs_app = "app" in params + needs_env = "env" in params + if needs_env or needs_app: + # only resolve this if actually requested. Commands that don't need an env or app should be able to run + # without + if not isinstance(ctx.obj, LitestarEnv): + ctx.obj = ctx.obj() + env = ctx.ensure_object(LitestarEnv) + if needs_app: + kwargs["app"] = env.app + if needs_env: + kwargs["env"] = env + + if "ctx" in params: + kwargs["ctx"] = ctx + + return func(*args, **kwargs) + + return pass_context(wrapped) + + +def _wrap_commands(commands: Iterable[Command]) -> None: + for command in commands: + if isinstance(command, Group): + _wrap_commands(command.commands.values()) + elif command.callback: + command.callback = _inject_args(command.callback) + + +def _bool_from_env(key: str, default: bool = False) -> bool: + value = getenv(key) + if not value: + return default + value = value.lower() + return value in ("true", "1") + + +def _load_app_from_path(app_path: str) -> LoadedApp: + module_path, app_name = app_path.split(":") + module = importlib.import_module(module_path) + app = getattr(module, app_name) + is_factory = False + if not isinstance(app, Litestar) and callable(app): + app = app() + is_factory = True + return LoadedApp(app=app, app_path=app_path, is_factory=is_factory) + + +def _path_to_dotted_path(path: Path) -> str: + if path.stem == "__init__": + path = path.parent + return ".".join(path.with_suffix("").parts) + + +def _arbitrary_autodiscovery_paths(base_dir: Path) -> Generator[Path, None, None]: + yield from _autodiscovery_paths(base_dir, arbitrary=False) + for path in base_dir.iterdir(): + if path.name.startswith(".") or path.name.startswith("_"): + continue + if path.is_file() and path.suffix == ".py": + yield path + + +def _autodiscovery_paths(base_dir: Path, arbitrary: bool = True) -> Generator[Path, None, None]: + for name in AUTODISCOVERY_FILE_NAMES: + path = base_dir / name + + if path.exists() or path.with_suffix(".py").exists(): + yield path + if arbitrary and path.is_dir(): + yield from _arbitrary_autodiscovery_paths(path) + + +def _autodiscover_app(cwd: Path) -> LoadedApp: + for file_path in _autodiscovery_paths(cwd): + import_path = _path_to_dotted_path(file_path.relative_to(cwd)) + module = importlib.import_module(import_path) + + for attr, value in chain( + [("app", getattr(module, "app", None)), ("application", getattr(module, "application", None))], + module.__dict__.items(), + ): + if isinstance(value, Litestar): + app_string = f"{import_path}:{attr}" + os.environ["LITESTAR_APP"] = app_string + console.print(f"Using Litestar app from [bright_blue]{app_string}") + return LoadedApp(app=value, app_path=app_string, is_factory=False) + + if hasattr(module, "create_app"): + app_string = f"{import_path}:create_app" + os.environ["LITESTAR_APP"] = app_string + console.print(f"Using Litestar factory [bright_blue]{app_string}") + return LoadedApp(app=module.create_app(), app_path=app_string, is_factory=True) + + for attr, value in module.__dict__.items(): + if not callable(value): + continue + return_annotation = ( + get_type_hints(value, include_extras=True).get("return") if hasattr(value, "__annotations__") else None + ) + if not return_annotation: + continue + if return_annotation in ("Litestar", Litestar): + app_string = f"{import_path}:{attr}" + os.environ["LITESTAR_APP"] = app_string + console.print(f"Using Litestar factory [bright_blue]{app_string}") + return LoadedApp(app=value(), app_path=f"{app_string}", is_factory=True) + + raise LitestarCLIException("Could not find a Litestar app or factory") + + +def _format_is_enabled(value: Any) -> str: + """Return a coloured string `"Enabled" if ``value`` is truthy, else "Disabled".""" + return "[green]Enabled[/]" if value else "[red]Disabled[/]" + + +def show_app_info(app: Litestar) -> None: # pragma: no cover + """Display basic information about the application and its configuration.""" + + table = Table(show_header=False) + table.add_column("title", style="cyan") + table.add_column("value", style="bright_blue") + + table.add_row("Litestar version", f"{__version__.major}.{__version__.minor}.{__version__.patch}") + table.add_row("Debug mode", _format_is_enabled(app.debug)) + table.add_row("Python Debugger on exception", _format_is_enabled(app.pdb_on_exception)) + table.add_row("CORS", _format_is_enabled(app.cors_config)) + table.add_row("CSRF", _format_is_enabled(app.csrf_config)) + if app.allowed_hosts: + allowed_hosts = app.allowed_hosts + + table.add_row("Allowed hosts", ", ".join(allowed_hosts.allowed_hosts)) + + openapi_enabled = _format_is_enabled(app.openapi_config) + if app.openapi_config: + openapi_enabled += f" path=[yellow]{app.openapi_config.openapi_controller.path}" + table.add_row("OpenAPI", openapi_enabled) + + table.add_row("Compression", app.compression_config.backend if app.compression_config else "[red]Disabled") + + if app.template_engine: + table.add_row("Template engine", type(app.template_engine).__name__) + + if app.static_files_config: + static_files_configs = app.static_files_config + static_files_info = [ + f"path=[yellow]{static_files.path}[/] dirs=[yellow]{', '.join(map(str, static_files.directories))}[/] " + f"html_mode={_format_is_enabled(static_files.html_mode)}" + for static_files in static_files_configs + ] + table.add_row("Static files", "\n".join(static_files_info)) + + middlewares = [] + for middleware in app.middleware: + updated_middleware = middleware.middleware if isinstance(middleware, DefineMiddleware) else middleware + middlewares.append(get_name(updated_middleware)) + if middlewares: + table.add_row("Middlewares", ", ".join(middlewares)) + + console.print(table) + + +def validate_ssl_file_paths(certfile_arg: str | None, keyfile_arg: str | None) -> tuple[str, str] | tuple[None, None]: + """Validate whether given paths exist, are not directories and were both provided or none was. Return the resolved paths. + + Args: + certfile_arg: path argument for the certificate file + keyfile_arg: path argument for the key file + + Returns: + tuple of resolved paths converted to str or tuple of None's if no argument was provided + """ + if certfile_arg is None and keyfile_arg is None: + return (None, None) + + resolved_paths = [] + + for argname, arg in {"--ssl-certfile": certfile_arg, "--ssl-keyfile": keyfile_arg}.items(): + if arg is None: + raise LitestarCLIException(f"No value provided for {argname}") + path = Path(arg).resolve() + if path.is_dir(): + raise LitestarCLIException(f"Path provided for {argname} is a directory: {path}") + if not path.exists(): + raise LitestarCLIException(f"File provided for {argname} was not found: {path}") + resolved_paths.append(str(path)) + + return tuple(resolved_paths) # type: ignore[return-value] + + +def create_ssl_files( + certfile_arg: str | None, keyfile_arg: str | None, common_name: str = "localhost" +) -> tuple[str, str]: + """Validate whether both files were provided, are not directories, their parent dirs exist and either both files exists or none does. + If neither file exists, create a self-signed ssl certificate and a passwordless key at the location. + + Args: + certfile_arg: path argument for the certificate file + keyfile_arg: path argument for the key file + common_name: the CN to be used as cert issuer and subject + + Returns: + resolved paths of the found or generated files + """ + resolved_paths = [] + + for argname, arg in {"--ssl-certfile": certfile_arg, "--ssl-keyfile": keyfile_arg}.items(): + if arg is None: + raise LitestarCLIException(f"No value provided for {argname}") + path = Path(arg).resolve() + if path.is_dir(): + raise LitestarCLIException(f"Path provided for {argname} is a directory: {path}") + if not (parent_dir := path.parent).exists(): + raise LitestarCLIException( + f"Could not create file, parent directory for {argname} doesn't exist: {parent_dir}" + ) + resolved_paths.append(path) + + if (not resolved_paths[0].exists()) ^ (not resolved_paths[1].exists()): + raise LitestarCLIException( + "Both certificate and key file must exists or both must not exists when using --create-self-signed-cert" + ) + + if (not resolved_paths[0].exists()) and (not resolved_paths[1].exists()): + _generate_self_signed_cert(resolved_paths[0], resolved_paths[1], common_name) + + return (str(resolved_paths[0]), str(resolved_paths[1])) + + +def _generate_self_signed_cert(certfile_path: Path, keyfile_path: Path, common_name: str) -> None: + """Create a self-signed certificate using the cryptography modules at given paths""" + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + except ImportError as err: + raise LitestarCLIException( + "Cryptography must be installed when using --create-self-signed-cert\nPlease install the litestar[cryptography] extras" + ) from err + + subject = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, common_name), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Development Certificate"), + ] + ) + + key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) + + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(subject) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(tz=timezone.utc)) + .not_valid_after(datetime.now(tz=timezone.utc) + timedelta(days=365)) + .add_extension(x509.SubjectAlternativeName([x509.DNSName(common_name)]), critical=False) + .add_extension(x509.ExtendedKeyUsage([x509.OID_SERVER_AUTH]), critical=False) + .sign(key, hashes.SHA256(), default_backend()) + ) + + with certfile_path.open("wb") as cert_file: + cert_file.write(cert.public_bytes(serialization.Encoding.PEM)) + + with keyfile_path.open("wb") as key_file: + key_file.write( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + + +def remove_routes_with_patterns( + routes: list[HTTPRoute | ASGIRoute | WebSocketRoute], patterns: tuple[str, ...] +) -> list[HTTPRoute | ASGIRoute | WebSocketRoute]: + regex_routes = [] + valid_patterns = [] + for pattern in patterns: + try: + check_pattern = re.compile(pattern) + valid_patterns.append(check_pattern) + except re.error as e: + console.print(f"Error: {e}. Invalid regex pattern supplied: '{pattern}'. Omitting from querying results.") + + for route in routes: + checked_pattern_route_matches = [] + for pattern_compile in valid_patterns: + matches = pattern_compile.match(route.path) + checked_pattern_route_matches.append(matches) + + if not any(checked_pattern_route_matches): + regex_routes.append(route) + + return regex_routes + + +def remove_default_schema_routes( + routes: list[HTTPRoute | ASGIRoute | WebSocketRoute], openapi_config: OpenAPIConfig +) -> list[HTTPRoute | ASGIRoute | WebSocketRoute]: + schema_path = openapi_config.openapi_controller.path + return remove_routes_with_patterns(routes, (schema_path,)) diff --git a/venv/lib/python3.11/site-packages/litestar/cli/commands/__init__.py b/venv/lib/python3.11/site-packages/litestar/cli/commands/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/cli/commands/__init__.py diff --git a/venv/lib/python3.11/site-packages/litestar/cli/commands/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/cli/commands/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..af651ad --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/cli/commands/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/cli/commands/__pycache__/core.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/cli/commands/__pycache__/core.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6f44d52 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/cli/commands/__pycache__/core.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/cli/commands/__pycache__/schema.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/cli/commands/__pycache__/schema.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..d6d7532 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/cli/commands/__pycache__/schema.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/cli/commands/__pycache__/sessions.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/cli/commands/__pycache__/sessions.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..bebfa1a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/cli/commands/__pycache__/sessions.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/cli/commands/core.py b/venv/lib/python3.11/site-packages/litestar/cli/commands/core.py new file mode 100644 index 0000000..5b55253 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/cli/commands/core.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import inspect +import multiprocessing +import os +import subprocess +import sys +from contextlib import AbstractContextManager, ExitStack, contextmanager +from typing import TYPE_CHECKING, Any, Iterator + +import click +from click import Context, command, option +from rich.tree import Tree + +from litestar.app import DEFAULT_OPENAPI_CONFIG +from litestar.cli._utils import ( + UVICORN_INSTALLED, + LitestarEnv, + console, + create_ssl_files, + remove_default_schema_routes, + remove_routes_with_patterns, + show_app_info, + validate_ssl_file_paths, +) +from litestar.routes import ASGIRoute, HTTPRoute, WebSocketRoute +from litestar.utils.helpers import unwrap_partial + +__all__ = ("info_command", "routes_command", "run_command") + +if TYPE_CHECKING: + from litestar import Litestar + + +@contextmanager +def _server_lifespan(app: Litestar) -> Iterator[None]: + """Context manager handling the ASGI server lifespan. + + It will be entered just before the ASGI server is started through the CLI. + """ + with ExitStack() as exit_stack: + for manager in app._server_lifespan_managers: + if not isinstance(manager, AbstractContextManager): + manager = manager(app) # type: ignore[assignment] + exit_stack.enter_context(manager) # type: ignore[arg-type] + + yield + + +def _convert_uvicorn_args(args: dict[str, Any]) -> list[str]: + process_args = [] + for arg, value in args.items(): + if isinstance(value, bool): + if value: + process_args.append(f"--{arg}") + elif isinstance(value, tuple): + process_args.extend(f"--{arg}={item}" for item in value) + else: + process_args.append(f"--{arg}={value}") + + return process_args + + +def _run_uvicorn_in_subprocess( + *, + env: LitestarEnv, + host: str | None, + port: int | None, + workers: int | None, + reload: bool, + reload_dirs: tuple[str, ...] | None, + reload_include: tuple[str, ...] | None, + reload_exclude: tuple[str, ...] | None, + fd: int | None, + uds: str | None, + certfile_path: str | None, + keyfile_path: str | None, +) -> None: + process_args: dict[str, Any] = { + "reload": reload, + "host": host, + "port": port, + "workers": workers, + "factory": env.is_app_factory, + } + if fd is not None: + process_args["fd"] = fd + if uds is not None: + process_args["uds"] = uds + if reload_dirs: + process_args["reload-dir"] = reload_dirs + if reload_include: + process_args["reload-include"] = reload_include + if reload_exclude: + process_args["reload-exclude"] = reload_exclude + if certfile_path is not None: + process_args["ssl-certfile"] = certfile_path + if keyfile_path is not None: + process_args["ssl-keyfile"] = keyfile_path + subprocess.run( + [sys.executable, "-m", "uvicorn", env.app_path, *_convert_uvicorn_args(process_args)], # noqa: S603 + check=True, + ) + + +@command(name="version") +@option("-s", "--short", help="Exclude release level and serial information", is_flag=True, default=False) +def version_command(short: bool) -> None: + """Show the currently installed Litestar version.""" + from litestar import __version__ + + click.echo(__version__.formatted(short=short)) + + +@command(name="info") +def info_command(app: Litestar) -> None: + """Show information about the detected Litestar app.""" + + show_app_info(app) + + +@command(name="run") +@option("-r", "--reload", help="Reload server on changes", default=False, is_flag=True) +@option("-R", "--reload-dir", help="Directories to watch for file changes", multiple=True) +@option( + "-I", "--reload-include", help="Glob patterns for files to include when watching for file changes", multiple=True +) +@option( + "-E", "--reload-exclude", help="Glob patterns for files to exclude when watching for file changes", multiple=True +) +@option("-p", "--port", help="Serve under this port", type=int, default=8000, show_default=True) +@option( + "-W", + "--wc", + "--web-concurrency", + help="The number of HTTP workers to launch", + type=click.IntRange(min=1, max=multiprocessing.cpu_count() + 1), + show_default=True, + default=1, +) +@option("-H", "--host", help="Server under this host", default="127.0.0.1", show_default=True) +@option( + "-F", + "--fd", + "--file-descriptor", + help="Bind to a socket from this file descriptor.", + type=int, + default=None, + show_default=True, +) +@option("-U", "--uds", "--unix-domain-socket", help="Bind to a UNIX domain socket.", default=None, show_default=True) +@option("-d", "--debug", help="Run app in debug mode", is_flag=True) +@option("-P", "--pdb", "--use-pdb", help="Drop into PDB on an exception", is_flag=True) +@option("--ssl-certfile", help="Location of the SSL cert file", default=None) +@option("--ssl-keyfile", help="Location of the SSL key file", default=None) +@option( + "--create-self-signed-cert", + help="If certificate and key are not found at specified locations, create a self-signed certificate and a key", + is_flag=True, +) +def run_command( + reload: bool, + port: int, + wc: int, + host: str, + fd: int | None, + uds: str | None, + debug: bool, + reload_dir: tuple[str, ...], + reload_include: tuple[str, ...], + reload_exclude: tuple[str, ...], + pdb: bool, + ssl_certfile: str | None, + ssl_keyfile: str | None, + create_self_signed_cert: bool, + ctx: Context, +) -> None: + """Run a Litestar app; requires ``uvicorn``. + + The app can be either passed as a module path in the form of <module name>.<submodule>:<app instance or factory>, + set as an environment variable LITESTAR_APP with the same format or automatically discovered from one of these + canonical paths: app.py, asgi.py, application.py or app/__init__.py. When auto-discovering application factories, + functions with the name ``create_app`` are considered, or functions that are annotated as returning a ``Litestar`` + instance. + """ + + if debug: + os.environ["LITESTAR_DEBUG"] = "1" + + if pdb: + os.environ["LITESTAR_PDB"] = "1" + + if not UVICORN_INSTALLED: + console.print( + r"uvicorn is not installed. Please install the standard group, litestar\[standard], to use this command." + ) + sys.exit(1) + + if callable(ctx.obj): + ctx.obj = ctx.obj() + else: + if debug: + ctx.obj.app.debug = True + if pdb: + ctx.obj.app.pdb_on_exception = True + + env: LitestarEnv = ctx.obj + app = env.app + + reload_dirs = env.reload_dirs or reload_dir + reload_include = env.reload_include or reload_include + reload_exclude = env.reload_exclude or reload_exclude + + host = env.host or host + port = env.port if env.port is not None else port + fd = env.fd if env.fd is not None else fd + uds = env.uds or uds + reload = env.reload or reload or bool(reload_dirs) or bool(reload_include) or bool(reload_exclude) + workers = env.web_concurrency or wc + + ssl_certfile = ssl_certfile or env.certfile_path + ssl_keyfile = ssl_keyfile or env.keyfile_path + create_self_signed_cert = create_self_signed_cert or env.create_self_signed_cert + + certfile_path, keyfile_path = ( + create_ssl_files(ssl_certfile, ssl_keyfile, host) + if create_self_signed_cert + else validate_ssl_file_paths(ssl_certfile, ssl_keyfile) + ) + + console.rule("[yellow]Starting server process", align="left") + + show_app_info(app) + with _server_lifespan(app): + if workers == 1 and not reload: + import uvicorn + + # A guard statement at the beginning of this function prevents uvicorn from being unbound + # See "reportUnboundVariable in: + # https://microsoft.github.io/pyright/#/configuration?id=type-check-diagnostics-settings + uvicorn.run( # pyright: ignore + app=env.app_path, + host=host, + port=port, + fd=fd, + uds=uds, + factory=env.is_app_factory, + ssl_certfile=certfile_path, + ssl_keyfile=keyfile_path, + ) + else: + # invoke uvicorn in a subprocess to be able to use the --reload flag. see + # https://github.com/litestar-org/litestar/issues/1191 and https://github.com/encode/uvicorn/issues/1045 + if sys.gettrace() is not None: + console.print( + "[yellow]Debugger detected. Breakpoints might not work correctly inside route handlers when running" + " with the --reload or --workers options[/]" + ) + + _run_uvicorn_in_subprocess( + env=env, + host=host, + port=port, + workers=workers, + reload=reload, + reload_dirs=reload_dirs, + reload_include=reload_include, + reload_exclude=reload_exclude, + fd=fd, + uds=uds, + certfile_path=certfile_path, + keyfile_path=keyfile_path, + ) + + +@command(name="routes") +@option("--schema", help="Include schema routes", is_flag=True, default=False) +@option("--exclude", help="routes to exclude via regex", type=str, is_flag=False, multiple=True) +def routes_command(app: Litestar, exclude: tuple[str, ...], schema: bool) -> None: # pragma: no cover + """Display information about the application's routes.""" + + sorted_routes = sorted(app.routes, key=lambda r: r.path) + if not schema: + openapi_config = app.openapi_config or DEFAULT_OPENAPI_CONFIG + sorted_routes = remove_default_schema_routes(sorted_routes, openapi_config) + if exclude is not None: + sorted_routes = remove_routes_with_patterns(sorted_routes, exclude) + + console.print(_RouteTree(sorted_routes)) + + +class _RouteTree(Tree): + def __init__(self, routes: list[HTTPRoute | ASGIRoute | WebSocketRoute]) -> None: + super().__init__("", hide_root=True) + self._routes = routes + self._build() + + def _build(self) -> None: + for route in self._routes: + if isinstance(route, HTTPRoute): + self._handle_http_route(route) + elif isinstance(route, WebSocketRoute): + self._handle_websocket_route(route) + else: + self._handle_asgi_route(route) + + def _handle_asgi_like_route(self, route: ASGIRoute | WebSocketRoute, route_type: str) -> None: + branch = self.add(f"[green]{route.path}[/green] ({route_type})") + branch.add(f"[blue]{route.route_handler.name or route.route_handler.handler_name}[/blue]") + + def _handle_asgi_route(self, route: ASGIRoute) -> None: + self._handle_asgi_like_route(route, route_type="ASGI") + + def _handle_websocket_route(self, route: WebSocketRoute) -> None: + self._handle_asgi_like_route(route, route_type="WS") + + def _handle_http_route(self, route: HTTPRoute) -> None: + branch = self.add(f"[green]{route.path}[/green] (HTTP)") + for handler in route.route_handlers: + handler_info = [ + f"[blue]{handler.name or handler.handler_name}[/blue]", + ] + + if inspect.iscoroutinefunction(unwrap_partial(handler.fn)): + handler_info.append("[magenta]async[/magenta]") + else: + handler_info.append("[yellow]sync[/yellow]") + + handler_info.append(f'[cyan]{", ".join(sorted(handler.http_methods))}[/cyan]') + + if len(handler.paths) > 1: + for path in handler.paths: + branch.add(" ".join([f"[green]{path}[green]", *handler_info])) + else: + branch.add(" ".join(handler_info)) diff --git a/venv/lib/python3.11/site-packages/litestar/cli/commands/schema.py b/venv/lib/python3.11/site-packages/litestar/cli/commands/schema.py new file mode 100644 index 0000000..a323bc7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/cli/commands/schema.py @@ -0,0 +1,82 @@ +from pathlib import Path + +import msgspec +from click import Path as ClickPath +from click import group, option +from yaml import dump as dump_yaml + +from litestar import Litestar +from litestar._openapi.typescript_converter.converter import ( + convert_openapi_to_typescript, +) +from litestar.cli._utils import JSBEAUTIFIER_INSTALLED, LitestarCLIException, LitestarGroup +from litestar.serialization import encode_json, get_serializer + +__all__ = ("generate_openapi_schema", "generate_typescript_specs", "schema_group") + + +@group(cls=LitestarGroup, name="schema") +def schema_group() -> None: + """Manage server-side OpenAPI schemas.""" + + +def _generate_openapi_schema(app: Litestar, output: Path) -> None: + """Generate an OpenAPI Schema.""" + serializer = get_serializer(app.type_encoders) + if output.suffix in (".yml", ".yaml"): + content = dump_yaml( + msgspec.to_builtins(app.openapi_schema.to_schema(), enc_hook=serializer), + default_flow_style=False, + encoding="utf-8", + ) + else: + content = msgspec.json.format( + encode_json(app.openapi_schema.to_schema(), serializer=serializer), + indent=4, + ) + + try: + output.write_bytes(content) + except OSError as e: # pragma: no cover + raise LitestarCLIException(f"failed to write schema to path {output}") from e + + +@schema_group.command("openapi") # type: ignore[misc] +@option( + "--output", + help="output file path", + type=ClickPath(dir_okay=False, path_type=Path), + default=Path("openapi_schema.json"), + show_default=True, +) +def generate_openapi_schema(app: Litestar, output: Path) -> None: + """Generate an OpenAPI Schema.""" + _generate_openapi_schema(app, output) + + +@schema_group.command("typescript") # type: ignore[misc] +@option( + "--output", + help="output file path", + type=ClickPath(dir_okay=False, path_type=Path), + default=Path("api-specs.ts"), + show_default=True, +) +@option("--namespace", help="namespace to use for the typescript specs", type=str, default="API") +def generate_typescript_specs(app: Litestar, output: Path, namespace: str) -> None: + """Generate TypeScript specs from the OpenAPI schema.""" + if JSBEAUTIFIER_INSTALLED: # pragma: no cover + from jsbeautifier import Beautifier + + beautifier = Beautifier() + else: + beautifier = None + try: + specs = convert_openapi_to_typescript(app.openapi_schema, namespace) + # beautifier will be defined if JSBEAUTIFIER_INSTALLED is True + specs_output = ( + beautifier.beautify(specs.write()) if JSBEAUTIFIER_INSTALLED and beautifier else specs.write() # pyright: ignore + ) + output.write_text(specs_output) + except OSError as e: # pragma: no cover + raise LitestarCLIException(f"failed to write schema to path {output}") from e diff --git a/venv/lib/python3.11/site-packages/litestar/cli/commands/sessions.py b/venv/lib/python3.11/site-packages/litestar/cli/commands/sessions.py new file mode 100644 index 0000000..f048dd1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/cli/commands/sessions.py @@ -0,0 +1,58 @@ +from click import argument, group +from rich.prompt import Confirm + +from litestar import Litestar +from litestar.cli._utils import LitestarCLIException, LitestarGroup, console +from litestar.middleware import DefineMiddleware +from litestar.middleware.session import SessionMiddleware +from litestar.middleware.session.server_side import ServerSideSessionBackend +from litestar.utils import is_class_and_subclass + +__all__ = ("clear_sessions_command", "delete_session_command", "get_session_backend", "sessions_group") + + +def get_session_backend(app: Litestar) -> ServerSideSessionBackend: + """Get the session backend used by a ``Litestar`` app.""" + for middleware in app.middleware: + if isinstance(middleware, DefineMiddleware): + if not is_class_and_subclass(middleware.middleware, SessionMiddleware): + continue + backend = middleware.kwargs["backend"] + if not isinstance(backend, ServerSideSessionBackend): + raise LitestarCLIException("Only server-side backends are supported") + return backend + raise LitestarCLIException("Session middleware not installed") + + +@group(cls=LitestarGroup, name="sessions") +def sessions_group() -> None: + """Manage server-side sessions.""" + + +@sessions_group.command("delete") # type: ignore[misc] +@argument("session-id") +def delete_session_command(session_id: str, app: Litestar) -> None: + """Delete a specific session.""" + import anyio + + backend = get_session_backend(app) + store = backend.config.get_store_from_app(app) + + if Confirm.ask(f"Delete session {session_id!r}?"): + anyio.run(backend.delete, session_id, store) + console.print(f"[green]Deleted session {session_id!r}") + + +@sessions_group.command("clear") # type: ignore[misc] +def clear_sessions_command(app: Litestar) -> None: + """Delete all sessions.""" + import anyio + + backend = get_session_backend(app) + store = backend.config.get_store_from_app(app) + if not hasattr(store, "delete_all"): + raise LitestarCLIException(f"{type(store)} does not support clearing all sessions") + + if Confirm.ask("[red]Delete all sessions?"): + anyio.run(store.delete_all) # pyright: ignore + console.print("[green]All active sessions deleted") diff --git a/venv/lib/python3.11/site-packages/litestar/cli/main.py b/venv/lib/python3.11/site-packages/litestar/cli/main.py new file mode 100644 index 0000000..32505f6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/cli/main.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from pathlib import Path + +from click import Context, group, option, pass_context +from click import Path as ClickPath + +from ._utils import LitestarEnv, LitestarExtensionGroup +from .commands import core, schema, sessions + +__all__ = ("litestar_group",) + + +@group(cls=LitestarExtensionGroup, context_settings={"help_option_names": ["-h", "--help"]}) +@option("--app", "app_path", help="Module path to a Litestar application") +@option( + "--app-dir", + help="Look for APP in the specified directory, by adding this to the PYTHONPATH. Defaults to the current working directory.", + default=None, + type=ClickPath(dir_okay=True, file_okay=False, path_type=Path), + show_default=False, +) +@pass_context +def litestar_group(ctx: Context, app_path: str | None, app_dir: Path | None = None) -> None: + """Litestar CLI.""" + if ctx.obj is None: # env has not been loaded yet, so we can lazy load it + ctx.obj = lambda: LitestarEnv.from_env(app_path, app_dir=app_dir) + + +# add sub commands here + +litestar_group.add_command(core.info_command) +litestar_group.add_command(core.run_command) +litestar_group.add_command(core.routes_command) +litestar_group.add_command(core.version_command) +litestar_group.add_command(sessions.sessions_group) +litestar_group.add_command(schema.schema_group) diff --git a/venv/lib/python3.11/site-packages/litestar/concurrency.py b/venv/lib/python3.11/site-packages/litestar/concurrency.py new file mode 100644 index 0000000..90eadbf --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/concurrency.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import asyncio +import contextvars +from functools import partial +from typing import TYPE_CHECKING, Callable, TypeVar + +import sniffio +from typing_extensions import ParamSpec + +if TYPE_CHECKING: + from concurrent.futures import ThreadPoolExecutor + + import trio + + +T = TypeVar("T") +P = ParamSpec("P") + + +__all__ = ( + "sync_to_thread", + "set_asyncio_executor", + "get_asyncio_executor", + "set_trio_capacity_limiter", + "get_trio_capacity_limiter", +) + + +class _State: + EXECUTOR: ThreadPoolExecutor | None = None + LIMITER: trio.CapacityLimiter | None = None + + +async def _run_sync_asyncio(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + ctx = contextvars.copy_context() + bound_fn = partial(ctx.run, fn, *args, **kwargs) + return await asyncio.get_running_loop().run_in_executor(get_asyncio_executor(), bound_fn) # pyright: ignore + + +async def _run_sync_trio(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + import trio + + return await trio.to_thread.run_sync(partial(fn, *args, **kwargs), limiter=get_trio_capacity_limiter()) + + +async def sync_to_thread(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + """Run the synchronous callable ``fn`` asynchronously in a worker thread. + + When called from asyncio, uses :meth:`asyncio.loop.run_in_executor` to + run the callable. No executor is specified by default so the current loop's executor + is used. A specific executor can be set using + :func:`~litestar.concurrency.set_asyncio_executor`. This does not affect the loop's + default executor. + + When called from trio, uses :func:`trio.to_thread.run_sync` to run the callable. No + capacity limiter is specified by default, but one can be set using + :func:`~litestar.concurrency.set_trio_capacity_limiter`. This does not affect trio's + default capacity limiter. + """ + if (library := sniffio.current_async_library()) == "asyncio": + return await _run_sync_asyncio(fn, *args, **kwargs) + + if library == "trio": + return await _run_sync_trio(fn, *args, **kwargs) + + raise RuntimeError("Unsupported async library or not in async context") + + +def set_asyncio_executor(executor: ThreadPoolExecutor | None) -> None: + """Set the executor in which synchronous callables will be run within an asyncio + context + """ + try: + sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + pass + else: + raise RuntimeError("Cannot set executor from running loop") + + _State.EXECUTOR = executor + + +def get_asyncio_executor() -> ThreadPoolExecutor | None: + """Get the executor in which synchronous callables will be run within an asyncio + context + """ + return _State.EXECUTOR + + +def set_trio_capacity_limiter(limiter: trio.CapacityLimiter | None) -> None: + """Set the capacity limiter used when running synchronous callable within a trio + context + """ + try: + sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + pass + else: + raise RuntimeError("Cannot set limiter while in async context") + + _State.LIMITER = limiter + + +def get_trio_capacity_limiter() -> trio.CapacityLimiter | None: + """Get the capacity limiter used when running synchronous callable within a trio + context + """ + return _State.LIMITER diff --git a/venv/lib/python3.11/site-packages/litestar/config/__init__.py b/venv/lib/python3.11/site-packages/litestar/config/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/config/__init__.py diff --git a/venv/lib/python3.11/site-packages/litestar/config/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/config/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c6ef8a3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/config/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/config/__pycache__/allowed_hosts.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/config/__pycache__/allowed_hosts.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f340305 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/config/__pycache__/allowed_hosts.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/config/__pycache__/app.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/config/__pycache__/app.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a84a2f7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/config/__pycache__/app.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/config/__pycache__/compression.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/config/__pycache__/compression.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b6d9382 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/config/__pycache__/compression.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/config/__pycache__/cors.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/config/__pycache__/cors.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..9965625 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/config/__pycache__/cors.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/config/__pycache__/csrf.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/config/__pycache__/csrf.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5056590 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/config/__pycache__/csrf.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/config/__pycache__/response_cache.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/config/__pycache__/response_cache.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0ca0cf9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/config/__pycache__/response_cache.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/config/allowed_hosts.py b/venv/lib/python3.11/site-packages/litestar/config/allowed_hosts.py new file mode 100644 index 0000000..4c8e6ac --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/config/allowed_hosts.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from litestar.exceptions import ImproperlyConfiguredException + +__all__ = ("AllowedHostsConfig",) + + +if TYPE_CHECKING: + from litestar.types import Scopes + + +@dataclass +class AllowedHostsConfig: + """Configuration for allowed hosts protection. + + To enable allowed hosts protection, pass an instance of this class to the :class:`Litestar <litestar.app.Litestar>` + constructor using the ``allowed_hosts`` key. + """ + + allowed_hosts: list[str] = field(default_factory=lambda: ["*"]) + """A list of trusted hosts. + + Use ``*.`` to allow all hosts, or prefix domains with ``*.`` to allow all sub domains. + """ + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns to skip in the Allowed Hosts middleware.""" + exclude_opt_key: str | None = field(default=None) + """An identifier to use on routes to disable hosts check for a particular route.""" + scopes: Scopes | None = field(default=None) + """ASGI scopes processed by the middleware, if None both ``http`` and ``websocket`` will be processed.""" + www_redirect: bool = field(default=True) + """A boolean dictating whether to redirect requests that start with ``www.`` and otherwise match a trusted host.""" + + def __post_init__(self) -> None: + """Ensure that the trusted hosts have correct domain wildcards.""" + for host in self.allowed_hosts: + if host != "*" and "*" in host and not host.startswith("*."): + raise ImproperlyConfiguredException( + "domain wildcards can only appear in the beginning of the domain, e.g. ``*.example.com``" + ) diff --git a/venv/lib/python3.11/site-packages/litestar/config/app.py b/venv/lib/python3.11/site-packages/litestar/config/app.py new file mode 100644 index 0000000..0acefb1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/config/app.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import enum +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable + +from litestar.config.allowed_hosts import AllowedHostsConfig +from litestar.config.response_cache import ResponseCacheConfig +from litestar.datastructures import State +from litestar.events.emitter import SimpleEventEmitter +from litestar.types.empty import Empty + +if TYPE_CHECKING: + from contextlib import AbstractAsyncContextManager + + from litestar import Litestar, Response + from litestar.config.compression import CompressionConfig + from litestar.config.cors import CORSConfig + from litestar.config.csrf import CSRFConfig + from litestar.connection import Request, WebSocket + from litestar.datastructures import CacheControlHeader, ETag + from litestar.di import Provide + from litestar.dto import AbstractDTO + from litestar.events.emitter import BaseEventEmitterBackend + from litestar.events.listener import EventListener + from litestar.logging.config import BaseLoggingConfig + from litestar.openapi.config import OpenAPIConfig + from litestar.openapi.spec import SecurityRequirement + from litestar.plugins import PluginProtocol + from litestar.static_files.config import StaticFilesConfig + from litestar.stores.base import Store + from litestar.stores.registry import StoreRegistry + from litestar.types import ( + AfterExceptionHookHandler, + AfterRequestHookHandler, + AfterResponseHookHandler, + AnyCallable, + BeforeMessageSendHookHandler, + BeforeRequestHookHandler, + ControllerRouterHandler, + ExceptionHandlersMap, + Guard, + Middleware, + ParametersMap, + ResponseCookies, + ResponseHeaders, + TypeEncodersMap, + ) + from litestar.types.callable_types import LifespanHook + from litestar.types.composite_types import TypeDecodersSequence + from litestar.types.empty import EmptyType + from litestar.types.internal_types import TemplateConfigType + + +__all__ = ( + "AppConfig", + "ExperimentalFeatures", +) + + +@dataclass +class AppConfig: + """The parameters provided to the ``Litestar`` app are used to instantiate an instance, and then the instance is + passed to any callbacks registered to ``on_app_init`` in the order they are provided. + + The final attribute values are used to instantiate the application object. + """ + + after_exception: list[AfterExceptionHookHandler] = field(default_factory=list) + """An application level :class:`exception hook handler <.types.AfterExceptionHookHandler>` or list thereof. + + This hook is called after an exception occurs. In difference to exception handlers, it is not meant to return a + response - only to process the exception (e.g. log it, send it to Sentry etc.). + """ + after_request: AfterRequestHookHandler | None = field(default=None) + """A sync or async function executed after the route handler function returned and the response object has been + resolved. + + Receives the response object which may be any subclass of :class:`Response <.response.Response>`. + """ + after_response: AfterResponseHookHandler | None = field(default=None) + """A sync or async function called after the response has been awaited. It receives the + :class:`Request <.connection.Request>` object and should not return any values. + """ + allowed_hosts: list[str] | AllowedHostsConfig | None = field(default=None) + """If set enables the builtin allowed hosts middleware.""" + before_request: BeforeRequestHookHandler | None = field(default=None) + """A sync or async function called immediately before calling the route handler. Receives the + :class:`Request <.connection.Request>` instance and any non-``None`` return value is used for the response, + bypassing the route handler. + """ + before_send: list[BeforeMessageSendHookHandler] = field(default_factory=list) + """An application level :class:`before send hook handler <.types.BeforeMessageSendHookHandler>` or list thereof. + + This hook is called when the ASGI send function is called. + """ + cache_control: CacheControlHeader | None = field(default=None) + """A ``cache-control`` header of type :class:`CacheControlHeader <.datastructures.CacheControlHeader>` to add to + route handlers of this app. + + Can be overridden by route handlers. + """ + compression_config: CompressionConfig | None = field(default=None) + """Configures compression behaviour of the application, this enabled a builtin or user defined Compression + middleware. + """ + cors_config: CORSConfig | None = field(default=None) + """If set this enables the builtin CORS middleware.""" + csrf_config: CSRFConfig | None = field(default=None) + """If set this enables the builtin CSRF middleware.""" + debug: bool = field(default=False) + """If ``True``, app errors rendered as HTML with a stack trace.""" + dependencies: dict[str, Provide | AnyCallable] = field(default_factory=dict) + """A string keyed dictionary of dependency :class:`Provider <.di.Provide>` instances.""" + dto: type[AbstractDTO] | None | EmptyType = field(default=Empty) + """:class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and validation of request data.""" + etag: ETag | None = field(default=None) + """An ``etag`` header of type :class:`ETag <.datastructures.ETag>` to add to route handlers of this app. + + Can be overridden by route handlers. + """ + event_emitter_backend: type[BaseEventEmitterBackend] = field(default=SimpleEventEmitter) + """A subclass of :class:`BaseEventEmitterBackend <.events.emitter.BaseEventEmitterBackend>`.""" + exception_handlers: ExceptionHandlersMap = field(default_factory=dict) + """A dictionary that maps handler functions to status codes and/or exception types.""" + guards: list[Guard] = field(default_factory=list) + """A list of :class:`Guard <.types.Guard>` callables.""" + include_in_schema: bool | EmptyType = field(default=Empty) + """A boolean flag dictating whether the route handler should be documented in the OpenAPI schema""" + lifespan: list[Callable[[Litestar], AbstractAsyncContextManager] | AbstractAsyncContextManager] = field( + default_factory=list + ) + """A list of callables returning async context managers, wrapping the lifespan of the ASGI application""" + listeners: list[EventListener] = field(default_factory=list) + """A list of :class:`EventListener <.events.listener.EventListener>`.""" + logging_config: BaseLoggingConfig | None = field(default=None) + """An instance of :class:`BaseLoggingConfig <.logging.config.BaseLoggingConfig>` subclass.""" + middleware: list[Middleware] = field(default_factory=list) + """A list of :class:`Middleware <.types.Middleware>`.""" + on_shutdown: list[LifespanHook] = field(default_factory=list) + """A list of :class:`LifespanHook <.types.LifespanHook>` called during application shutdown.""" + on_startup: list[LifespanHook] = field(default_factory=list) + """A list of :class:`LifespanHook <.types.LifespanHook>` called during application startup.""" + openapi_config: OpenAPIConfig | None = field(default=None) + """Defaults to :data:`DEFAULT_OPENAPI_CONFIG <litestar.app.DEFAULT_OPENAPI_CONFIG>`""" + opt: dict[str, Any] = field(default_factory=dict) + """A string keyed dictionary of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <litestar.types.Scope>`. + + Can be overridden by routers and router handlers. + """ + parameters: ParametersMap = field(default_factory=dict) + """A mapping of :class:`Parameter <.params.Parameter>` definitions available to all application paths.""" + pdb_on_exception: bool = field(default=False) + """Drop into the PDB on an exception""" + plugins: list[PluginProtocol] = field(default_factory=list) + """List of :class:`SerializationPluginProtocol <.plugins.SerializationPluginProtocol>`.""" + request_class: type[Request] | None = field(default=None) + """An optional subclass of :class:`Request <.connection.Request>` to use for http connections.""" + response_class: type[Response] | None = field(default=None) + """A custom subclass of :class:`Response <.response.Response>` to be used as the app's default response.""" + response_cookies: ResponseCookies = field(default_factory=list) + """A list of :class:`Cookie <.datastructures.Cookie>`.""" + response_headers: ResponseHeaders = field(default_factory=list) + """A string keyed dictionary mapping :class:`ResponseHeader <.datastructures.ResponseHeader>`.""" + response_cache_config: ResponseCacheConfig = field(default_factory=ResponseCacheConfig) + """Configures caching behavior of the application.""" + return_dto: type[AbstractDTO] | None | EmptyType = field(default=Empty) + """:class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing outbound response + data. + """ + route_handlers: list[ControllerRouterHandler] = field(default_factory=list) + """A required list of route handlers, which can include instances of :class:`Router <.router.Router>`, + subclasses of :class:`Controller <.controller.Controller>` or any function decorated by the route handler + decorators. + """ + security: list[SecurityRequirement] = field(default_factory=list) + """A list of dictionaries that will be added to the schema of all route handlers in the application. See + :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>` for details. + """ + signature_namespace: dict[str, Any] = field(default_factory=dict) + """A mapping of names to types for use in forward reference resolution during signature modeling.""" + signature_types: list[Any] = field(default_factory=list) + """A sequence of types for use in forward reference resolution during signature modeling. + + These types will be added to the signature namespace using their ``__name__`` attribute. + """ + state: State = field(default_factory=State) + """A :class:`State` <.datastructures.State>` instance holding application state.""" + static_files_config: list[StaticFilesConfig] = field(default_factory=list) + """An instance or list of :class:`StaticFilesConfig <.static_files.StaticFilesConfig>`.""" + stores: StoreRegistry | dict[str, Store] | None = None + """Central registry of :class:`Store <.stores.base.Store>` to be made available and be used throughout the + application. Can be either a dictionary mapping strings to :class:`Store <.stores.base.Store>` instances, or an + instance of :class:`StoreRegistry <.stores.registry.StoreRegistry>`. + """ + tags: list[str] = field(default_factory=list) + """A list of string tags that will be appended to the schema of all route handlers under the application.""" + template_config: TemplateConfigType | None = field(default=None) + """An instance of :class:`TemplateConfig <.template.TemplateConfig>`.""" + type_decoders: TypeDecodersSequence | None = field(default=None) + """A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization.""" + type_encoders: TypeEncodersMap | None = field(default=None) + """A mapping of types to callables that transform them into types supported for serialization.""" + websocket_class: type[WebSocket] | None = field(default=None) + """An optional subclass of :class:`WebSocket <.connection.WebSocket>` to use for websocket connections.""" + multipart_form_part_limit: int = field(default=1000) + """The maximal number of allowed parts in a multipart/formdata request. This limit is intended to protect from + DoS attacks.""" + experimental_features: list[ExperimentalFeatures] | None = None + + def __post_init__(self) -> None: + """Normalize the allowed hosts to be a config or None. + + Returns: + Optional config. + """ + if self.allowed_hosts and isinstance(self.allowed_hosts, list): + self.allowed_hosts = AllowedHostsConfig(allowed_hosts=self.allowed_hosts) + + +class ExperimentalFeatures(str, enum.Enum): + DTO_CODEGEN = "DTO_CODEGEN" diff --git a/venv/lib/python3.11/site-packages/litestar/config/compression.py b/venv/lib/python3.11/site-packages/litestar/config/compression.py new file mode 100644 index 0000000..c339329 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/config/compression.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal + +from litestar.exceptions import ImproperlyConfiguredException +from litestar.middleware.compression import CompressionMiddleware +from litestar.middleware.compression.gzip_facade import GzipCompression + +if TYPE_CHECKING: + from litestar.middleware.compression.facade import CompressionFacade + +__all__ = ("CompressionConfig",) + + +@dataclass +class CompressionConfig: + """Configuration for response compression. + + To enable response compression, pass an instance of this class to the :class:`Litestar <.app.Litestar>` constructor + using the ``compression_config`` key. + """ + + backend: Literal["gzip", "brotli"] | str + """The backend to use. + + If the value given is `gzip` or `brotli`, then the builtin gzip and brotli compression is used. + """ + minimum_size: int = field(default=500) + """Minimum response size (bytes) to enable compression, affects all backends.""" + gzip_compress_level: int = field(default=9) + """Range ``[0-9]``, see :doc:`python:library/gzip`.""" + brotli_quality: int = field(default=5) + """Range ``[0-11]``, Controls the compression-speed vs compression-density tradeoff. + + The higher the quality, the slower the compression. + """ + brotli_mode: Literal["generic", "text", "font"] = "text" + """``MODE_GENERIC``, ``MODE_TEXT`` (for UTF-8 format text input, default) or ``MODE_FONT`` (for WOFF 2.0).""" + brotli_lgwin: int = field(default=22) + """Base 2 logarithm of size. + + Range is 10 to 24. Defaults to 22. + """ + brotli_lgblock: Literal[0, 16, 17, 18, 19, 20, 21, 22, 23, 24] = 0 + """Base 2 logarithm of the maximum input block size. + + Range is ``16`` to ``24``. If set to ``0``, the value will be set based on the quality. Defaults to ``0``. + """ + brotli_gzip_fallback: bool = True + """Use GZIP if Brotli is not supported.""" + middleware_class: type[CompressionMiddleware] = CompressionMiddleware + """Middleware class to use, should be a subclass of :class:`CompressionMiddleware`.""" + exclude: str | list[str] | None = None + """A pattern or list of patterns to skip in the compression middleware.""" + exclude_opt_key: str | None = None + """An identifier to use on routes to disable compression for a particular route.""" + compression_facade: type[CompressionFacade] = GzipCompression + """The compression facade to use for the actual compression.""" + backend_config: Any = None + """Configuration specific to the backend.""" + gzip_fallback: bool = True + """Use GZIP as a fallback if the provided backend is not supported by the client.""" + + def __post_init__(self) -> None: + if self.minimum_size <= 0: + raise ImproperlyConfiguredException("minimum_size must be greater than 0") + + if self.backend == "gzip": + if self.gzip_compress_level < 0 or self.gzip_compress_level > 9: + raise ImproperlyConfiguredException("gzip_compress_level must be a value between 0 and 9") + elif self.backend == "brotli": + # Brotli is not guaranteed to be installed. + from litestar.middleware.compression.brotli_facade import BrotliCompression + + if self.brotli_quality < 0 or self.brotli_quality > 11: + raise ImproperlyConfiguredException("brotli_quality must be a value between 0 and 11") + + if self.brotli_lgwin < 10 or self.brotli_lgwin > 24: + raise ImproperlyConfiguredException("brotli_lgwin must be a value between 10 and 24") + + self.gzip_fallback = self.brotli_gzip_fallback + self.compression_facade = BrotliCompression diff --git a/venv/lib/python3.11/site-packages/litestar/config/cors.py b/venv/lib/python3.11/site-packages/litestar/config/cors.py new file mode 100644 index 0000000..d3e2ccf --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/config/cors.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from functools import cached_property +from typing import TYPE_CHECKING, Literal, Pattern + +from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS + +__all__ = ("CORSConfig",) + + +if TYPE_CHECKING: + from litestar.types import Method + + +@dataclass +class CORSConfig: + """Configuration for CORS (Cross-Origin Resource Sharing). + + To enable CORS, pass an instance of this class to the :class:`Litestar <litestar.app.Litestar>` constructor using the + 'cors_config' key. + """ + + allow_origins: list[str] = field(default_factory=lambda: ["*"]) + """List of origins that are allowed. + + Can use '*' in any component of the path, e.g. 'domain.*'. Sets the 'Access-Control-Allow-Origin' header. + """ + allow_methods: list[Literal["*"] | Method] = field(default_factory=lambda: ["*"]) + """List of allowed HTTP methods. + + Sets the 'Access-Control-Allow-Methods' header. + """ + allow_headers: list[str] = field(default_factory=lambda: ["*"]) + """List of allowed headers. + + Sets the 'Access-Control-Allow-Headers' header. + """ + allow_credentials: bool = field(default=False) + """Boolean dictating whether or not to set the 'Access-Control-Allow-Credentials' header.""" + allow_origin_regex: str | None = field(default=None) + """Regex to match origins against.""" + expose_headers: list[str] = field(default_factory=list) + """List of headers that are exposed via the 'Access-Control-Expose-Headers' header.""" + max_age: int = field(default=600) + """Response caching TTL in seconds, defaults to 600. + + Sets the 'Access-Control-Max-Age' header. + """ + + def __post_init__(self) -> None: + self.allow_headers = [v.lower() for v in self.allow_headers] + + @cached_property + def allowed_origins_regex(self) -> Pattern[str]: + """Get or create a compiled regex for allowed origins. + + Returns: + A compiled regex of the allowed path. + """ + origins = self.allow_origins + if self.allow_origin_regex: + origins.append(self.allow_origin_regex) + return re.compile("|".join([origin.replace("*.", r".*\.") for origin in origins])) + + @cached_property + def is_allow_all_origins(self) -> bool: + """Get a cached boolean flag dictating whether all origins are allowed. + + Returns: + Boolean dictating whether all origins are allowed. + """ + return "*" in self.allow_origins + + @cached_property + def is_allow_all_methods(self) -> bool: + """Get a cached boolean flag dictating whether all methods are allowed. + + Returns: + Boolean dictating whether all methods are allowed. + """ + return "*" in self.allow_methods + + @cached_property + def is_allow_all_headers(self) -> bool: + """Get a cached boolean flag dictating whether all headers are allowed. + + Returns: + Boolean dictating whether all headers are allowed. + """ + return "*" in self.allow_headers + + @cached_property + def preflight_headers(self) -> dict[str, str]: + """Get cached pre-flight headers. + + Returns: + A dictionary of headers to set on the response object. + """ + headers: dict[str, str] = {"Access-Control-Max-Age": str(self.max_age)} + if self.is_allow_all_origins: + headers["Access-Control-Allow-Origin"] = "*" + else: + headers["Vary"] = "Origin" + if self.allow_credentials: + headers["Access-Control-Allow-Credentials"] = str(self.allow_credentials).lower() + if not self.is_allow_all_headers: + headers["Access-Control-Allow-Headers"] = ", ".join( + sorted(set(self.allow_headers) | DEFAULT_ALLOWED_CORS_HEADERS) # pyright: ignore + ) + if self.allow_methods: + headers["Access-Control-Allow-Methods"] = ", ".join( + sorted( + {"DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"} + if self.is_allow_all_methods + else set(self.allow_methods) + ) + ) + return headers + + @cached_property + def simple_headers(self) -> dict[str, str]: + """Get cached simple headers. + + Returns: + A dictionary of headers to set on the response object. + """ + simple_headers = {} + if self.is_allow_all_origins: + simple_headers["Access-Control-Allow-Origin"] = "*" + if self.allow_credentials: + simple_headers["Access-Control-Allow-Credentials"] = "true" + if self.expose_headers: + simple_headers["Access-Control-Expose-Headers"] = ", ".join(sorted(set(self.expose_headers))) + return simple_headers + + def is_origin_allowed(self, origin: str) -> bool: + """Check whether a given origin is allowed. + + Args: + origin: An origin header value. + + Returns: + Boolean determining whether an origin is allowed. + """ + return bool(self.is_allow_all_origins or self.allowed_origins_regex.fullmatch(origin)) diff --git a/venv/lib/python3.11/site-packages/litestar/config/csrf.py b/venv/lib/python3.11/site-packages/litestar/config/csrf.py new file mode 100644 index 0000000..5094a5b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/config/csrf.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Literal + +__all__ = ("CSRFConfig",) + + +if TYPE_CHECKING: + from litestar.types import Method + + +@dataclass +class CSRFConfig: + """Configuration for CSRF (Cross Site Request Forgery) protection. + + To enable CSRF protection, pass an instance of this class to the :class:`Litestar <litestar.app.Litestar>` constructor using + the 'csrf_config' key. + """ + + secret: str + """A string that is used to create an HMAC to sign the CSRF token.""" + cookie_name: str = field(default="csrftoken") + """The CSRF cookie name.""" + cookie_path: str = field(default="/") + """The CSRF cookie path.""" + header_name: str = field(default="x-csrftoken") + """The header that will be expected in each request.""" + cookie_secure: bool = field(default=False) + """A boolean value indicating whether to set the ``Secure`` attribute on the cookie.""" + cookie_httponly: bool = field(default=False) + """A boolean value indicating whether to set the ``HttpOnly`` attribute on the cookie.""" + cookie_samesite: Literal["lax", "strict", "none"] = field(default="lax") + """The value to set in the ``SameSite`` attribute of the cookie.""" + cookie_domain: str | None = field(default=None) + """Specifies which hosts can receive the cookie.""" + safe_methods: set[Method] = field(default_factory=lambda: {"GET", "HEAD"}) + """A set of "safe methods" that can set the cookie.""" + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns to skip in the CSRF middleware.""" + exclude_from_csrf_key: str = "exclude_from_csrf" + """An identifier to use on routes to disable CSRF for a particular route.""" diff --git a/venv/lib/python3.11/site-packages/litestar/config/response_cache.py b/venv/lib/python3.11/site-packages/litestar/config/response_cache.py new file mode 100644 index 0000000..4f1dfe9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/config/response_cache.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, final +from urllib.parse import urlencode + +from litestar.status_codes import ( + HTTP_200_OK, + HTTP_300_MULTIPLE_CHOICES, + HTTP_301_MOVED_PERMANENTLY, + HTTP_308_PERMANENT_REDIRECT, +) + +if TYPE_CHECKING: + from litestar import Litestar + from litestar.connection import Request + from litestar.stores.base import Store + from litestar.types import CacheKeyBuilder, HTTPScope + +__all__ = ("ResponseCacheConfig", "default_cache_key_builder", "CACHE_FOREVER") + + +@final +class CACHE_FOREVER: # noqa: N801 + """Sentinel value indicating that a cached response should be stored without an expiration, explicitly skipping the + default expiration + """ + + +def default_cache_key_builder(request: Request[Any, Any, Any]) -> str: + """Given a request object, returns a cache key by combining + the request method and path with the sorted query params. + + Args: + request: request used to generate cache key. + + Returns: + A combination of url path and query parameters + """ + query_params: list[tuple[str, Any]] = list(request.query_params.dict().items()) + query_params.sort(key=lambda x: x[0]) + return request.method + request.url.path + urlencode(query_params, doseq=True) + + +def default_do_cache_predicate(_: HTTPScope, status_code: int) -> bool: + """Given a status code, returns a boolean indicating whether the response should be cached. + + Args: + _: ASGI scope. + status_code: status code of the response. + + Returns: + A boolean indicating whether the response should be cached. + """ + return HTTP_200_OK <= status_code < HTTP_300_MULTIPLE_CHOICES or status_code in ( + HTTP_301_MOVED_PERMANENTLY, + HTTP_308_PERMANENT_REDIRECT, + ) + + +@dataclass +class ResponseCacheConfig: + """Configuration for response caching. + + To enable response caching, pass an instance of this class to :class:`Litestar <.app.Litestar>` using the + ``response_cache_config`` key. + """ + + default_expiration: int | None = 60 + """Default cache expiration in seconds used when a route handler is configured with ``cache=True``.""" + key_builder: CacheKeyBuilder = field(default=default_cache_key_builder) + """:class:`CacheKeyBuilder <.types.CacheKeyBuilder>`. Defaults to :func:`default_cache_key_builder`.""" + store: str = "response_cache" + """Name of the :class:`Store <.stores.base.Store>` to use.""" + cache_response_filter: Callable[[HTTPScope, int], bool] = field(default=default_do_cache_predicate) + """A callable that receives connection scope and a status code, and returns a boolean indicating whether the + response should be cached.""" + + def get_store_from_app(self, app: Litestar) -> Store: + """Get the store defined in :attr:`store` from an :class:`Litestar <.app.Litestar>` instance.""" + return app.stores.get(self.store) diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__init__.py b/venv/lib/python3.11/site-packages/litestar/connection/__init__.py new file mode 100644 index 0000000..6922e79 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/__init__.py @@ -0,0 +1,37 @@ +"""Some code in this module was adapted from https://github.com/encode/starlette/blob/master/starlette/requests.py and +https://github.com/encode/starlette/blob/master/starlette/websockets.py. + +Copyright © 2018, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +from litestar.connection.base import ASGIConnection +from litestar.connection.request import Request +from litestar.connection.websocket import WebSocket + +__all__ = ("ASGIConnection", "Request", "WebSocket") diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..89cb4db --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f1eff74 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f047a18 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..49e294c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/connection/base.py b/venv/lib/python3.11/site-packages/litestar/connection/base.py new file mode 100644 index 0000000..d14c662 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/base.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast + +from litestar._parsers import parse_cookie_string, parse_query_string +from litestar.datastructures.headers import Headers +from litestar.datastructures.multi_dicts import MultiDict +from litestar.datastructures.state import State +from litestar.datastructures.url import URL, Address, make_absolute_url +from litestar.exceptions import ImproperlyConfiguredException +from litestar.types.empty import Empty +from litestar.utils.empty import value_or_default +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from typing import NoReturn + + from litestar.app import Litestar + from litestar.types import DataContainerType, EmptyType + from litestar.types.asgi_types import Message, Receive, Scope, Send + from litestar.types.protocols import Logger + +__all__ = ("ASGIConnection", "empty_receive", "empty_send") + +UserT = TypeVar("UserT") +AuthT = TypeVar("AuthT") +HandlerT = TypeVar("HandlerT") +StateT = TypeVar("StateT", bound=State) + + +async def empty_receive() -> NoReturn: # pragma: no cover + """Raise a ``RuntimeError``. + + Serves as a placeholder ``send`` function. + + Raises: + RuntimeError + """ + raise RuntimeError() + + +async def empty_send(_: Message) -> NoReturn: # pragma: no cover + """Raise a ``RuntimeError``. + + Serves as a placeholder ``send`` function. + + Args: + _: An ASGI message + + Raises: + RuntimeError + """ + raise RuntimeError() + + +class ASGIConnection(Generic[HandlerT, UserT, AuthT, StateT]): + """The base ASGI connection container.""" + + __slots__ = ( + "scope", + "receive", + "send", + "_base_url", + "_url", + "_parsed_query", + "_cookies", + "_server_extensions", + "_connection_state", + ) + + scope: Scope + """The ASGI scope attached to the connection.""" + receive: Receive + """The ASGI receive function.""" + send: Send + """The ASGI send function.""" + + def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None: + """Initialize ``ASGIConnection``. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + """ + self.scope = scope + self.receive = receive + self.send = send + self._connection_state = ScopeState.from_scope(scope) + self._base_url: URL | EmptyType = Empty + self._url: URL | EmptyType = Empty + self._parsed_query: tuple[tuple[str, str], ...] | EmptyType = Empty + self._cookies: dict[str, str] | EmptyType = Empty + self._server_extensions = scope.get("extensions") or {} # extensions may be None + + @property + def app(self) -> Litestar: + """Return the ``app`` for this connection. + + Returns: + The :class:`Litestar <litestar.app.Litestar>` application instance + """ + return self.scope["app"] + + @property + def route_handler(self) -> HandlerT: + """Return the ``route_handler`` for this connection. + + Returns: + The target route handler instance. + """ + return cast("HandlerT", self.scope["route_handler"]) + + @property + def state(self) -> StateT: + """Return the ``State`` of this connection. + + Returns: + A State instance constructed from the scope["state"] value. + """ + return cast("StateT", State(self.scope.get("state"))) + + @property + def url(self) -> URL: + """Return the URL of this connection's ``Scope``. + + Returns: + A URL instance constructed from the request's scope. + """ + if self._url is Empty: + if (url := self._connection_state.url) is not Empty: + self._url = url + else: + self._connection_state.url = self._url = URL.from_scope(self.scope) + + return self._url + + @property + def base_url(self) -> URL: + """Return the base URL of this connection's ``Scope``. + + Returns: + A URL instance constructed from the request's scope, representing only the base part + (host + domain + prefix) of the request. + """ + if self._base_url is Empty: + if (base_url := self._connection_state.base_url) is not Empty: + self._base_url = base_url + else: + scope = cast( + "Scope", + { + **self.scope, + "path": "/", + "query_string": b"", + "root_path": self.scope.get("app_root_path") or self.scope.get("root_path", ""), + }, + ) + self._connection_state.base_url = self._base_url = URL.from_scope(scope) + return self._base_url + + @property + def headers(self) -> Headers: + """Return the headers of this connection's ``Scope``. + + Returns: + A Headers instance with the request's scope["headers"] value. + """ + return Headers.from_scope(self.scope) + + @property + def query_params(self) -> MultiDict[Any]: + """Return the query parameters of this connection's ``Scope``. + + Returns: + A normalized dict of query parameters. Multiple values for the same key are returned as a list. + """ + if self._parsed_query is Empty: + if (parsed_query := self._connection_state.parsed_query) is not Empty: + self._parsed_query = parsed_query + else: + self._connection_state.parsed_query = self._parsed_query = parse_query_string( + self.scope.get("query_string", b"") + ) + return MultiDict(self._parsed_query) + + @property + def path_params(self) -> dict[str, Any]: + """Return the ``path_params`` of this connection's ``Scope``. + + Returns: + A string keyed dictionary of path parameter values. + """ + return self.scope["path_params"] + + @property + def cookies(self) -> dict[str, str]: + """Return the ``cookies`` of this connection's ``Scope``. + + Returns: + Returns any cookies stored in the header as a parsed dictionary. + """ + if self._cookies is Empty: + if (cookies := self._connection_state.cookies) is not Empty: + self._cookies = cookies + else: + self._connection_state.cookies = self._cookies = ( + parse_cookie_string(cookie_header) if (cookie_header := self.headers.get("cookie")) else {} + ) + return self._cookies + + @property + def client(self) -> Address | None: + """Return the ``client`` data of this connection's ``Scope``. + + Returns: + A two tuple of the host name and port number. + """ + client = self.scope.get("client") + return Address(*client) if client else None + + @property + def auth(self) -> AuthT: + """Return the ``auth`` data of this connection's ``Scope``. + + Raises: + ImproperlyConfiguredException: If ``auth`` is not set in scope via an ``AuthMiddleware``, raises an exception + + Returns: + A type correlating to the generic variable Auth. + """ + if "auth" not in self.scope: + raise ImproperlyConfiguredException("'auth' is not defined in scope, install an AuthMiddleware to set it") + + return cast("AuthT", self.scope["auth"]) + + @property + def user(self) -> UserT: + """Return the ``user`` data of this connection's ``Scope``. + + Raises: + ImproperlyConfiguredException: If ``user`` is not set in scope via an ``AuthMiddleware``, raises an exception + + Returns: + A type correlating to the generic variable User. + """ + if "user" not in self.scope: + raise ImproperlyConfiguredException("'user' is not defined in scope, install an AuthMiddleware to set it") + + return cast("UserT", self.scope["user"]) + + @property + def session(self) -> dict[str, Any]: + """Return the session for this connection if a session was previously set in the ``Scope`` + + Returns: + A dictionary representing the session value - if existing. + + Raises: + ImproperlyConfiguredException: if session is not set in scope. + """ + if "session" not in self.scope: + raise ImproperlyConfiguredException( + "'session' is not defined in scope, install a SessionMiddleware to set it" + ) + + return cast("dict[str, Any]", self.scope["session"]) + + @property + def logger(self) -> Logger: + """Return the ``Logger`` instance for this connection. + + Returns: + A ``Logger`` instance. + + Raises: + ImproperlyConfiguredException: if ``log_config`` has not been passed to the Litestar constructor. + """ + return self.app.get_logger() + + def set_session(self, value: dict[str, Any] | DataContainerType | EmptyType) -> None: + """Set the session in the connection's ``Scope``. + + If the :class:`SessionMiddleware <.middleware.session.base.SessionMiddleware>` is enabled, the session will be added + to the response as a cookie header. + + Args: + value: Dictionary or pydantic model instance for the session data. + + Returns: + None + """ + self.scope["session"] = value + + def clear_session(self) -> None: + """Remove the session from the connection's ``Scope``. + + If the :class:`Litestar SessionMiddleware <.middleware.session.base.SessionMiddleware>` is enabled, this will cause + the session data to be cleared. + + Returns: + None. + """ + self.scope["session"] = Empty + self._connection_state.session_id = Empty + + def get_session_id(self) -> str | None: + return value_or_default(value=self._connection_state.session_id, default=None) + + def url_for(self, name: str, **path_parameters: Any) -> str: + """Return the url for a given route handler name. + + Args: + name: The ``name`` of the request route handler. + **path_parameters: Values for path parameters in the route + + Raises: + NoRouteMatchFoundException: If route with ``name`` does not exist, path parameters are missing or have a + wrong type. + + Returns: + A string representing the absolute url of the route handler. + """ + litestar_instance = self.scope["app"] + url_path = litestar_instance.route_reverse(name, **path_parameters) + + return make_absolute_url(url_path, self.base_url) + + def url_for_static_asset(self, name: str, file_path: str) -> str: + """Receives a static files handler name, an asset file path and returns resolved absolute url to the asset. + + Args: + name: A static handler unique name. + file_path: a string containing path to an asset. + + Raises: + NoRouteMatchFoundException: If static files handler with ``name`` does not exist. + + Returns: + A string representing absolute url to the asset. + """ + litestar_instance = self.scope["app"] + url_path = litestar_instance.url_for_static_asset(name, file_path) + + return make_absolute_url(url_path, self.base_url) diff --git a/venv/lib/python3.11/site-packages/litestar/connection/request.py b/venv/lib/python3.11/site-packages/litestar/connection/request.py new file mode 100644 index 0000000..254c315 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/request.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic + +from litestar._multipart import parse_content_header, parse_multipart_form +from litestar._parsers import parse_url_encoded_form_data +from litestar.connection.base import ( + ASGIConnection, + AuthT, + StateT, + UserT, + empty_receive, + empty_send, +) +from litestar.datastructures.headers import Accept +from litestar.datastructures.multi_dicts import FormMultiDict +from litestar.enums import ASGIExtension, RequestEncodingType +from litestar.exceptions import ( + InternalServerException, + LitestarException, + LitestarWarning, +) +from litestar.serialization import decode_json, decode_msgpack +from litestar.types import Empty + +__all__ = ("Request",) + + +if TYPE_CHECKING: + from litestar.handlers.http_handlers import HTTPRouteHandler # noqa: F401 + from litestar.types.asgi_types import HTTPScope, Method, Receive, Scope, Send + from litestar.types.empty import EmptyType + + +SERVER_PUSH_HEADERS = { + "accept", + "accept-encoding", + "accept-language", + "cache-control", + "user-agent", +} + + +class Request(Generic[UserT, AuthT, StateT], ASGIConnection["HTTPRouteHandler", UserT, AuthT, StateT]): + """The Litestar Request class.""" + + __slots__ = ( + "_json", + "_form", + "_body", + "_msgpack", + "_content_type", + "_accept", + "is_connected", + "supports_push_promise", + ) + + scope: HTTPScope # pyright: ignore + """The ASGI scope attached to the connection.""" + receive: Receive + """The ASGI receive function.""" + send: Send + """The ASGI send function.""" + + def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None: + """Initialize ``Request``. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + """ + super().__init__(scope, receive, send) + self.is_connected: bool = True + self._body: bytes | EmptyType = Empty + self._form: dict[str, str | list[str]] | EmptyType = Empty + self._json: Any = Empty + self._msgpack: Any = Empty + self._content_type: tuple[str, dict[str, str]] | EmptyType = Empty + self._accept: Accept | EmptyType = Empty + self.supports_push_promise = ASGIExtension.SERVER_PUSH in self._server_extensions + + @property + def method(self) -> Method: + """Return the request method. + + Returns: + The request :class:`Method <litestar.types.Method>` + """ + return self.scope["method"] + + @property + def content_type(self) -> tuple[str, dict[str, str]]: + """Parse the request's 'Content-Type' header, returning the header value and any options as a dictionary. + + Returns: + A tuple with the parsed value and a dictionary containing any options send in it. + """ + if self._content_type is Empty: + if (content_type := self._connection_state.content_type) is not Empty: + self._content_type = content_type + else: + self._content_type = self._connection_state.content_type = parse_content_header( + self.headers.get("Content-Type", "") + ) + return self._content_type + + @property + def accept(self) -> Accept: + """Parse the request's 'Accept' header, returning an :class:`Accept <litestar.datastructures.headers.Accept>` instance. + + Returns: + An :class:`Accept <litestar.datastructures.headers.Accept>` instance, representing the list of acceptable media types. + """ + if self._accept is Empty: + if (accept := self._connection_state.accept) is not Empty: + self._accept = accept + else: + self._accept = self._connection_state.accept = Accept(self.headers.get("Accept", "*/*")) + return self._accept + + async def json(self) -> Any: + """Retrieve the json request body from the request. + + Returns: + An arbitrary value + """ + if self._json is Empty: + if (json_ := self._connection_state.json) is not Empty: + self._json = json_ + else: + body = await self.body() + self._json = self._connection_state.json = decode_json( + body or b"null", type_decoders=self.route_handler.resolve_type_decoders() + ) + return self._json + + async def msgpack(self) -> Any: + """Retrieve the MessagePack request body from the request. + + Returns: + An arbitrary value + """ + if self._msgpack is Empty: + if (msgpack := self._connection_state.msgpack) is not Empty: + self._msgpack = msgpack + else: + body = await self.body() + self._msgpack = self._connection_state.msgpack = decode_msgpack( + body or b"\xc0", type_decoders=self.route_handler.resolve_type_decoders() + ) + return self._msgpack + + async def stream(self) -> AsyncGenerator[bytes, None]: + """Return an async generator that streams chunks of bytes. + + Returns: + An async generator. + + Raises: + RuntimeError: if the stream is already consumed + """ + if self._body is Empty: + if not self.is_connected: + raise InternalServerException("stream consumed") + while event := await self.receive(): + if event["type"] == "http.request": + if event["body"]: + yield event["body"] + + if not event.get("more_body", False): + break + + if event["type"] == "http.disconnect": + raise InternalServerException("client disconnected prematurely") + + self.is_connected = False + yield b"" + + else: + yield self._body + yield b"" + return + + async def body(self) -> bytes: + """Return the body of the request. + + Returns: + A byte-string representing the body of the request. + """ + if self._body is Empty: + if (body := self._connection_state.body) is not Empty: + self._body = body + else: + self._body = self._connection_state.body = b"".join([c async for c in self.stream()]) + return self._body + + async def form(self) -> FormMultiDict: + """Retrieve form data from the request. If the request is either a 'multipart/form-data' or an + 'application/x-www-form- urlencoded', return a FormMultiDict instance populated with the values sent in the + request, otherwise, an empty instance. + + Returns: + A FormMultiDict instance + """ + if self._form is Empty: + if (form := self._connection_state.form) is not Empty: + self._form = form + else: + content_type, options = self.content_type + if content_type == RequestEncodingType.MULTI_PART: + self._form = parse_multipart_form( + body=await self.body(), + boundary=options.get("boundary", "").encode(), + multipart_form_part_limit=self.app.multipart_form_part_limit, + ) + elif content_type == RequestEncodingType.URL_ENCODED: + self._form = parse_url_encoded_form_data( + await self.body(), + ) + else: + self._form = {} + + self._connection_state.form = self._form + + return FormMultiDict(self._form) + + async def send_push_promise(self, path: str, raise_if_unavailable: bool = False) -> None: + """Send a push promise. + + This method requires the `http.response.push` extension to be sent from the ASGI server. + + Args: + path: Path to send the promise to. + raise_if_unavailable: Raise an exception if server push is not supported by + the server + + Returns: + None + """ + if not self.supports_push_promise: + if raise_if_unavailable: + raise LitestarException("Attempted to send a push promise but the server does not support it") + + warnings.warn( + "Attempted to send a push promise but the server does not support it. In a future version, this will " + "raise an exception. To enable this behaviour in the current version, set raise_if_unavailable=True. " + "To prevent this behaviour, make sure that the server you are using supports the 'http.response.push' " + "ASGI extension, or check this dynamically via " + ":attr:`~litestar.connection.Request.supports_push_promise`", + stacklevel=2, + category=LitestarWarning, + ) + + return + + raw_headers = [ + (header_name.encode("latin-1"), value.encode("latin-1")) + for header_name in (self.headers.keys() & SERVER_PUSH_HEADERS) + for value in self.headers.getall(header_name, []) + ] + await self.send({"type": "http.response.push", "path": path, "headers": raw_headers}) diff --git a/venv/lib/python3.11/site-packages/litestar/connection/websocket.py b/venv/lib/python3.11/site-packages/litestar/connection/websocket.py new file mode 100644 index 0000000..0c7bc04 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/websocket.py @@ -0,0 +1,343 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic, Literal, cast, overload + +from litestar.connection.base import ( + ASGIConnection, + AuthT, + StateT, + UserT, + empty_receive, + empty_send, +) +from litestar.datastructures.headers import Headers +from litestar.exceptions import WebSocketDisconnect +from litestar.serialization import decode_json, decode_msgpack, default_serializer, encode_json, encode_msgpack +from litestar.status_codes import WS_1000_NORMAL_CLOSURE + +__all__ = ("WebSocket",) + + +if TYPE_CHECKING: + from litestar.handlers.websocket_handlers import WebsocketRouteHandler # noqa: F401 + from litestar.types import Message, Serializer, WebSocketScope + from litestar.types.asgi_types import ( + Receive, + ReceiveMessage, + Scope, + Send, + WebSocketAcceptEvent, + WebSocketCloseEvent, + WebSocketDisconnectEvent, + WebSocketMode, + WebSocketReceiveEvent, + WebSocketSendEvent, + ) + +DISCONNECT_MESSAGE = "connection is disconnected" + + +class WebSocket(Generic[UserT, AuthT, StateT], ASGIConnection["WebsocketRouteHandler", UserT, AuthT, StateT]): + """The Litestar WebSocket class.""" + + __slots__ = ("connection_state",) + + scope: WebSocketScope # pyright: ignore + """The ASGI scope attached to the connection.""" + receive: Receive + """The ASGI receive function.""" + send: Send + """The ASGI send function.""" + + def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None: + """Initialize ``WebSocket``. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + """ + super().__init__(scope, self.receive_wrapper(receive), self.send_wrapper(send)) + self.connection_state: Literal["init", "connect", "receive", "disconnect"] = "init" + + def receive_wrapper(self, receive: Receive) -> Receive: + """Wrap ``receive`` to set 'self.connection_state' and validate events. + + Args: + receive: The ASGI receive function. + + Returns: + An ASGI receive function. + """ + + async def wrapped_receive() -> ReceiveMessage: + if self.connection_state == "disconnect": + raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) + message = await receive() + if message["type"] == "websocket.connect": + self.connection_state = "connect" + elif message["type"] == "websocket.receive": + self.connection_state = "receive" + else: + self.connection_state = "disconnect" + return message + + return wrapped_receive + + def send_wrapper(self, send: Send) -> Send: + """Wrap ``send`` to ensure that state is not disconnected. + + Args: + send: The ASGI send function. + + Returns: + An ASGI send function. + """ + + async def wrapped_send(message: Message) -> None: + if self.connection_state == "disconnect": + raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) # pragma: no cover + await send(message) + + return wrapped_send + + async def accept( + self, + subprotocols: str | None = None, + headers: Headers | dict[str, Any] | list[tuple[bytes, bytes]] | None = None, + ) -> None: + """Accept the incoming connection. This method should be called before receiving data. + + Args: + subprotocols: Websocket sub-protocol to use. + headers: Headers to set on the data sent. + + Returns: + None + """ + if self.connection_state == "init": + await self.receive() + _headers: list[tuple[bytes, bytes]] = headers if isinstance(headers, list) else [] + + if isinstance(headers, dict): + _headers = Headers(headers=headers).to_header_list() + + if isinstance(headers, Headers): + _headers = headers.to_header_list() + + event: WebSocketAcceptEvent = { + "type": "websocket.accept", + "subprotocol": subprotocols, + "headers": _headers, + } + await self.send(event) + + async def close(self, code: int = WS_1000_NORMAL_CLOSURE, reason: str | None = None) -> None: + """Send an 'websocket.close' event. + + Args: + code: Status code. + reason: Reason for closing the connection + + Returns: + None + """ + event: WebSocketCloseEvent = {"type": "websocket.close", "code": code, "reason": reason or ""} + await self.send(event) + + @overload + async def receive_data(self, mode: Literal["text"]) -> str: ... + + @overload + async def receive_data(self, mode: Literal["binary"]) -> bytes: ... + + async def receive_data(self, mode: WebSocketMode) -> str | bytes: + """Receive an 'websocket.receive' event and returns the data stored on it. + + Args: + mode: The respective event key to use. + + Returns: + The event's data. + """ + if self.connection_state == "init": + await self.accept() + event = cast("WebSocketReceiveEvent | WebSocketDisconnectEvent", await self.receive()) + if event["type"] == "websocket.disconnect": + raise WebSocketDisconnect(detail="disconnect event", code=event["code"]) + return event.get("text") or "" if mode == "text" else event.get("bytes") or b"" + + @overload + def iter_data(self, mode: Literal["text"]) -> AsyncGenerator[str, None]: ... + + @overload + def iter_data(self, mode: Literal["binary"]) -> AsyncGenerator[bytes, None]: ... + + async def iter_data(self, mode: WebSocketMode = "text") -> AsyncGenerator[str | bytes, None]: + """Continuously receive data and yield it + + Args: + mode: Socket mode to use. Either ``text`` or ``binary`` + """ + try: + while True: + yield await self.receive_data(mode) + except WebSocketDisconnect: + pass + + async def receive_text(self) -> str: + """Receive data as text. + + Returns: + A string. + """ + return await self.receive_data(mode="text") + + async def receive_bytes(self) -> bytes: + """Receive data as bytes. + + Returns: + A byte-string. + """ + return await self.receive_data(mode="binary") + + async def receive_json(self, mode: WebSocketMode = "text") -> Any: + """Receive data and decode it as JSON. + + Args: + mode: Either ``text`` or ``binary``. + + Returns: + An arbitrary value + """ + data = await self.receive_data(mode=mode) + return decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + + async def receive_msgpack(self) -> Any: + """Receive data and decode it as MessagePack. + + Note that since MessagePack is a binary format, this method will always receive + data in ``binary`` mode. + + Returns: + An arbitrary value + """ + data = await self.receive_data(mode="binary") + return decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + + async def iter_json(self, mode: WebSocketMode = "text") -> AsyncGenerator[Any, None]: + """Continuously receive data and yield it, decoding it as JSON in the process. + + Args: + mode: Socket mode to use. Either ``text`` or ``binary`` + """ + async for data in self.iter_data(mode): + yield decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + + async def iter_msgpack(self) -> AsyncGenerator[Any, None]: + """Continuously receive data and yield it, decoding it as MessagePack in the + process. + + Note that since MessagePack is a binary format, this method will always receive + data in ``binary`` mode. + + """ + async for data in self.iter_data(mode="binary"): + yield decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + + async def send_data(self, data: str | bytes, mode: WebSocketMode = "text", encoding: str = "utf-8") -> None: + """Send a 'websocket.send' event. + + Args: + data: Data to send. + mode: The respective event key to use. + encoding: Encoding to use when converting bytes / str. + + Returns: + None + """ + if self.connection_state == "init": # pragma: no cover + await self.accept() + event: WebSocketSendEvent = {"type": "websocket.send", "bytes": None, "text": None} + if mode == "binary": + event["bytes"] = data if isinstance(data, bytes) else data.encode(encoding) + else: + event["text"] = data if isinstance(data, str) else data.decode(encoding) + await self.send(event) + + @overload + async def send_text(self, data: bytes, encoding: str = "utf-8") -> None: ... + + @overload + async def send_text(self, data: str) -> None: ... + + async def send_text(self, data: str | bytes, encoding: str = "utf-8") -> None: + """Send data using the ``text`` key of the send event. + + Args: + data: Data to send + encoding: Encoding to use for binary data. + + Returns: + None + """ + await self.send_data(data=data, encoding=encoding) + + @overload + async def send_bytes(self, data: bytes) -> None: ... + + @overload + async def send_bytes(self, data: str, encoding: str = "utf-8") -> None: ... + + async def send_bytes(self, data: str | bytes, encoding: str = "utf-8") -> None: + """Send data using the ``bytes`` key of the send event. + + Args: + data: Data to send + encoding: Encoding to use for binary data. + + Returns: + None + """ + await self.send_data(data=data, mode="binary", encoding=encoding) + + async def send_json( + self, + data: Any, + mode: WebSocketMode = "text", + encoding: str = "utf-8", + serializer: Serializer = default_serializer, + ) -> None: + """Send data as JSON. + + Args: + data: A value to serialize. + mode: Either ``text`` or ``binary``. + encoding: Encoding to use for binary data. + serializer: A serializer function. + + Returns: + None + """ + await self.send_data(data=encode_json(data, serializer), mode=mode, encoding=encoding) + + async def send_msgpack( + self, + data: Any, + encoding: str = "utf-8", + serializer: Serializer = default_serializer, + ) -> None: + """Send data as MessagePack. + + Note that since MessagePack is a binary format, this method will always send + data in ``binary`` mode. + + Args: + data: A value to serialize. + encoding: Encoding to use for binary data. + serializer: A serializer function. + + Returns: + None + """ + await self.send_data(data=encode_msgpack(data, serializer), mode="binary", encoding=encoding) diff --git a/venv/lib/python3.11/site-packages/litestar/constants.py b/venv/lib/python3.11/site-packages/litestar/constants.py new file mode 100644 index 0000000..930296c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/constants.py @@ -0,0 +1,57 @@ +from dataclasses import MISSING +from inspect import Signature +from typing import Any, Final + +from msgspec import UnsetType + +from litestar.enums import MediaType +from litestar.types import Empty +from litestar.utils.deprecation import warn_deprecation + +DEFAULT_ALLOWED_CORS_HEADERS: Final = {"Accept", "Accept-Language", "Content-Language", "Content-Type"} +DEFAULT_CHUNK_SIZE: Final = 1024 * 128 # 128KB +HTTP_DISCONNECT: Final = "http.disconnect" +HTTP_RESPONSE_BODY: Final = "http.response.body" +HTTP_RESPONSE_START: Final = "http.response.start" +ONE_MEGABYTE: Final = 1024 * 1024 +OPENAPI_NOT_INITIALIZED: Final = "Litestar has not been instantiated with OpenAPIConfig" +REDIRECT_STATUS_CODES: Final = {301, 302, 303, 307, 308} +REDIRECT_ALLOWED_MEDIA_TYPES: Final = {MediaType.TEXT, MediaType.HTML, MediaType.JSON} +RESERVED_KWARGS: Final = {"state", "headers", "cookies", "request", "socket", "data", "query", "scope", "body"} +SKIP_VALIDATION_NAMES: Final = {"request", "socket", "scope", "receive", "send"} +UNDEFINED_SENTINELS: Final = {Signature.empty, Empty, Ellipsis, MISSING, UnsetType} +WEBSOCKET_CLOSE: Final = "websocket.close" +WEBSOCKET_DISCONNECT: Final = "websocket.disconnect" + + +# deprecated constants +_SCOPE_STATE_CSRF_TOKEN_KEY = "csrf_token" # noqa: S105 # possible hardcoded password +_SCOPE_STATE_DEPENDENCY_CACHE: Final = "dependency_cache" +_SCOPE_STATE_NAMESPACE: Final = "__litestar__" +_SCOPE_STATE_RESPONSE_COMPRESSED: Final = "response_compressed" +_SCOPE_STATE_DO_CACHE: Final = "do_cache" +_SCOPE_STATE_IS_CACHED: Final = "is_cached" + +_deprecated_names = { + "SCOPE_STATE_CSRF_TOKEN_KEY": _SCOPE_STATE_CSRF_TOKEN_KEY, + "SCOPE_STATE_DEPENDENCY_CACHE": _SCOPE_STATE_DEPENDENCY_CACHE, + "SCOPE_STATE_NAMESPACE": _SCOPE_STATE_NAMESPACE, + "SCOPE_STATE_RESPONSE_COMPRESSED": _SCOPE_STATE_RESPONSE_COMPRESSED, + "SCOPE_STATE_DO_CACHE": _SCOPE_STATE_DO_CACHE, + "SCOPE_STATE_IS_CACHED": _SCOPE_STATE_IS_CACHED, +} + + +def __getattr__(name: str) -> Any: + if name in _deprecated_names: + warn_deprecation( + deprecated_name=f"litestar.constants.{name}", + version="2.4", + kind="import", + removal_in="3.0", + info=f"'{name}' from 'litestar.constants' is deprecated and will be removed in 3.0. " + "Direct access to Litestar scope state is not recommended.", + ) + + return globals()["_deprecated_names"][name] + raise AttributeError(f"module {__name__} has no attribute {name}") # pragma: no cover diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/__init__.py b/venv/lib/python3.11/site-packages/litestar/contrib/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/__init__.py diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..fc2f5bc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/jinja.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/jinja.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f58d015 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/jinja.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/mako.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/mako.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..09bede9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/mako.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/minijinja.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/minijinja.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..05ada30 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/minijinja.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/minijnja.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/minijnja.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..3006e82 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/minijnja.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/piccolo.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/piccolo.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..20ea290 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/__pycache__/piccolo.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/attrs/__init__.py b/venv/lib/python3.11/site-packages/litestar/contrib/attrs/__init__.py new file mode 100644 index 0000000..ddd2a3f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/attrs/__init__.py @@ -0,0 +1,3 @@ +from .attrs_schema_plugin import AttrsSchemaPlugin + +__all__ = ("AttrsSchemaPlugin",) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/attrs/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/attrs/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a224be6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/attrs/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/attrs/__pycache__/attrs_schema_plugin.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/attrs/__pycache__/attrs_schema_plugin.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..730252a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/attrs/__pycache__/attrs_schema_plugin.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/attrs/attrs_schema_plugin.py b/venv/lib/python3.11/site-packages/litestar/contrib/attrs/attrs_schema_plugin.py new file mode 100644 index 0000000..cf67fe4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/attrs/attrs_schema_plugin.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from litestar.exceptions import MissingDependencyException +from litestar.plugins import OpenAPISchemaPluginProtocol +from litestar.typing import FieldDefinition +from litestar.utils import is_attrs_class, is_optional_union + +try: + import attr + import attrs +except ImportError as e: + raise MissingDependencyException("attrs") from e + +if TYPE_CHECKING: + from litestar._openapi.schema_generation import SchemaCreator + from litestar.openapi.spec import Schema + + +class AttrsSchemaPlugin(OpenAPISchemaPluginProtocol): + @staticmethod + def is_plugin_supported_type(value: Any) -> bool: + return is_attrs_class(value) or is_attrs_class(type(value)) + + def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema: + """Given a type annotation, transform it into an OpenAPI schema class. + + Args: + field_definition: FieldDefinition instance. + schema_creator: An instance of the schema creator class + + Returns: + An :class:`OpenAPI <litestar.openapi.spec.schema.Schema>` instance. + """ + + type_hints = field_definition.get_type_hints(include_extras=True, resolve_generics=True) + attr_fields = attr.fields_dict(field_definition.type_) + return schema_creator.create_component_schema( + field_definition, + required=sorted( + field_name + for field_name, attribute in attr_fields.items() + if attribute.default is attrs.NOTHING and not is_optional_union(type_hints[field_name]) + ), + property_fields={ + field_name: FieldDefinition.from_kwarg(type_hints[field_name], field_name) for field_name in attr_fields + }, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__init__.py b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__init__.py diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..31d4982 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/_utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/_utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..d860774 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/_utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/request.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/request.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..65b99d9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/request.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/response.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/response.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0bb64b8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/response.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0af7128 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/__pycache__/types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/htmx/_utils.py b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/_utils.py new file mode 100644 index 0000000..894fd25 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/_utils.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, cast +from urllib.parse import quote + +from litestar.exceptions import ImproperlyConfiguredException +from litestar.serialization import encode_json + +__all__ = ( + "HTMXHeaders", + "get_headers", + "get_location_headers", + "get_push_url_header", + "get_redirect_header", + "get_refresh_header", + "get_replace_url_header", + "get_reswap_header", + "get_retarget_header", + "get_trigger_event_headers", +) + + +if TYPE_CHECKING: + from litestar.contrib.htmx.types import ( + EventAfterType, + HtmxHeaderType, + LocationType, + PushUrlType, + ReSwapMethod, + TriggerEventType, + ) + +HTMX_STOP_POLLING = 286 + + +class HTMXHeaders(str, Enum): + """Enum for HTMX Headers""" + + REDIRECT = "HX-Redirect" + REFRESH = "HX-Refresh" + PUSH_URL = "HX-Push-Url" + REPLACE_URL = "HX-Replace-Url" + RE_SWAP = "HX-Reswap" + RE_TARGET = "HX-Retarget" + LOCATION = "HX-Location" + + TRIGGER_EVENT = "HX-Trigger" + TRIGGER_AFTER_SETTLE = "HX-Trigger-After-Settle" + TRIGGER_AFTER_SWAP = "HX-Trigger-After-Swap" + + REQUEST = "HX-Request" + BOOSTED = "HX-Boosted" + CURRENT_URL = "HX-Current-URL" + HISTORY_RESTORE_REQUEST = "HX-History-Restore-Request" + PROMPT = "HX-Prompt" + TARGET = "HX-Target" + TRIGGER_ID = "HX-Trigger" # noqa: PIE796 + TRIGGER_NAME = "HX-Trigger-Name" + TRIGGERING_EVENT = "Triggering-Event" + + +def get_trigger_event_headers(trigger_event: TriggerEventType) -> dict[str, Any]: + """Return headers for trigger event response.""" + after_params: dict[EventAfterType, str] = { + "receive": HTMXHeaders.TRIGGER_EVENT.value, + "settle": HTMXHeaders.TRIGGER_AFTER_SETTLE.value, + "swap": HTMXHeaders.TRIGGER_AFTER_SWAP.value, + } + + if trigger_header := after_params.get(trigger_event["after"]): + return {trigger_header: encode_json({trigger_event["name"]: trigger_event["params"] or {}}).decode()} + + raise ImproperlyConfiguredException( + "invalid value for 'after' param- allowed values are 'receive', 'settle' or 'swap'." + ) + + +def get_redirect_header(url: str) -> dict[str, Any]: + """Return headers for redirect response.""" + return {HTMXHeaders.REDIRECT.value: quote(url, safe="/#%[]=:;$&()+,!?*@'~"), "Location": ""} + + +def get_push_url_header(url: PushUrlType) -> dict[str, Any]: + """Return headers for push url to browser history response.""" + if isinstance(url, str): + url = url if url != "False" else "false" + elif isinstance(url, bool): + url = "false" + + return {HTMXHeaders.PUSH_URL.value: url} + + +def get_replace_url_header(url: PushUrlType) -> dict[str, Any]: + """Return headers for replace url in browser tab response.""" + url = (url if url != "False" else "false") if isinstance(url, str) else "false" + return {HTMXHeaders.REPLACE_URL: url} + + +def get_refresh_header(refresh: bool) -> dict[str, Any]: + """Return headers for client refresh response.""" + return {HTMXHeaders.REFRESH.value: "true" if refresh else ""} + + +def get_reswap_header(method: ReSwapMethod) -> dict[str, Any]: + """Return headers for change swap method response.""" + return {HTMXHeaders.RE_SWAP.value: method} + + +def get_retarget_header(target: str) -> dict[str, Any]: + """Return headers for change target element response.""" + return {HTMXHeaders.RE_TARGET.value: target} + + +def get_location_headers(location: LocationType) -> dict[str, Any]: + """Return headers for redirect without page-reload response.""" + if spec := {key: value for key, value in location.items() if value}: + return {HTMXHeaders.LOCATION.value: encode_json(spec).decode()} + raise ValueError("redirect_to is required parameter.") + + +def get_headers(hx_headers: HtmxHeaderType) -> dict[str, Any]: + """Return headers for HTMX responses.""" + if not hx_headers: + raise ValueError("Value for hx_headers cannot be None.") + htmx_headers_dict: dict[str, Callable] = { + "redirect": get_redirect_header, + "refresh": get_refresh_header, + "push_url": get_push_url_header, + "replace_url": get_replace_url_header, + "re_swap": get_reswap_header, + "re_target": get_retarget_header, + "trigger_event": get_trigger_event_headers, + "location": get_location_headers, + } + + header: dict[str, Any] = {} + response: dict[str, Any] + key: str + value: Any + + for key, value in hx_headers.items(): + if key in ["redirect", "refresh", "location", "replace_url"]: + return cast("dict[str, Any]", htmx_headers_dict[key](value)) + if value is not None: + response = htmx_headers_dict[key](value) + header.update(response) + return header diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/htmx/request.py b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/request.py new file mode 100644 index 0000000..b4fad18 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/request.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from contextlib import suppress +from functools import cached_property +from typing import TYPE_CHECKING, Any +from urllib.parse import unquote, urlsplit, urlunsplit + +from litestar import Request +from litestar.connection.base import empty_receive, empty_send +from litestar.contrib.htmx._utils import HTMXHeaders +from litestar.exceptions import SerializationException +from litestar.serialization import decode_json + +__all__ = ("HTMXDetails", "HTMXRequest") + + +if TYPE_CHECKING: + from litestar.types import Receive, Scope, Send + + +class HTMXDetails: + """HTMXDetails holds all the values sent by HTMX client in headers and provide convenient properties.""" + + def __init__(self, request: Request) -> None: + """Initialize :class:`HTMXDetails`""" + self.request = request + + def _get_header_value(self, name: HTMXHeaders) -> str | None: + """Parse request header + + Check for uri encoded header and unquotes it in readable format. + """ + + if value := self.request.headers.get(name.value.lower()): + is_uri_encoded = self.request.headers.get(f"{name.value.lower()}-uri-autoencoded") == "true" + return unquote(value) if is_uri_encoded else value + return None + + def __bool__(self) -> bool: + """Check if request is sent by an HTMX client.""" + return self._get_header_value(HTMXHeaders.REQUEST) == "true" + + @cached_property + def boosted(self) -> bool: + """Check if request is boosted.""" + return self._get_header_value(HTMXHeaders.BOOSTED) == "true" + + @cached_property + def current_url(self) -> str | None: + """Current url value sent by HTMX client.""" + return self._get_header_value(HTMXHeaders.CURRENT_URL) + + @cached_property + def current_url_abs_path(self) -> str | None: + """Current url abs path value, to get query and path parameter sent by HTMX client.""" + if self.current_url: + split = urlsplit(self.current_url) + if split.scheme == self.request.scope["scheme"] and split.netloc == self.request.headers.get("host"): + return str(urlunsplit(split._replace(scheme="", netloc=""))) + return None + return self.current_url + + @cached_property + def history_restore_request(self) -> bool: + """If True then, request is for history restoration after a miss in the local history cache.""" + return self._get_header_value(HTMXHeaders.HISTORY_RESTORE_REQUEST) == "true" + + @cached_property + def prompt(self) -> str | None: + """User Response to prompt. + + .. code-block:: html + + <button hx-delete="/account" hx-prompt="Enter your account name to confirm deletion">Delete My Account</button> + """ + return self._get_header_value(HTMXHeaders.PROMPT) + + @cached_property + def target(self) -> str | None: + """ID of the target element if provided on the element.""" + return self._get_header_value(HTMXHeaders.TARGET) + + @cached_property + def trigger(self) -> str | None: + """ID of the triggered element if provided on the element.""" + return self._get_header_value(HTMXHeaders.TRIGGER_ID) + + @cached_property + def trigger_name(self) -> str | None: + """Name of the triggered element if provided on the element.""" + return self._get_header_value(HTMXHeaders.TRIGGER_NAME) + + @cached_property + def triggering_event(self) -> Any: + """Name of the triggered event. + + This value is added by ``event-header`` extension of HTMX to the ``Triggering-Event`` header to requests. + """ + if value := self._get_header_value(HTMXHeaders.TRIGGERING_EVENT): + with suppress(SerializationException): + return decode_json(value=value, type_decoders=self.request.route_handler.resolve_type_decoders()) + return None + + +class HTMXRequest(Request): + """HTMX Request class to work with HTMX client.""" + + __slots__ = ("htmx",) + + def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None: + """Initialize :class:`HTMXRequest`""" + super().__init__(scope=scope, receive=receive, send=send) + self.htmx = HTMXDetails(self) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/htmx/response.py b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/response.py new file mode 100644 index 0000000..0a56e1f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/response.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +from typing import Any, Generic, TypeVar +from urllib.parse import quote + +from litestar import Response +from litestar.contrib.htmx._utils import HTMX_STOP_POLLING, get_headers +from litestar.contrib.htmx.types import ( + EventAfterType, + HtmxHeaderType, + LocationType, + PushUrlType, + ReSwapMethod, + TriggerEventType, +) +from litestar.response import Template +from litestar.status_codes import HTTP_200_OK + +__all__ = ( + "ClientRedirect", + "ClientRefresh", + "HTMXTemplate", + "HXLocation", + "HXStopPolling", + "PushUrl", + "ReplaceUrl", + "Reswap", + "Retarget", + "TriggerEvent", +) + + +# HTMX defined HTTP status code. +# Response carrying this status code will ask client to stop Polling. +T = TypeVar("T") + + +class HXStopPolling(Response): + """Stop HTMX client from Polling.""" + + def __init__(self) -> None: + """Initialize""" + super().__init__(content=None) + self.status_code = HTMX_STOP_POLLING + + +class ClientRedirect(Response): + """HTMX Response class to support client side redirect.""" + + def __init__(self, redirect_to: str) -> None: + """Set status code to 200 (required by HTMX), and pass redirect url.""" + super().__init__(content=None, headers=get_headers(hx_headers=HtmxHeaderType(redirect=redirect_to))) + del self.headers["Location"] + + +class ClientRefresh(Response): + """Response to support HTMX client page refresh""" + + def __init__(self) -> None: + """Set Status code to 200 and set headers.""" + super().__init__(content=None, headers=get_headers(hx_headers=HtmxHeaderType(refresh=True))) + + +class PushUrl(Generic[T], Response[T]): + """Response to push new url into the history stack.""" + + def __init__(self, content: T, push_url: PushUrlType, **kwargs: Any) -> None: + """Initialize PushUrl.""" + super().__init__( + content=content, + status_code=HTTP_200_OK, + headers=get_headers(hx_headers=HtmxHeaderType(push_url=push_url)), + **kwargs, + ) + + +class ReplaceUrl(Generic[T], Response[T]): + """Response to replace url in the Browser Location bar.""" + + def __init__(self, content: T, replace_url: PushUrlType, **kwargs: Any) -> None: + """Initialize ReplaceUrl.""" + super().__init__( + content=content, + status_code=HTTP_200_OK, + headers=get_headers(hx_headers=HtmxHeaderType(replace_url=replace_url)), + **kwargs, + ) + + +class Reswap(Generic[T], Response[T]): + """Response to specify how the response will be swapped.""" + + def __init__( + self, + content: T, + method: ReSwapMethod, + **kwargs: Any, + ) -> None: + """Initialize Reswap.""" + super().__init__(content=content, headers=get_headers(hx_headers=HtmxHeaderType(re_swap=method)), **kwargs) + + +class Retarget(Generic[T], Response[T]): + """Response to target different element on the page.""" + + def __init__(self, content: T, target: str, **kwargs: Any) -> None: + """Initialize Retarget.""" + super().__init__(content=content, headers=get_headers(hx_headers=HtmxHeaderType(re_target=target)), **kwargs) + + +class TriggerEvent(Generic[T], Response[T]): + """Trigger Client side event.""" + + def __init__( + self, + content: T, + name: str, + after: EventAfterType, + params: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Initialize TriggerEvent.""" + event = TriggerEventType(name=name, params=params, after=after) + headers = get_headers(hx_headers=HtmxHeaderType(trigger_event=event)) + super().__init__(content=content, headers=headers, **kwargs) + + +class HXLocation(Response): + """Client side redirect without full page reload.""" + + def __init__( + self, + redirect_to: str, + source: str | None = None, + event: str | None = None, + target: str | None = None, + swap: ReSwapMethod | None = None, + hx_headers: dict[str, Any] | None = None, + values: dict[str, str] | None = None, + **kwargs: Any, + ) -> None: + """Initialize HXLocation, Set status code to 200 (required by HTMX), + and pass redirect url. + """ + super().__init__( + content=None, + headers={"Location": quote(redirect_to, safe="/#%[]=:;$&()+,!?*@'~")}, + **kwargs, + ) + spec: dict[str, Any] = get_headers( + hx_headers=HtmxHeaderType( + location=LocationType( + path=str(self.headers.get("Location")), + source=source, + event=event, + target=target, + swap=swap, + values=values, + hx_headers=hx_headers, + ) + ) + ) + del self.headers["Location"] + self.headers.update(spec) + + +class HTMXTemplate(Template): + """HTMX template wrapper""" + + def __init__( + self, + push_url: PushUrlType | None = None, + re_swap: ReSwapMethod | None = None, + re_target: str | None = None, + trigger_event: str | None = None, + params: dict[str, Any] | None = None, + after: EventAfterType | None = None, + **kwargs: Any, + ) -> None: + """Create HTMXTemplate response. + + Args: + push_url: Either a string value specifying a URL to push to browser history or ``False`` to prevent HTMX client from + pushing a url to browser history. + re_swap: Method value to instruct HTMX which swapping method to use. + re_target: Value for 'id of target element' to apply changes to. + trigger_event: Event name to trigger. + params: Dictionary of parameters if any required with trigger event parameter. + after: Changes to apply after ``receive``, ``settle`` or ``swap`` event. + **kwargs: Additional arguments to pass to ``Template``. + """ + super().__init__(**kwargs) + + event: TriggerEventType | None = None + if trigger_event: + event = TriggerEventType(name=str(trigger_event), params=params, after=after) + + self.headers.update( + get_headers(HtmxHeaderType(push_url=push_url, re_swap=re_swap, re_target=re_target, trigger_event=event)) + ) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/htmx/types.py b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/types.py new file mode 100644 index 0000000..aa8f9cd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/htmx/types.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, TypedDict, Union + +__all__ = ( + "HtmxHeaderType", + "LocationType", + "TriggerEventType", +) + +if TYPE_CHECKING: + from typing_extensions import Required + + +EventAfterType = Literal["receive", "settle", "swap", None] + +PushUrlType = Union[str, bool] + +ReSwapMethod = Literal[ + "innerHTML", "outerHTML", "beforebegin", "afterbegin", "beforeend", "afterend", "delete", "none", None +] + + +class LocationType(TypedDict): + """Type for HX-Location header.""" + + path: Required[str] + source: str | None + event: str | None + target: str | None + swap: ReSwapMethod | None + values: dict[str, str] | None + hx_headers: dict[str, Any] | None + + +class TriggerEventType(TypedDict): + """Type for HX-Trigger header.""" + + name: Required[str] + params: dict[str, Any] | None + after: EventAfterType | None + + +class HtmxHeaderType(TypedDict, total=False): + """Type for hx_headers parameter in get_headers().""" + + location: LocationType | None + redirect: str | None + refresh: bool + push_url: PushUrlType | None + replace_url: PushUrlType | None + re_swap: ReSwapMethod | None + re_target: str | None + trigger_event: TriggerEventType | None diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/jinja.py b/venv/lib/python3.11/site-packages/litestar/contrib/jinja.py new file mode 100644 index 0000000..4e8057b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/jinja.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping, TypeVar + +from typing_extensions import ParamSpec + +from litestar.exceptions import ImproperlyConfiguredException, MissingDependencyException, TemplateNotFoundException +from litestar.template.base import ( + TemplateCallableType, + TemplateEngineProtocol, + csrf_token, + url_for, + url_for_static_asset, +) + +try: + from jinja2 import Environment, FileSystemLoader, pass_context + from jinja2 import TemplateNotFound as JinjaTemplateNotFound +except ImportError as e: + raise MissingDependencyException("jinja2", extra="jinja") from e + +if TYPE_CHECKING: + from pathlib import Path + + from jinja2 import Template as JinjaTemplate + +__all__ = ("JinjaTemplateEngine",) + +P = ParamSpec("P") +T = TypeVar("T") + + +class JinjaTemplateEngine(TemplateEngineProtocol["JinjaTemplate", Mapping[str, Any]]): + """The engine instance.""" + + def __init__( + self, + directory: Path | list[Path] | None = None, + engine_instance: Environment | None = None, + ) -> None: + """Jinja-based TemplateEngine. + + Args: + directory: Direct path or list of directory paths from which to serve templates. + engine_instance: A jinja Environment instance. + """ + + super().__init__(directory, engine_instance) + if directory and engine_instance: + raise ImproperlyConfiguredException("You must provide either a directory or a jinja2 Environment instance.") + if directory: + loader = FileSystemLoader(searchpath=directory) + self.engine = Environment(loader=loader, autoescape=True) + elif engine_instance: + self.engine = engine_instance + self.register_template_callable(key="url_for_static_asset", template_callable=url_for_static_asset) + self.register_template_callable(key="csrf_token", template_callable=csrf_token) + self.register_template_callable(key="url_for", template_callable=url_for) + + def get_template(self, template_name: str) -> JinjaTemplate: + """Retrieve a template by matching its name (dotted path) with files in the directory or directories provided. + + Args: + template_name: A dotted path + + Returns: + JinjaTemplate instance + + Raises: + TemplateNotFoundException: if no template is found. + """ + try: + return self.engine.get_template(name=template_name) + except JinjaTemplateNotFound as exc: + raise TemplateNotFoundException(template_name=template_name) from exc + + def register_template_callable( + self, key: str, template_callable: TemplateCallableType[Mapping[str, Any], P, T] + ) -> None: + """Register a callable on the template engine. + + Args: + key: The callable key, i.e. the value to use inside the template to call the callable. + template_callable: A callable to register. + + Returns: + None + """ + self.engine.globals[key] = pass_context(template_callable) + + def render_string(self, template_string: str, context: Mapping[str, Any]) -> str: + """Render a template from a string with the given context. + + Args: + template_string: The template string to render. + context: A dictionary of variables to pass to the template. + + Returns: + The rendered template as a string. + """ + template = self.engine.from_string(template_string) + return template.render(context) + + @classmethod + def from_environment(cls, jinja_environment: Environment) -> JinjaTemplateEngine: + """Create a JinjaTemplateEngine from an existing jinja Environment instance. + + Args: + jinja_environment (jinja2.environment.Environment): A jinja Environment instance. + + Returns: + JinjaTemplateEngine instance + """ + return cls(directory=None, engine_instance=jinja_environment) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__init__.py b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__init__.py new file mode 100644 index 0000000..70a4f75 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__init__.py @@ -0,0 +1,32 @@ +from litestar.contrib.jwt.jwt_auth import ( + BaseJWTAuth, + JWTAuth, + JWTCookieAuth, + OAuth2Login, + OAuth2PasswordBearerAuth, +) +from litestar.contrib.jwt.jwt_token import Token +from litestar.contrib.jwt.middleware import ( + JWTAuthenticationMiddleware, + JWTCookieAuthenticationMiddleware, +) +from litestar.utils import warn_deprecation + +__all__ = ( + "BaseJWTAuth", + "JWTAuth", + "JWTAuthenticationMiddleware", + "JWTCookieAuth", + "JWTCookieAuthenticationMiddleware", + "OAuth2Login", + "OAuth2PasswordBearerAuth", + "Token", +) + +warn_deprecation( + deprecated_name="litestar.contrib.jwt", + version="2.3.2", + kind="import", + removal_in="3.0", + info="importing from 'litestar.contrib.jwt' is deprecated, please import from 'litestar.security.jwt' instead", +) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1515862 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__pycache__/jwt_auth.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__pycache__/jwt_auth.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..44907f6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__pycache__/jwt_auth.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__pycache__/jwt_token.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__pycache__/jwt_token.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2ffdb24 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__pycache__/jwt_token.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__pycache__/middleware.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__pycache__/middleware.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..334e3fa --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/__pycache__/middleware.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/jwt/jwt_auth.py b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/jwt_auth.py new file mode 100644 index 0000000..e8ffb48 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/jwt_auth.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from litestar.security.jwt.auth import BaseJWTAuth, JWTAuth, JWTCookieAuth, OAuth2Login, OAuth2PasswordBearerAuth + +__all__ = ("BaseJWTAuth", "JWTAuth", "JWTCookieAuth", "OAuth2Login", "OAuth2PasswordBearerAuth") diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/jwt/jwt_token.py b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/jwt_token.py new file mode 100644 index 0000000..882373f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/jwt_token.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from litestar.security.jwt.token import Token + +__all__ = ("Token",) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/jwt/middleware.py b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/middleware.py new file mode 100644 index 0000000..e0ad413 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/jwt/middleware.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from litestar.security.jwt.middleware import JWTAuthenticationMiddleware, JWTCookieAuthenticationMiddleware + +__all__ = ("JWTAuthenticationMiddleware", "JWTCookieAuthenticationMiddleware") diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/mako.py b/venv/lib/python3.11/site-packages/litestar/contrib/mako.py new file mode 100644 index 0000000..9cb4c47 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/mako.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, Any, Mapping, TypeVar + +from typing_extensions import ParamSpec + +from litestar.exceptions import ImproperlyConfiguredException, MissingDependencyException, TemplateNotFoundException +from litestar.template.base import ( + TemplateCallableType, + TemplateEngineProtocol, + TemplateProtocol, + csrf_token, + url_for, + url_for_static_asset, +) + +try: + from mako.exceptions import TemplateLookupException as MakoTemplateNotFound # type: ignore[import-untyped] + from mako.lookup import TemplateLookup # type: ignore[import-untyped] + from mako.template import Template as _MakoTemplate # type: ignore[import-untyped] +except ImportError as e: + raise MissingDependencyException("mako") from e + +if TYPE_CHECKING: + from pathlib import Path + +__all__ = ("MakoTemplate", "MakoTemplateEngine") + +P = ParamSpec("P") +T = TypeVar("T") + + +class MakoTemplate(TemplateProtocol): + """Mako template, implementing ``TemplateProtocol``""" + + def __init__(self, template: _MakoTemplate, template_callables: list[tuple[str, TemplateCallableType]]) -> None: + """Initialize a template. + + Args: + template: Base ``MakoTemplate`` used by the underlying mako-engine + template_callables: List of callables passed to the template + """ + super().__init__() + self.template = template + self.template_callables = template_callables + + def render(self, *args: Any, **kwargs: Any) -> str: + """Render a template. + + Args: + args: Positional arguments passed to the engines ``render`` function + kwargs: Keyword arguments passed to the engines ``render`` function + + Returns: + Rendered template as a string + """ + for callable_key, template_callable in self.template_callables: + kwargs_copy = {**kwargs} + kwargs[callable_key] = partial(template_callable, kwargs_copy) + + return str(self.template.render(*args, **kwargs)) + + +class MakoTemplateEngine(TemplateEngineProtocol[MakoTemplate, Mapping[str, Any]]): + """Mako-based TemplateEngine.""" + + def __init__(self, directory: Path | list[Path] | None = None, engine_instance: Any | None = None) -> None: + """Initialize template engine. + + Args: + directory: Direct path or list of directory paths from which to serve templates. + engine_instance: A mako TemplateLookup instance. + """ + super().__init__(directory, engine_instance) + if directory and engine_instance: + raise ImproperlyConfiguredException("You must provide either a directory or a mako TemplateLookup.") + if directory: + self.engine = TemplateLookup( + directories=directory if isinstance(directory, (list, tuple)) else [directory], default_filters=["h"] + ) + elif engine_instance: + self.engine = engine_instance + + self._template_callables: list[tuple[str, TemplateCallableType]] = [] + self.register_template_callable(key="url_for_static_asset", template_callable=url_for_static_asset) + self.register_template_callable(key="csrf_token", template_callable=csrf_token) + self.register_template_callable(key="url_for", template_callable=url_for) + + def get_template(self, template_name: str) -> MakoTemplate: + """Retrieve a template by matching its name (dotted path) with files in the directory or directories provided. + + Args: + template_name: A dotted path + + Returns: + MakoTemplate instance + + Raises: + TemplateNotFoundException: if no template is found. + """ + try: + return MakoTemplate( + template=self.engine.get_template(template_name), template_callables=self._template_callables + ) + except MakoTemplateNotFound as exc: + raise TemplateNotFoundException(template_name=template_name) from exc + + def register_template_callable( + self, key: str, template_callable: TemplateCallableType[Mapping[str, Any], P, T] + ) -> None: + """Register a callable on the template engine. + + Args: + key: The callable key, i.e. the value to use inside the template to call the callable. + template_callable: A callable to register. + + Returns: + None + """ + self._template_callables.append((key, template_callable)) + + def render_string(self, template_string: str, context: Mapping[str, Any]) -> str: # pyright: ignore + """Render a template from a string with the given context. + + Args: + template_string: The template string to render. + context: A dictionary of variables to pass to the template. + + Returns: + The rendered template as a string. + """ + template = _MakoTemplate(template_string) # noqa: S702 + return template.render(**context) # type: ignore[no-any-return] + + @classmethod + def from_template_lookup(cls, template_lookup: TemplateLookup) -> MakoTemplateEngine: + """Create a template engine from an existing mako TemplateLookup instance. + + Args: + template_lookup: A mako TemplateLookup instance. + + Returns: + MakoTemplateEngine instance + """ + return cls(directory=None, engine_instance=template_lookup) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/minijinja.py b/venv/lib/python3.11/site-packages/litestar/contrib/minijinja.py new file mode 100644 index 0000000..6007a18 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/minijinja.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import functools +from pathlib import Path +from typing import TYPE_CHECKING, Any, Mapping, Protocol, TypeVar, cast + +from typing_extensions import ParamSpec + +from litestar.exceptions import ImproperlyConfiguredException, MissingDependencyException, TemplateNotFoundException +from litestar.template.base import ( + TemplateCallableType, + TemplateEngineProtocol, + TemplateProtocol, + csrf_token, + url_for, + url_for_static_asset, +) +from litestar.utils.deprecation import warn_deprecation + +try: + from minijinja import Environment # type:ignore[import-untyped] + from minijinja import TemplateError as MiniJinjaTemplateNotFound +except ImportError as e: + raise MissingDependencyException("minijinja") from e + +if TYPE_CHECKING: + from typing import Callable + + C = TypeVar("C", bound="Callable") + + def pass_state(func: C) -> C: ... + +else: + from minijinja import pass_state + +__all__ = ( + "MiniJinjaTemplateEngine", + "StateProtocol", +) + +P = ParamSpec("P") +T = TypeVar("T") + + +class StateProtocol(Protocol): + auto_escape: str | None + current_block: str | None + env: Environment + name: str + + def lookup(self, key: str) -> Any | None: ... + + +def _transform_state(func: TemplateCallableType[Mapping[str, Any], P, T]) -> TemplateCallableType[StateProtocol, P, T]: + """Transform a template callable to receive a ``StateProtocol`` instance as first argument. + + This is for wrapping callables like ``url_for()`` that receive a mapping as first argument so they can be used + with minijinja which passes a ``StateProtocol`` instance as first argument. + """ + + @functools.wraps(func) + @pass_state + def wrapped(state: StateProtocol, /, *args: P.args, **kwargs: P.kwargs) -> T: + template_context = {"request": state.lookup("request"), "csrf_input": state.lookup("csrf_input")} + return func(template_context, *args, **kwargs) + + return wrapped + + +class MiniJinjaTemplate(TemplateProtocol): + """Initialize a template. + + Args: + template: Base ``MiniJinjaTemplate`` used by the underlying minijinja engine + """ + + def __init__(self, engine: Environment, template_name: str) -> None: + super().__init__() + self.engine = engine + self.template_name = template_name + + def render(self, *args: Any, **kwargs: Any) -> str: + """Render a template. + + Args: + args: Positional arguments passed to the engines ``render`` function + kwargs: Keyword arguments passed to the engines ``render`` function + + Returns: + Rendered template as a string + """ + try: + return str(self.engine.render_template(self.template_name, *args, **kwargs)) + except MiniJinjaTemplateNotFound as err: + raise TemplateNotFoundException(template_name=self.template_name) from err + + +class MiniJinjaTemplateEngine(TemplateEngineProtocol["MiniJinjaTemplate", StateProtocol]): + """The engine instance.""" + + def __init__(self, directory: Path | list[Path] | None = None, engine_instance: Environment | None = None) -> None: + """Minijinja based TemplateEngine. + + Args: + directory: Direct path or list of directory paths from which to serve templates. + engine_instance: A Minijinja Environment instance. + """ + super().__init__(directory, engine_instance) + if directory and engine_instance: + raise ImproperlyConfiguredException( + "You must provide either a directory or a minijinja Environment instance." + ) + if directory: + + def _loader(name: str) -> str: + """Load a template from a directory. + + Args: + name: The name of the template + + Returns: + The template as a string + + Raises: + TemplateNotFoundException: if no template is found. + """ + directories = directory if isinstance(directory, list) else [directory] + + for d in directories: + template_path = Path(d) / name # pyright: ignore[reportGeneralTypeIssues] + if template_path.exists(): + return template_path.read_text() + raise TemplateNotFoundException(template_name=name) + + self.engine = Environment(loader=_loader) + elif engine_instance: + self.engine = engine_instance + else: + raise ImproperlyConfiguredException( + "You must provide either a directory or a minijinja Environment instance." + ) + + self.register_template_callable("url_for", _transform_state(url_for)) + self.register_template_callable("csrf_token", _transform_state(csrf_token)) + self.register_template_callable("url_for_static_asset", _transform_state(url_for_static_asset)) + + def get_template(self, template_name: str) -> MiniJinjaTemplate: + """Retrieve a template by matching its name (dotted path) with files in the directory or directories provided. + + Args: + template_name: A dotted path + + Returns: + MiniJinjaTemplate instance + + Raises: + TemplateNotFoundException: if no template is found. + """ + return MiniJinjaTemplate(self.engine, template_name) + + def register_template_callable( + self, key: str, template_callable: TemplateCallableType[StateProtocol, P, T] + ) -> None: + """Register a callable on the template engine. + + Args: + key: The callable key, i.e. the value to use inside the template to call the callable. + template_callable: A callable to register. + + Returns: + None + """ + self.engine.add_global(key, pass_state(template_callable)) + + def render_string(self, template_string: str, context: Mapping[str, Any]) -> str: + """Render a template from a string with the given context. + + Args: + template_string: The template string to render. + context: A dictionary of variables to pass to the template. + + Returns: + The rendered template as a string. + """ + return self.engine.render_str(template_string, **context) # type: ignore[no-any-return] + + @classmethod + def from_environment(cls, minijinja_environment: Environment) -> MiniJinjaTemplateEngine: + """Create a MiniJinjaTemplateEngine from an existing minijinja Environment instance. + + Args: + minijinja_environment (Environment): A minijinja Environment instance. + + Returns: + MiniJinjaTemplateEngine instance + """ + return cls(directory=None, engine_instance=minijinja_environment) + + +@pass_state +def _minijinja_from_state(func: Callable, state: StateProtocol, *args: Any, **kwargs: Any) -> str: # pragma: no cover + template_context = {"request": state.lookup("request"), "csrf_input": state.lookup("csrf_input")} + return cast(str, func(template_context, *args, **kwargs)) + + +def __getattr__(name: str) -> Any: + if name == "minijinja_from_state": + warn_deprecation( + "2.3.0", + "minijinja_from_state", + "import", + removal_in="3.0.0", + alternative="Use a callable that receives the minijinja State object as first argument.", + ) + return _minijinja_from_state + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/minijnja.py b/venv/lib/python3.11/site-packages/litestar/contrib/minijnja.py new file mode 100644 index 0000000..13c295a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/minijnja.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import Any + +from litestar.utils.deprecation import warn_deprecation + +from . import minijinja as _minijinja + + +def __getattr__(name: str) -> Any: + warn_deprecation( + "2.3.0", + "contrib.minijnja", + "import", + removal_in="3.0.0", + alternative="contrib.minijinja", + ) + return getattr(_minijinja, name) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__init__.py b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__init__.py new file mode 100644 index 0000000..3f93611 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__init__.py @@ -0,0 +1,4 @@ +from .config import OpenTelemetryConfig +from .middleware import OpenTelemetryInstrumentationMiddleware + +__all__ = ("OpenTelemetryConfig", "OpenTelemetryInstrumentationMiddleware") diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..799a915 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__pycache__/_utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__pycache__/_utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c401f9c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__pycache__/_utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__pycache__/config.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__pycache__/config.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b41bc7e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__pycache__/config.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__pycache__/middleware.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__pycache__/middleware.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..291b510 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/__pycache__/middleware.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/_utils.py b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/_utils.py new file mode 100644 index 0000000..0ba7cb9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/_utils.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from litestar.exceptions import MissingDependencyException + +__all__ = ("get_route_details_from_scope",) + + +try: + import opentelemetry # noqa: F401 +except ImportError as e: + raise MissingDependencyException("opentelemetry") from e + +from opentelemetry.semconv.trace import SpanAttributes + +if TYPE_CHECKING: + from litestar.types import Scope + + +def get_route_details_from_scope(scope: Scope) -> tuple[str, dict[Any, str]]: + """Retrieve the span name and attributes from the ASGI scope. + + Args: + scope: The ASGI scope instance. + + Returns: + A tuple of the span name and a dict of attrs. + """ + route_handler_fn_name = scope["route_handler"].handler_name + return route_handler_fn_name, {SpanAttributes.HTTP_ROUTE: route_handler_fn_name} diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/config.py b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/config.py new file mode 100644 index 0000000..c0cce8a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/config.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable + +from litestar.contrib.opentelemetry._utils import get_route_details_from_scope +from litestar.contrib.opentelemetry.middleware import ( + OpenTelemetryInstrumentationMiddleware, +) +from litestar.exceptions import MissingDependencyException +from litestar.middleware.base import DefineMiddleware + +__all__ = ("OpenTelemetryConfig",) + + +try: + import opentelemetry # noqa: F401 +except ImportError as e: + raise MissingDependencyException("opentelemetry") from e + + +from opentelemetry.trace import Span, TracerProvider # pyright: ignore + +if TYPE_CHECKING: + from opentelemetry.metrics import Meter, MeterProvider + + from litestar.types import Scope, Scopes + +OpenTelemetryHookHandler = Callable[[Span, dict], None] + + +@dataclass +class OpenTelemetryConfig: + """Configuration class for the OpenTelemetry middleware. + + Consult the [OpenTelemetry ASGI documentation](https://opentelemetry-python-contrib.readthedocs.io/en/latest/instrumentation/asgi/asgi.html) for more info about the configuration options. + """ + + scope_span_details_extractor: Callable[[Scope], tuple[str, dict[str, Any]]] = field( + default=get_route_details_from_scope + ) + """Callback which should return a string and a tuple, representing the desired default span name and a dictionary + with any additional span attributes to set. + """ + server_request_hook_handler: OpenTelemetryHookHandler | None = field(default=None) + """Optional callback which is called with the server span and ASGI scope object for every incoming request.""" + client_request_hook_handler: OpenTelemetryHookHandler | None = field(default=None) + """Optional callback which is called with the internal span and an ASGI scope which is sent as a dictionary for when + the method receive is called. + """ + client_response_hook_handler: OpenTelemetryHookHandler | None = field(default=None) + """Optional callback which is called with the internal span and an ASGI event which is sent as a dictionary for when + the method send is called. + """ + meter_provider: MeterProvider | None = field(default=None) + """Optional meter provider to use. + + If omitted the current globally configured one is used. + """ + tracer_provider: TracerProvider | None = field(default=None) + """Optional tracer provider to use. + + If omitted the current globally configured one is used. + """ + meter: Meter | None = field(default=None) + """Optional meter to use. + + If omitted the provided meter provider or the global one will be used. + """ + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns to skip in the Allowed Hosts middleware.""" + exclude_opt_key: str | None = field(default=None) + """An identifier to use on routes to disable hosts check for a particular route.""" + exclude_urls_env_key: str = "LITESTAR" + """Key to use when checking whether a list of excluded urls is passed via ENV. + + OpenTelemetry supports excluding urls by passing an env in the format '{exclude_urls_env_key}_EXCLUDED_URLS'. With + the default being ``LITESTAR_EXCLUDED_URLS``. + """ + scopes: Scopes | None = field(default=None) + """ASGI scopes processed by the middleware, if None both ``http`` and ``websocket`` will be processed.""" + middleware_class: type[OpenTelemetryInstrumentationMiddleware] = field( + default=OpenTelemetryInstrumentationMiddleware + ) + """The middleware class to use. + + Should be a subclass of OpenTelemetry + InstrumentationMiddleware][litestar.contrib.opentelemetry.OpenTelemetryInstrumentationMiddleware]. + """ + + @property + def middleware(self) -> DefineMiddleware: + """Create an instance of :class:`DefineMiddleware <litestar.middleware.base.DefineMiddleware>` that wraps with. + + [OpenTelemetry + InstrumentationMiddleware][litestar.contrib.opentelemetry.OpenTelemetryInstrumentationMiddleware] or a subclass + of this middleware. + + Returns: + An instance of ``DefineMiddleware``. + """ + return DefineMiddleware(self.middleware_class, config=self) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/middleware.py b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/middleware.py new file mode 100644 index 0000000..762bae9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/opentelemetry/middleware.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from litestar.exceptions import MissingDependencyException +from litestar.middleware.base import AbstractMiddleware + +__all__ = ("OpenTelemetryInstrumentationMiddleware",) + + +try: + import opentelemetry # noqa: F401 +except ImportError as e: + raise MissingDependencyException("opentelemetry") from e + +from opentelemetry.instrumentation.asgi import OpenTelemetryMiddleware +from opentelemetry.util.http import get_excluded_urls + +if TYPE_CHECKING: + from litestar.contrib.opentelemetry import OpenTelemetryConfig + from litestar.types import ASGIApp, Receive, Scope, Send + + +class OpenTelemetryInstrumentationMiddleware(AbstractMiddleware): + """OpenTelemetry Middleware.""" + + __slots__ = ("open_telemetry_middleware",) + + def __init__(self, app: ASGIApp, config: OpenTelemetryConfig) -> None: + """Middleware that adds OpenTelemetry instrumentation to the application. + + Args: + app: The ``next`` ASGI app to call. + config: An instance of :class:`OpenTelemetryConfig <.contrib.opentelemetry.OpenTelemetryConfig>` + """ + super().__init__(app=app, scopes=config.scopes, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key) + self.open_telemetry_middleware = OpenTelemetryMiddleware( + app=app, + client_request_hook=config.client_request_hook_handler, + client_response_hook=config.client_response_hook_handler, + default_span_details=config.scope_span_details_extractor, + excluded_urls=get_excluded_urls(config.exclude_urls_env_key), + meter=config.meter, + meter_provider=config.meter_provider, + server_request_hook=config.server_request_hook_handler, + tracer_provider=config.tracer_provider, + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + await self.open_telemetry_middleware(scope, receive, send) # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues] diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/piccolo.py b/venv/lib/python3.11/site-packages/litestar/contrib/piccolo.py new file mode 100644 index 0000000..73bd271 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/piccolo.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import warnings +from dataclasses import replace +from decimal import Decimal +from typing import Any, Generator, Generic, List, Optional, TypeVar + +from msgspec import Meta +from typing_extensions import Annotated + +from litestar.dto import AbstractDTO, DTOField, Mark +from litestar.dto.data_structures import DTOFieldDefinition +from litestar.exceptions import LitestarWarning, MissingDependencyException +from litestar.types import Empty +from litestar.typing import FieldDefinition +from litestar.utils import warn_deprecation + +try: + from piccolo.columns import Column, column_types + from piccolo.table import Table +except ImportError as e: + raise MissingDependencyException("piccolo") from e + + +T = TypeVar("T", bound=Table) + +__all__ = ("PiccoloDTO",) + + +def __getattr__(name: str) -> Any: + warn_deprecation( + deprecated_name=f"litestar.contrib.piccolo.{name}", + version="2.3.2", + kind="import", + removal_in="3.0.0", + info="importing from 'litestar.contrib.piccolo' is deprecated and will be removed in 3.0, please import from 'litestar_piccolo' package directly instead", + ) + return getattr(name, name) + + +def _parse_piccolo_type(column: Column, extra: dict[str, Any]) -> FieldDefinition: + is_optional = not column._meta.required + + if isinstance(column, (column_types.Decimal, column_types.Numeric)): + column_type: Any = Decimal + meta = Meta(extra=extra) + elif isinstance(column, (column_types.Email, column_types.Varchar)): + column_type = str + if is_optional: + meta = Meta(extra=extra) + warnings.warn( + f"Dropping max_length constraint for column {column!r} because the " "column is optional", + category=LitestarWarning, + stacklevel=2, + ) + else: + meta = Meta(max_length=column.length, extra=extra) + elif isinstance(column, column_types.Array): + column_type = List[column.base_column.value_type] # type: ignore[name-defined] + meta = Meta(extra=extra) + elif isinstance(column, (column_types.JSON, column_types.JSONB)): + column_type = str + meta = Meta(extra={**extra, "format": "json"}) + elif isinstance(column, column_types.Text): + column_type = str + meta = Meta(extra={**extra, "format": "text-area"}) + else: + column_type = column.value_type + meta = Meta(extra=extra) + + if is_optional: + column_type = Optional[column_type] + + return FieldDefinition.from_annotation(Annotated[column_type, meta]) + + +def _create_column_extra(column: Column) -> dict[str, Any]: + extra: dict[str, Any] = {} + + if column._meta.help_text: + extra["description"] = column._meta.help_text + + if column._meta.get_choices_dict(): + extra["enum"] = column._meta.get_choices_dict() + + return extra + + +class PiccoloDTO(AbstractDTO[T], Generic[T]): + @classmethod + def generate_field_definitions(cls, model_type: type[Table]) -> Generator[DTOFieldDefinition, None, None]: + for column in model_type._meta.columns: + mark = Mark.WRITE_ONLY if column._meta.secret else Mark.READ_ONLY if column._meta.primary_key else None + yield replace( + DTOFieldDefinition.from_field_definition( + field_definition=_parse_piccolo_type(column, _create_column_extra(column)), + dto_field=DTOField(mark=mark), + model_name=model_type.__name__, + default_factory=None, + ), + default=Empty if column._meta.required else None, + name=column._meta.name, + ) + + @classmethod + def detect_nested_field(cls, field_definition: FieldDefinition) -> bool: + return field_definition.is_subclass_of(Table) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__init__.py b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__init__.py new file mode 100644 index 0000000..1ccb494 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__init__.py @@ -0,0 +1,5 @@ +from .config import PrometheusConfig +from .controller import PrometheusController +from .middleware import PrometheusMiddleware + +__all__ = ("PrometheusMiddleware", "PrometheusConfig", "PrometheusController") diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c6c5558 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__pycache__/config.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__pycache__/config.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a998104 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__pycache__/config.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__pycache__/controller.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__pycache__/controller.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..012b4be --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__pycache__/controller.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__pycache__/middleware.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__pycache__/middleware.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2c08508 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/__pycache__/middleware.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/config.py b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/config.py new file mode 100644 index 0000000..b77dab0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/config.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Callable, Mapping, Sequence + +from litestar.contrib.prometheus.middleware import ( + PrometheusMiddleware, +) +from litestar.exceptions import MissingDependencyException +from litestar.middleware.base import DefineMiddleware + +__all__ = ("PrometheusConfig",) + + +try: + import prometheus_client # noqa: F401 +except ImportError as e: + raise MissingDependencyException("prometheus_client", "prometheus-client", "prometheus") from e + + +if TYPE_CHECKING: + from litestar.connection.request import Request + from litestar.types import Method, Scopes + + +@dataclass +class PrometheusConfig: + """Configuration class for the PrometheusConfig middleware.""" + + app_name: str = field(default="litestar") + """The name of the application to use in the metrics.""" + prefix: str = "litestar" + """The prefix to use for the metrics.""" + labels: Mapping[str, str | Callable] | None = field(default=None) + """A mapping of labels to add to the metrics. The values can be either a string or a callable that returns a string.""" + exemplars: Callable[[Request], dict] | None = field(default=None) + """A callable that returns a list of exemplars to add to the metrics. Only supported in opementrics-text exposition format.""" + buckets: list[str | float] | None = field(default=None) + """A list of buckets to use for the histogram.""" + excluded_http_methods: Method | Sequence[Method] | None = field(default=None) + """A list of http methods to exclude from the metrics.""" + exclude_unhandled_paths: bool = field(default=False) + """Whether to ignore requests for unhandled paths from the metrics.""" + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns for routes to exclude from the metrics.""" + exclude_opt_key: str | None = field(default=None) + """A key or list of keys in ``opt`` with which a route handler can "opt-out" of the middleware.""" + scopes: Scopes | None = field(default=None) + """ASGI scopes processed by the middleware, if None both ``http`` and ``websocket`` will be processed.""" + middleware_class: type[PrometheusMiddleware] = field(default=PrometheusMiddleware) + """The middleware class to use. + """ + + @property + def middleware(self) -> DefineMiddleware: + """Create an instance of :class:`DefineMiddleware <litestar.middleware.base.DefineMiddleware>` that wraps with. + + [PrometheusMiddleware][litestar.contrib.prometheus.PrometheusMiddleware]. or a subclass + of this middleware. + + Returns: + An instance of ``DefineMiddleware``. + """ + return DefineMiddleware(self.middleware_class, config=self) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/controller.py b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/controller.py new file mode 100644 index 0000000..15f5bf1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/controller.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import os + +from litestar import Controller, get +from litestar.exceptions import MissingDependencyException +from litestar.response import Response + +try: + import prometheus_client # noqa: F401 +except ImportError as e: + raise MissingDependencyException("prometheus_client", "prometheus-client", "prometheus") from e + +from prometheus_client import ( + CONTENT_TYPE_LATEST, + REGISTRY, + CollectorRegistry, + generate_latest, + multiprocess, +) +from prometheus_client.openmetrics.exposition import ( + CONTENT_TYPE_LATEST as OPENMETRICS_CONTENT_TYPE_LATEST, +) +from prometheus_client.openmetrics.exposition import ( + generate_latest as openmetrics_generate_latest, +) + +__all__ = [ + "PrometheusController", +] + + +class PrometheusController(Controller): + """Controller for Prometheus endpoints.""" + + path: str = "/metrics" + """The path to expose the metrics on.""" + openmetrics_format: bool = False + """Whether to expose the metrics in OpenMetrics format.""" + + @get() + async def get(self) -> Response: + registry = REGISTRY + if "prometheus_multiproc_dir" in os.environ or "PROMETHEUS_MULTIPROC_DIR" in os.environ: + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) # type: ignore[no-untyped-call] + + if self.openmetrics_format: + headers = {"Content-Type": OPENMETRICS_CONTENT_TYPE_LATEST} + return Response(openmetrics_generate_latest(registry), status_code=200, headers=headers) # type: ignore[no-untyped-call] + + headers = {"Content-Type": CONTENT_TYPE_LATEST} + return Response(generate_latest(registry), status_code=200, headers=headers) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/middleware.py b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/middleware.py new file mode 100644 index 0000000..50bc7cb --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/middleware.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import time +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast + +from litestar.connection.request import Request +from litestar.enums import ScopeType +from litestar.exceptions import MissingDependencyException +from litestar.middleware.base import AbstractMiddleware + +__all__ = ("PrometheusMiddleware",) + +from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR + +try: + import prometheus_client # noqa: F401 +except ImportError as e: + raise MissingDependencyException("prometheus_client", "prometheus-client", "prometheus") from e + +from prometheus_client import Counter, Gauge, Histogram + +if TYPE_CHECKING: + from prometheus_client.metrics import MetricWrapperBase + + from litestar.contrib.prometheus import PrometheusConfig + from litestar.types import ASGIApp, Message, Receive, Scope, Send + + +class PrometheusMiddleware(AbstractMiddleware): + """Prometheus Middleware.""" + + _metrics: ClassVar[dict[str, MetricWrapperBase]] = {} + + def __init__(self, app: ASGIApp, config: PrometheusConfig) -> None: + """Middleware that adds Prometheus instrumentation to the application. + + Args: + app: The ``next`` ASGI app to call. + config: An instance of :class:`PrometheusConfig <.contrib.prometheus.PrometheusConfig>` + """ + super().__init__(app=app, scopes=config.scopes, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key) + self._config = config + self._kwargs: dict[str, Any] = {} + + if self._config.buckets is not None: + self._kwargs["buckets"] = self._config.buckets + + def request_count(self, labels: dict[str, str | int | float]) -> Counter: + metric_name = f"{self._config.prefix}_requests_total" + + if metric_name not in PrometheusMiddleware._metrics: + PrometheusMiddleware._metrics[metric_name] = Counter( + name=metric_name, + documentation="Total requests", + labelnames=[*labels.keys()], + ) + + return cast("Counter", PrometheusMiddleware._metrics[metric_name]) + + def request_time(self, labels: dict[str, str | int | float]) -> Histogram: + metric_name = f"{self._config.prefix}_request_duration_seconds" + + if metric_name not in PrometheusMiddleware._metrics: + PrometheusMiddleware._metrics[metric_name] = Histogram( + name=metric_name, + documentation="Request duration, in seconds", + labelnames=[*labels.keys()], + **self._kwargs, + ) + return cast("Histogram", PrometheusMiddleware._metrics[metric_name]) + + def requests_in_progress(self, labels: dict[str, str | int | float]) -> Gauge: + metric_name = f"{self._config.prefix}_requests_in_progress" + + if metric_name not in PrometheusMiddleware._metrics: + PrometheusMiddleware._metrics[metric_name] = Gauge( + name=metric_name, + documentation="Total requests currently in progress", + labelnames=[*labels.keys()], + multiprocess_mode="livesum", + ) + return cast("Gauge", PrometheusMiddleware._metrics[metric_name]) + + def requests_error_count(self, labels: dict[str, str | int | float]) -> Counter: + metric_name = f"{self._config.prefix}_requests_error_total" + + if metric_name not in PrometheusMiddleware._metrics: + PrometheusMiddleware._metrics[metric_name] = Counter( + name=metric_name, + documentation="Total errors in requests", + labelnames=[*labels.keys()], + ) + return cast("Counter", PrometheusMiddleware._metrics[metric_name]) + + def _get_extra_labels(self, request: Request[Any, Any, Any]) -> dict[str, str]: + """Get extra labels provided by the config and if they are callable, parse them. + + Args: + request: The request object. + + Returns: + A dictionary of extra labels. + """ + + return {k: str(v(request) if callable(v) else v) for k, v in (self._config.labels or {}).items()} + + def _get_default_labels(self, request: Request[Any, Any, Any]) -> dict[str, str | int | float]: + """Get default label values from the request. + + Args: + request: The request object. + + Returns: + A dictionary of default labels. + """ + + return { + "method": request.method if request.scope["type"] == ScopeType.HTTP else request.scope["type"], + "path": request.url.path, + "status_code": 200, + "app_name": self._config.app_name, + } + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + + request = Request[Any, Any, Any](scope, receive) + + if self._config.excluded_http_methods and request.method in self._config.excluded_http_methods: + await self.app(scope, receive, send) + return + + labels = {**self._get_default_labels(request), **self._get_extra_labels(request)} + + request_span = {"start_time": time.perf_counter(), "end_time": 0, "duration": 0, "status_code": 200} + + wrapped_send = self._get_wrapped_send(send, request_span) + + self.requests_in_progress(labels).labels(*labels.values()).inc() + + try: + await self.app(scope, receive, wrapped_send) + finally: + extra: dict[str, Any] = {} + if self._config.exemplars: + extra["exemplar"] = self._config.exemplars(request) + + self.requests_in_progress(labels).labels(*labels.values()).dec() + + labels["status_code"] = request_span["status_code"] + label_values = [*labels.values()] + + if request_span["status_code"] >= HTTP_500_INTERNAL_SERVER_ERROR: + self.requests_error_count(labels).labels(*label_values).inc(**extra) + + self.request_count(labels).labels(*label_values).inc(**extra) + self.request_time(labels).labels(*label_values).observe(request_span["duration"], **extra) + + def _get_wrapped_send(self, send: Send, request_span: dict[str, float]) -> Callable: + @wraps(send) + async def wrapped_send(message: Message) -> None: + if message["type"] == "http.response.start": + request_span["status_code"] = message["status"] + + if message["type"] == "http.response.body": + end = time.perf_counter() + request_span["duration"] = end - request_span["start_time"] + request_span["end_time"] = end + await send(message) + + return wrapped_send diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__init__.py b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__init__.py new file mode 100644 index 0000000..9bab707 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__init__.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from litestar.plugins import InitPluginProtocol + +from .pydantic_di_plugin import PydanticDIPlugin +from .pydantic_dto_factory import PydanticDTO +from .pydantic_init_plugin import PydanticInitPlugin +from .pydantic_schema_plugin import PydanticSchemaPlugin + +if TYPE_CHECKING: + from pydantic import BaseModel + from pydantic.v1 import BaseModel as BaseModelV1 + + from litestar.config.app import AppConfig + +__all__ = ( + "PydanticDTO", + "PydanticInitPlugin", + "PydanticSchemaPlugin", + "PydanticPlugin", + "PydanticDIPlugin", +) + + +def _model_dump(model: BaseModel | BaseModelV1, *, by_alias: bool = False) -> dict[str, Any]: + return ( + model.model_dump(mode="json", by_alias=by_alias) # pyright: ignore + if hasattr(model, "model_dump") + else {k: v.decode() if isinstance(v, bytes) else v for k, v in model.dict(by_alias=by_alias).items()} + ) + + +def _model_dump_json(model: BaseModel | BaseModelV1, by_alias: bool = False) -> str: + return ( + model.model_dump_json(by_alias=by_alias) # pyright: ignore + if hasattr(model, "model_dump_json") + else model.json(by_alias=by_alias) # pyright: ignore + ) + + +class PydanticPlugin(InitPluginProtocol): + """A plugin that provides Pydantic integration.""" + + __slots__ = ("prefer_alias",) + + def __init__(self, prefer_alias: bool = False) -> None: + """Initialize ``PydanticPlugin``. + + Args: + prefer_alias: OpenAPI and ``type_encoders`` will export by alias + """ + self.prefer_alias = prefer_alias + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + """Configure application for use with Pydantic. + + Args: + app_config: The :class:`AppConfig <.config.app.AppConfig>` instance. + """ + app_config.plugins.extend( + [ + PydanticInitPlugin(prefer_alias=self.prefer_alias), + PydanticSchemaPlugin(prefer_alias=self.prefer_alias), + PydanticDIPlugin(), + ] + ) + return app_config diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a62eb0f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/config.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/config.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5515df6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/config.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/pydantic_di_plugin.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/pydantic_di_plugin.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f8c4e00 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/pydantic_di_plugin.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/pydantic_dto_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/pydantic_dto_factory.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..490108e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/pydantic_dto_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/pydantic_init_plugin.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/pydantic_init_plugin.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..13788d4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/pydantic_init_plugin.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/pydantic_schema_plugin.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/pydantic_schema_plugin.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8b0946b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/pydantic_schema_plugin.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..dbb0e54 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/__pycache__/utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/config.py b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/config.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/config.py diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/pydantic_di_plugin.py b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/pydantic_di_plugin.py new file mode 100644 index 0000000..2096fd4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/pydantic_di_plugin.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import inspect +from inspect import Signature +from typing import Any + +from litestar.contrib.pydantic.utils import is_pydantic_model_class +from litestar.plugins import DIPlugin + + +class PydanticDIPlugin(DIPlugin): + def has_typed_init(self, type_: Any) -> bool: + return is_pydantic_model_class(type_) + + def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]: + try: + model_fields = dict(type_.model_fields) + except AttributeError: + model_fields = {k: model_field.field_info for k, model_field in type_.__fields__.items()} + + parameters = [ + inspect.Parameter(name=field_name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=Any) + for field_name in model_fields + ] + type_hints = {field_name: Any for field_name in model_fields} + return Signature(parameters), type_hints diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/pydantic_dto_factory.py b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/pydantic_dto_factory.py new file mode 100644 index 0000000..d61f95d --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/pydantic_dto_factory.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from dataclasses import replace +from typing import TYPE_CHECKING, Collection, Generic, TypeVar + +from typing_extensions import TypeAlias, override + +from litestar.contrib.pydantic.utils import is_pydantic_undefined +from litestar.dto.base_dto import AbstractDTO +from litestar.dto.data_structures import DTOFieldDefinition +from litestar.dto.field import DTO_FIELD_META_KEY, DTOField +from litestar.exceptions import MissingDependencyException, ValidationException +from litestar.types.empty import Empty + +if TYPE_CHECKING: + from typing import Any, Generator + + from litestar.typing import FieldDefinition + +try: + import pydantic as _ # noqa: F401 +except ImportError as e: + raise MissingDependencyException("pydantic") from e + + +try: + import pydantic as pydantic_v2 + from pydantic import ValidationError as ValidationErrorV2 + from pydantic import v1 as pydantic_v1 + from pydantic.v1 import ValidationError as ValidationErrorV1 + + ModelType: TypeAlias = "pydantic_v1.BaseModel | pydantic_v2.BaseModel" + +except ImportError: + import pydantic as pydantic_v1 # type: ignore[no-redef] + + pydantic_v2 = Empty # type: ignore[assignment] + from pydantic import ValidationError as ValidationErrorV1 # type: ignore[assignment] + + ValidationErrorV2 = ValidationErrorV1 # type: ignore[assignment, misc] + ModelType = "pydantic_v1.BaseModel" # type: ignore[misc] + + +T = TypeVar("T", bound="ModelType | Collection[ModelType]") + + +__all__ = ("PydanticDTO",) + + +class PydanticDTO(AbstractDTO[T], Generic[T]): + """Support for domain modelling with Pydantic.""" + + @override + def decode_builtins(self, value: dict[str, Any]) -> Any: + try: + return super().decode_builtins(value) + except (ValidationErrorV2, ValidationErrorV1) as ex: + raise ValidationException(extra=ex.errors()) from ex + + @override + def decode_bytes(self, value: bytes) -> Any: + try: + return super().decode_bytes(value) + except (ValidationErrorV2, ValidationErrorV1) as ex: + raise ValidationException(extra=ex.errors()) from ex + + @classmethod + def generate_field_definitions( + cls, model_type: type[pydantic_v1.BaseModel | pydantic_v2.BaseModel] + ) -> Generator[DTOFieldDefinition, None, None]: + model_field_definitions = cls.get_model_type_hints(model_type) + + model_fields: dict[str, pydantic_v1.fields.FieldInfo | pydantic_v2.fields.FieldInfo] + try: + model_fields = dict(model_type.model_fields) # type: ignore[union-attr] + except AttributeError: + model_fields = { + k: model_field.field_info + for k, model_field in model_type.__fields__.items() # type: ignore[union-attr] + } + + for field_name, field_info in model_fields.items(): + field_definition = model_field_definitions[field_name] + dto_field = (field_definition.extra or {}).pop(DTO_FIELD_META_KEY, DTOField()) + + if not is_pydantic_undefined(field_info.default): + default = field_info.default + elif field_definition.is_optional: + default = None + else: + default = Empty + + yield replace( + DTOFieldDefinition.from_field_definition( + field_definition=field_definition, + dto_field=dto_field, + model_name=model_type.__name__, + default_factory=field_info.default_factory + if field_info.default_factory and not is_pydantic_undefined(field_info.default_factory) + else None, + ), + default=default, + name=field_name, + ) + + @classmethod + def detect_nested_field(cls, field_definition: FieldDefinition) -> bool: + if pydantic_v2 is not Empty: # type: ignore[comparison-overlap] + return field_definition.is_subclass_of((pydantic_v1.BaseModel, pydantic_v2.BaseModel)) + return field_definition.is_subclass_of(pydantic_v1.BaseModel) # type: ignore[unreachable] diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/pydantic_init_plugin.py b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/pydantic_init_plugin.py new file mode 100644 index 0000000..1261cd8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/pydantic_init_plugin.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from contextlib import suppress +from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast +from uuid import UUID + +from msgspec import ValidationError +from typing_extensions import Buffer, TypeGuard + +from litestar._signature.types import ExtendedMsgSpecValidationError +from litestar.contrib.pydantic.utils import is_pydantic_constrained_field +from litestar.exceptions import MissingDependencyException +from litestar.plugins import InitPluginProtocol +from litestar.typing import _KWARG_META_EXTRACTORS +from litestar.utils import is_class_and_subclass + +try: + # check if we have pydantic v2 installed, and try to import both versions + import pydantic as pydantic_v2 + from pydantic import v1 as pydantic_v1 +except ImportError: + # check if pydantic 1 is installed and import it + try: + import pydantic as pydantic_v1 # type: ignore[no-redef] + + pydantic_v2 = None # type: ignore[assignment] + except ImportError as e: + raise MissingDependencyException("pydantic") from e + + +if TYPE_CHECKING: + from litestar.config.app import AppConfig + + +T = TypeVar("T") + + +def _dec_pydantic_v1(model_type: type[pydantic_v1.BaseModel], value: Any) -> pydantic_v1.BaseModel: + try: + return model_type.parse_obj(value) + except pydantic_v1.ValidationError as e: + raise ExtendedMsgSpecValidationError(errors=cast("list[dict[str, Any]]", e.errors())) from e + + +def _dec_pydantic_v2(model_type: type[pydantic_v2.BaseModel], value: Any) -> pydantic_v2.BaseModel: + try: + return model_type.model_validate(value, strict=False) + except pydantic_v2.ValidationError as e: + raise ExtendedMsgSpecValidationError(errors=cast("list[dict[str, Any]]", e.errors())) from e + + +def _dec_pydantic_uuid( + uuid_type: type[pydantic_v1.UUID1] | type[pydantic_v1.UUID3] | type[pydantic_v1.UUID4] | type[pydantic_v1.UUID5], + value: Any, +) -> ( + type[pydantic_v1.UUID1] | type[pydantic_v1.UUID3] | type[pydantic_v1.UUID4] | type[pydantic_v1.UUID5] +): # pragma: no cover + if isinstance(value, str): + value = uuid_type(value) + + elif isinstance(value, Buffer): + value = bytes(value) + try: + value = uuid_type(value.decode()) + except ValueError: + # 16 bytes in big-endian order as the bytes argument fail + # the above check + value = uuid_type(bytes=value) + elif isinstance(value, UUID): + value = uuid_type(str(value)) + + if not isinstance(value, uuid_type): + raise ValidationError(f"Invalid UUID: {value!r}") + + if value._required_version != value.version: + raise ValidationError(f"Invalid UUID version: {value!r}") + + return cast( + "type[pydantic_v1.UUID1] | type[pydantic_v1.UUID3] | type[pydantic_v1.UUID4] | type[pydantic_v1.UUID5]", value + ) + + +def _is_pydantic_v1_uuid(value: Any) -> bool: # pragma: no cover + return is_class_and_subclass(value, (pydantic_v1.UUID1, pydantic_v1.UUID3, pydantic_v1.UUID4, pydantic_v1.UUID5)) + + +_base_encoders: dict[Any, Callable[[Any], Any]] = { + pydantic_v1.EmailStr: str, + pydantic_v1.NameEmail: str, + pydantic_v1.ByteSize: lambda val: val.real, +} + +if pydantic_v2 is not None: # pragma: no cover + _base_encoders.update( + { + pydantic_v2.EmailStr: str, + pydantic_v2.NameEmail: str, + pydantic_v2.ByteSize: lambda val: val.real, + } + ) + + +def is_pydantic_v1_model_class(annotation: Any) -> TypeGuard[type[pydantic_v1.BaseModel]]: + return is_class_and_subclass(annotation, pydantic_v1.BaseModel) + + +def is_pydantic_v2_model_class(annotation: Any) -> TypeGuard[type[pydantic_v2.BaseModel]]: + return is_class_and_subclass(annotation, pydantic_v2.BaseModel) + + +class ConstrainedFieldMetaExtractor: + @staticmethod + def matches(annotation: Any, name: str | None, default: Any) -> bool: + return is_pydantic_constrained_field(annotation) + + @staticmethod + def extract(annotation: Any, default: Any) -> Any: + return [annotation] + + +class PydanticInitPlugin(InitPluginProtocol): + __slots__ = ("prefer_alias",) + + def __init__(self, prefer_alias: bool = False) -> None: + self.prefer_alias = prefer_alias + + @classmethod + def encoders(cls, prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: + encoders = {**_base_encoders, **cls._create_pydantic_v1_encoders(prefer_alias)} + if pydantic_v2 is not None: # pragma: no cover + encoders.update(cls._create_pydantic_v2_encoders(prefer_alias)) + return encoders + + @classmethod + def decoders(cls) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]: + decoders: list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]] = [ + (is_pydantic_v1_model_class, _dec_pydantic_v1) + ] + + if pydantic_v2 is not None: # pragma: no cover + decoders.append((is_pydantic_v2_model_class, _dec_pydantic_v2)) + + decoders.append((_is_pydantic_v1_uuid, _dec_pydantic_uuid)) + + return decoders + + @staticmethod + def _create_pydantic_v1_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: # pragma: no cover + return { + pydantic_v1.BaseModel: lambda model: { + k: v.decode() if isinstance(v, bytes) else v for k, v in model.dict(by_alias=prefer_alias).items() + }, + pydantic_v1.SecretField: str, + pydantic_v1.StrictBool: int, + pydantic_v1.color.Color: str, + pydantic_v1.ConstrainedBytes: lambda val: val.decode("utf-8"), + pydantic_v1.ConstrainedDate: lambda val: val.isoformat(), + pydantic_v1.AnyUrl: str, + } + + @staticmethod + def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: + encoders: dict[Any, Callable[[Any], Any]] = { + pydantic_v2.BaseModel: lambda model: model.model_dump(mode="json", by_alias=prefer_alias), + pydantic_v2.types.SecretStr: lambda val: "**********" if val else "", + pydantic_v2.types.SecretBytes: lambda val: "**********" if val else "", + pydantic_v2.AnyUrl: str, + } + + with suppress(ImportError): + from pydantic_extra_types import color + + encoders[color.Color] = str + + return encoders + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + app_config.type_encoders = {**self.encoders(self.prefer_alias), **(app_config.type_encoders or {})} + app_config.type_decoders = [*self.decoders(), *(app_config.type_decoders or [])] + + _KWARG_META_EXTRACTORS.add(ConstrainedFieldMetaExtractor) + return app_config diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/pydantic_schema_plugin.py b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/pydantic_schema_plugin.py new file mode 100644 index 0000000..2c189e4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/pydantic_schema_plugin.py @@ -0,0 +1,317 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +from typing_extensions import Annotated + +from litestar.contrib.pydantic.utils import ( + create_field_definitions_for_computed_fields, + is_pydantic_2_model, + is_pydantic_constrained_field, + is_pydantic_model_class, + is_pydantic_undefined, + pydantic_get_type_hints_with_generics_resolved, + pydantic_unwrap_and_get_origin, +) +from litestar.exceptions import MissingDependencyException +from litestar.openapi.spec import OpenAPIFormat, OpenAPIType, Schema +from litestar.plugins import OpenAPISchemaPlugin +from litestar.types import Empty +from litestar.typing import FieldDefinition +from litestar.utils import is_class_and_subclass, is_generic + +try: + # check if we have pydantic v2 installed, and try to import both versions + import pydantic as pydantic_v2 + from pydantic import v1 as pydantic_v1 +except ImportError: + # check if pydantic 1 is installed and import it + try: + import pydantic as pydantic_v1 # type: ignore[no-redef] + + pydantic_v2 = None # type: ignore[assignment] + except ImportError as e: + raise MissingDependencyException("pydantic") from e + +if TYPE_CHECKING: + from litestar._openapi.schema_generation.schema import SchemaCreator + +PYDANTIC_TYPE_MAP: dict[type[Any] | None | Any, Schema] = { + pydantic_v1.ByteSize: Schema(type=OpenAPIType.INTEGER), + pydantic_v1.EmailStr: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.EMAIL), + pydantic_v1.IPvAnyAddress: Schema( + one_of=[ + Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.IPV4, + description="IPv4 address", + ), + Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.IPV6, + description="IPv6 address", + ), + ] + ), + pydantic_v1.IPvAnyInterface: Schema( + one_of=[ + Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.IPV4, + description="IPv4 interface", + ), + Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.IPV6, + description="IPv6 interface", + ), + ] + ), + pydantic_v1.IPvAnyNetwork: Schema( + one_of=[ + Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.IPV4, + description="IPv4 network", + ), + Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.IPV6, + description="IPv6 network", + ), + ] + ), + pydantic_v1.Json: Schema(type=OpenAPIType.OBJECT, format=OpenAPIFormat.JSON_POINTER), + pydantic_v1.NameEmail: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.EMAIL, description="Name and email"), + # removed in v2 + pydantic_v1.PyObject: Schema( + type=OpenAPIType.STRING, + description="dot separated path identifying a python object, e.g. 'decimal.Decimal'", + ), + # annotated in v2 + pydantic_v1.UUID1: Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.UUID, + description="UUID1 string", + ), + pydantic_v1.UUID3: Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.UUID, + description="UUID3 string", + ), + pydantic_v1.UUID4: Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.UUID, + description="UUID4 string", + ), + pydantic_v1.UUID5: Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.UUID, + description="UUID5 string", + ), + pydantic_v1.DirectoryPath: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI_REFERENCE), + pydantic_v1.AnyUrl: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URL), + pydantic_v1.AnyHttpUrl: Schema( + type=OpenAPIType.STRING, format=OpenAPIFormat.URL, description="must be a valid HTTP based URL" + ), + pydantic_v1.FilePath: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI_REFERENCE), + pydantic_v1.HttpUrl: Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.URL, + description="must be a valid HTTP based URL", + max_length=2083, + ), + pydantic_v1.RedisDsn: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI, description="redis DSN"), + pydantic_v1.PostgresDsn: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI, description="postgres DSN"), + pydantic_v1.SecretBytes: Schema(type=OpenAPIType.STRING), + pydantic_v1.SecretStr: Schema(type=OpenAPIType.STRING), + pydantic_v1.StrictBool: Schema(type=OpenAPIType.BOOLEAN), + pydantic_v1.StrictBytes: Schema(type=OpenAPIType.STRING), + pydantic_v1.StrictFloat: Schema(type=OpenAPIType.NUMBER), + pydantic_v1.StrictInt: Schema(type=OpenAPIType.INTEGER), + pydantic_v1.StrictStr: Schema(type=OpenAPIType.STRING), + pydantic_v1.NegativeFloat: Schema(type=OpenAPIType.NUMBER, exclusive_maximum=0.0), + pydantic_v1.NegativeInt: Schema(type=OpenAPIType.INTEGER, exclusive_maximum=0), + pydantic_v1.NonNegativeInt: Schema(type=OpenAPIType.INTEGER, minimum=0), + pydantic_v1.NonPositiveFloat: Schema(type=OpenAPIType.NUMBER, maximum=0.0), + pydantic_v1.PaymentCardNumber: Schema(type=OpenAPIType.STRING, min_length=12, max_length=19), + pydantic_v1.PositiveFloat: Schema(type=OpenAPIType.NUMBER, exclusive_minimum=0.0), + pydantic_v1.PositiveInt: Schema(type=OpenAPIType.INTEGER, exclusive_minimum=0), +} + +if pydantic_v2 is not None: # pragma: no cover + PYDANTIC_TYPE_MAP.update( + { + pydantic_v2.SecretStr: Schema(type=OpenAPIType.STRING), + pydantic_v2.SecretBytes: Schema(type=OpenAPIType.STRING), + pydantic_v2.ByteSize: Schema(type=OpenAPIType.INTEGER), + pydantic_v2.EmailStr: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.EMAIL), + pydantic_v2.IPvAnyAddress: Schema( + one_of=[ + Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.IPV4, + description="IPv4 address", + ), + Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.IPV6, + description="IPv6 address", + ), + ] + ), + pydantic_v2.IPvAnyInterface: Schema( + one_of=[ + Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.IPV4, + description="IPv4 interface", + ), + Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.IPV6, + description="IPv6 interface", + ), + ] + ), + pydantic_v2.IPvAnyNetwork: Schema( + one_of=[ + Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.IPV4, + description="IPv4 network", + ), + Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.IPV6, + description="IPv6 network", + ), + ] + ), + pydantic_v2.Json: Schema(type=OpenAPIType.OBJECT, format=OpenAPIFormat.JSON_POINTER), + pydantic_v2.NameEmail: Schema( + type=OpenAPIType.STRING, format=OpenAPIFormat.EMAIL, description="Name and email" + ), + pydantic_v2.AnyUrl: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URL), + } + ) + + +_supported_types = (pydantic_v1.BaseModel, *PYDANTIC_TYPE_MAP.keys()) +if pydantic_v2 is not None: # pragma: no cover + _supported_types = (pydantic_v2.BaseModel, *_supported_types) + + +class PydanticSchemaPlugin(OpenAPISchemaPlugin): + __slots__ = ("prefer_alias",) + + def __init__(self, prefer_alias: bool = False) -> None: + self.prefer_alias = prefer_alias + + @staticmethod + def is_plugin_supported_type(value: Any) -> bool: + return isinstance(value, _supported_types) or is_class_and_subclass(value, _supported_types) # type: ignore[arg-type] + + @staticmethod + def is_undefined_sentinel(value: Any) -> bool: + return is_pydantic_undefined(value) + + @staticmethod + def is_constrained_field(field_definition: FieldDefinition) -> bool: + return is_pydantic_constrained_field(field_definition.annotation) + + def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema: + """Given a type annotation, transform it into an OpenAPI schema class. + + Args: + field_definition: FieldDefinition instance. + schema_creator: An instance of the schema creator class + + Returns: + An :class:`OpenAPI <litestar.openapi.spec.schema.Schema>` instance. + """ + if schema_creator.prefer_alias != self.prefer_alias: + schema_creator.prefer_alias = True + if is_pydantic_model_class(field_definition.annotation): + return self.for_pydantic_model(field_definition=field_definition, schema_creator=schema_creator) + return PYDANTIC_TYPE_MAP[field_definition.annotation] # pragma: no cover + + @classmethod + def for_pydantic_model(cls, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema: # pyright: ignore + """Create a schema object for a given pydantic model class. + + Args: + field_definition: FieldDefinition instance. + schema_creator: An instance of the schema creator class + + Returns: + A schema instance. + """ + + annotation = field_definition.annotation + if is_generic(annotation): + is_generic_model = True + model = pydantic_unwrap_and_get_origin(annotation) or annotation + else: + is_generic_model = False + model = annotation + + if is_pydantic_2_model(model): + model_config = model.model_config + model_field_info = model.model_fields + title = model_config.get("title") + example = model_config.get("example") + is_v2_model = True + else: + model_config = annotation.__config__ + model_field_info = model.__fields__ + title = getattr(model_config, "title", None) + example = getattr(model_config, "example", None) + is_v2_model = False + + model_fields: dict[str, pydantic_v1.fields.FieldInfo | pydantic_v2.fields.FieldInfo] = { # pyright: ignore + k: getattr(f, "field_info", f) for k, f in model_field_info.items() + } + + if is_v2_model: + # extract the annotations from the FieldInfo. This allows us to skip fields + # which have been marked as private + model_annotations = {k: field_info.annotation for k, field_info in model_fields.items()} # type: ignore[union-attr] + + else: + # pydantic v1 requires some workarounds here + model_annotations = { + k: f.outer_type_ if f.required else Optional[f.outer_type_] for k, f in model.__fields__.items() + } + + if is_generic_model: + # if the model is generic, resolve the type variables. We pass in the + # already extracted annotations, to keep the logic of respecting private + # fields consistent with the above + model_annotations = pydantic_get_type_hints_with_generics_resolved( + annotation, model_annotations=model_annotations, include_extras=True + ) + + property_fields = { + field_info.alias if field_info.alias and schema_creator.prefer_alias else k: FieldDefinition.from_kwarg( + annotation=Annotated[model_annotations[k], field_info, field_info.metadata] # type: ignore[union-attr] + if is_v2_model + else Annotated[model_annotations[k], field_info], # pyright: ignore + name=field_info.alias if field_info.alias and schema_creator.prefer_alias else k, + default=Empty if schema_creator.is_undefined(field_info.default) else field_info.default, + ) + for k, field_info in model_fields.items() + } + + computed_field_definitions = create_field_definitions_for_computed_fields( + annotation, schema_creator.prefer_alias + ) + property_fields.update(computed_field_definitions) + + return schema_creator.create_component_schema( + field_definition, + required=sorted(f.name for f in property_fields.values() if f.is_required), + property_fields=property_fields, + title=title, + examples=None if example is None else [example], + ) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/utils.py b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/utils.py new file mode 100644 index 0000000..6aee322 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/pydantic/utils.py @@ -0,0 +1,214 @@ +# mypy: strict-equality=False +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from typing_extensions import Annotated, get_type_hints + +from litestar.params import KwargDefinition +from litestar.types import Empty +from litestar.typing import FieldDefinition +from litestar.utils import deprecated, is_class_and_subclass +from litestar.utils.predicates import is_generic +from litestar.utils.typing import ( + _substitute_typevars, + get_origin_or_inner_type, + get_type_hints_with_generics_resolved, + normalize_type_annotation, +) + +# isort: off +try: + from pydantic import v1 as pydantic_v1 + import pydantic as pydantic_v2 + from pydantic.fields import PydanticUndefined as Pydantic2Undefined # type: ignore[attr-defined] + from pydantic.v1.fields import Undefined as Pydantic1Undefined + + PYDANTIC_UNDEFINED_SENTINELS = {Pydantic1Undefined, Pydantic2Undefined} +except ImportError: + try: + import pydantic as pydantic_v1 # type: ignore[no-redef] + from pydantic.fields import Undefined as Pydantic1Undefined # type: ignore[attr-defined, no-redef] + + pydantic_v2 = Empty # type: ignore[assignment] + PYDANTIC_UNDEFINED_SENTINELS = {Pydantic1Undefined} + + except ImportError: # pyright: ignore + pydantic_v1 = Empty # type: ignore[assignment] + pydantic_v2 = Empty # type: ignore[assignment] + PYDANTIC_UNDEFINED_SENTINELS = set() +# isort: on + + +if TYPE_CHECKING: + from typing_extensions import TypeGuard + + +def is_pydantic_model_class( + annotation: Any, +) -> TypeGuard[type[pydantic_v1.BaseModel | pydantic_v2.BaseModel]]: # pyright: ignore + """Given a type annotation determine if the annotation is a subclass of pydantic's BaseModel. + + Args: + annotation: A type. + + Returns: + A typeguard determining whether the type is :data:`BaseModel pydantic.BaseModel>`. + """ + tests: list[bool] = [] + + if pydantic_v1 is not Empty: # pragma: no cover + tests.append(is_class_and_subclass(annotation, pydantic_v1.BaseModel)) + + if pydantic_v2 is not Empty: # pragma: no cover + tests.append(is_class_and_subclass(annotation, pydantic_v2.BaseModel)) + + return any(tests) + + +def is_pydantic_model_instance( + annotation: Any, +) -> TypeGuard[pydantic_v1.BaseModel | pydantic_v2.BaseModel]: # pyright: ignore + """Given a type annotation determine if the annotation is an instance of pydantic's BaseModel. + + Args: + annotation: A type. + + Returns: + A typeguard determining whether the type is :data:`BaseModel pydantic.BaseModel>`. + """ + tests: list[bool] = [] + + if pydantic_v1 is not Empty: # pragma: no cover + tests.append(isinstance(annotation, pydantic_v1.BaseModel)) + + if pydantic_v2 is not Empty: # pragma: no cover + tests.append(isinstance(annotation, pydantic_v2.BaseModel)) + + return any(tests) + + +def is_pydantic_constrained_field(annotation: Any) -> bool: + """Check if the given annotation is a constrained pydantic type. + + Args: + annotation: A type annotation + + Returns: + True if pydantic is installed and the type is a constrained type, otherwise False. + """ + if pydantic_v1 is Empty: # pragma: no cover + return False # type: ignore[unreachable] + + return any( + is_class_and_subclass(annotation, constrained_type) # pyright: ignore + for constrained_type in ( + pydantic_v1.ConstrainedBytes, + pydantic_v1.ConstrainedDate, + pydantic_v1.ConstrainedDecimal, + pydantic_v1.ConstrainedFloat, + pydantic_v1.ConstrainedFrozenSet, + pydantic_v1.ConstrainedInt, + pydantic_v1.ConstrainedList, + pydantic_v1.ConstrainedSet, + pydantic_v1.ConstrainedStr, + ) + ) + + +def pydantic_unwrap_and_get_origin(annotation: Any) -> Any | None: + if pydantic_v2 is Empty or (pydantic_v1 is not Empty and is_class_and_subclass(annotation, pydantic_v1.BaseModel)): + return get_origin_or_inner_type(annotation) + + origin = annotation.__pydantic_generic_metadata__["origin"] + return normalize_type_annotation(origin) + + +def pydantic_get_type_hints_with_generics_resolved( + annotation: Any, + globalns: dict[str, Any] | None = None, + localns: dict[str, Any] | None = None, + include_extras: bool = False, + model_annotations: dict[str, Any] | None = None, +) -> dict[str, Any]: + if pydantic_v2 is Empty or (pydantic_v1 is not Empty and is_class_and_subclass(annotation, pydantic_v1.BaseModel)): + return get_type_hints_with_generics_resolved(annotation, type_hints=model_annotations) + + origin = pydantic_unwrap_and_get_origin(annotation) + if origin is None: + if model_annotations is None: # pragma: no cover + model_annotations = get_type_hints( + annotation, globalns=globalns, localns=localns, include_extras=include_extras + ) + typevar_map = {p: p for p in annotation.__pydantic_generic_metadata__["parameters"]} + else: + if model_annotations is None: + model_annotations = get_type_hints( + origin, globalns=globalns, localns=localns, include_extras=include_extras + ) + args = annotation.__pydantic_generic_metadata__["args"] + parameters = origin.__pydantic_generic_metadata__["parameters"] + typevar_map = dict(zip(parameters, args)) + + return {n: _substitute_typevars(type_, typevar_map) for n, type_ in model_annotations.items()} + + +@deprecated(version="2.6.2") +def pydantic_get_unwrapped_annotation_and_type_hints(annotation: Any) -> tuple[Any, dict[str, Any]]: # pragma: pver + """Get the unwrapped annotation and the type hints after resolving generics. + + Args: + annotation: A type annotation. + + Returns: + A tuple containing the unwrapped annotation and the type hints. + """ + + if is_generic(annotation): + origin = pydantic_unwrap_and_get_origin(annotation) + return origin or annotation, pydantic_get_type_hints_with_generics_resolved(annotation, include_extras=True) + return annotation, get_type_hints(annotation, include_extras=True) + + +def is_pydantic_2_model( + obj: type[pydantic_v1.BaseModel | pydantic_v2.BaseModel], # pyright: ignore +) -> TypeGuard[pydantic_v2.BaseModel]: # pyright: ignore + return pydantic_v2 is not Empty and issubclass(obj, pydantic_v2.BaseModel) + + +def is_pydantic_undefined(value: Any) -> bool: + return any(v is value for v in PYDANTIC_UNDEFINED_SENTINELS) + + +def create_field_definitions_for_computed_fields( + model: type[pydantic_v1.BaseModel | pydantic_v2.BaseModel], # pyright: ignore + prefer_alias: bool, +) -> dict[str, FieldDefinition]: + """Create field definitions for computed fields. + + Args: + model: A pydantic model. + prefer_alias: Whether to prefer the alias or the name of the field. + + Returns: + A dictionary containing the field definitions for the computed fields. + """ + pydantic_decorators = getattr(model, "__pydantic_decorators__", None) + if pydantic_decorators is None: + return {} + + def get_name(k: str, dec: Any) -> str: + if not dec.info.alias: + return k + return dec.info.alias if prefer_alias else k # type: ignore[no-any-return] + + return { + (name := get_name(k, dec)): FieldDefinition.from_annotation( + Annotated[ + dec.info.return_type, + KwargDefinition(title=dec.info.title, description=dec.info.description, read_only=True), + ], + name=name, + ) + for k, dec in pydantic_decorators.computed_fields.items() + } diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/repository/__init__.py b/venv/lib/python3.11/site-packages/litestar/contrib/repository/__init__.py new file mode 100644 index 0000000..3a329c9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/repository/__init__.py @@ -0,0 +1,20 @@ +from litestar.utils import warn_deprecation + + +def __getattr__(attr_name: str) -> object: + from litestar import repository + + if attr_name in repository.__all__: + warn_deprecation( + deprecated_name=f"litestar.contrib.repository.{attr_name}", + version="2.1", + kind="import", + removal_in="3.0", + info=f"importing {attr_name} from 'litestar.contrib.repository' is deprecated, please" + f"import it from 'litestar.repository.{attr_name}' instead", + ) + + value = globals()[attr_name] = getattr(repository, attr_name) + return value + + raise AttributeError(f"module {__name__!r} has no attribute {attr_name!r}") diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1b3c01a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/exceptions.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/exceptions.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..29868c6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/exceptions.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/filters.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/filters.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a69925c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/filters.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/handlers.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/handlers.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5d82074 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/handlers.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/testing.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/testing.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..474bdcf --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/repository/__pycache__/testing.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/repository/abc/__init__.py b/venv/lib/python3.11/site-packages/litestar/contrib/repository/abc/__init__.py new file mode 100644 index 0000000..17efc1b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/repository/abc/__init__.py @@ -0,0 +1,20 @@ +from litestar.utils import warn_deprecation + + +def __getattr__(attr_name: str) -> object: + from litestar.repository import abc + + if attr_name in abc.__all__: + warn_deprecation( + deprecated_name=f"litestar.contrib.repository.abc.{attr_name}", + version="2.1", + kind="import", + removal_in="3.0", + info=f"importing {attr_name} from 'litestar.contrib.repository.abc' is deprecated, please" + f"import it from 'litestar.repository.abc.{attr_name}' instead", + ) + + value = globals()[attr_name] = getattr(abc, attr_name) + return value + + raise AttributeError(f"module {__name__!r} has no attribute {attr_name!r}") diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/repository/abc/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/repository/abc/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..666c136 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/repository/abc/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/repository/exceptions.py b/venv/lib/python3.11/site-packages/litestar/contrib/repository/exceptions.py new file mode 100644 index 0000000..1e7e738 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/repository/exceptions.py @@ -0,0 +1,20 @@ +from litestar.utils import warn_deprecation + + +def __getattr__(attr_name: str) -> object: + from litestar.repository import exceptions + + if attr_name in exceptions.__all__: + warn_deprecation( + deprecated_name=f"litestar.repository.contrib.exceptions.{attr_name}", + version="2.1", + kind="import", + removal_in="3.0", + info=f"importing {attr_name} from 'litestar.contrib.repository.exceptions' is deprecated, please" + f"import it from 'litestar.repository.exceptions.{attr_name}' instead", + ) + + value = globals()[attr_name] = getattr(exceptions, attr_name) + return value + + raise AttributeError(f"module {__name__!r} has no attribute {attr_name!r}") diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/repository/filters.py b/venv/lib/python3.11/site-packages/litestar/contrib/repository/filters.py new file mode 100644 index 0000000..3736a76 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/repository/filters.py @@ -0,0 +1,20 @@ +from litestar.utils import warn_deprecation + + +def __getattr__(attr_name: str) -> object: + from litestar.repository import filters + + if attr_name in filters.__all__: + warn_deprecation( + deprecated_name=f"litestar.repository.contrib.filters.{attr_name}", + version="2.1", + kind="import", + removal_in="3.0", + info=f"importing {attr_name} from 'litestar.contrib.repository.filters' is deprecated, please" + f"import it from 'litestar.repository.filters.{attr_name}' instead", + ) + + value = globals()[attr_name] = getattr(filters, attr_name) + return value + + raise AttributeError(f"module {__name__!r} has no attribute {attr_name!r}") diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/repository/handlers.py b/venv/lib/python3.11/site-packages/litestar/contrib/repository/handlers.py new file mode 100644 index 0000000..b1174e4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/repository/handlers.py @@ -0,0 +1,20 @@ +from litestar.utils import warn_deprecation + + +def __getattr__(attr_name: str) -> object: + from litestar.repository import handlers + + if attr_name in handlers.__all__: + warn_deprecation( + deprecated_name=f"litestar.repository.contrib.handlers.{attr_name}", + version="2.1", + kind="import", + removal_in="3.0", + info=f"importing {attr_name} from 'litestar.contrib.repository.handlers' is deprecated, please" + f"import it from 'litestar.repository.handlers.{attr_name}' instead", + ) + + value = globals()[attr_name] = getattr(handlers, attr_name) + return value + + raise AttributeError(f"module {__name__!r} has no attribute {attr_name!r}") diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/repository/testing.py b/venv/lib/python3.11/site-packages/litestar/contrib/repository/testing.py new file mode 100644 index 0000000..b78fea8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/repository/testing.py @@ -0,0 +1,20 @@ +from litestar.utils import warn_deprecation + + +def __getattr__(attr_name: str) -> object: + from litestar.repository.testing import generic_mock_repository + + if attr_name in generic_mock_repository.__all__: + warn_deprecation( + deprecated_name=f"litestar.repository.contrib.testing.{attr_name}", + version="2.1", + kind="import", + removal_in="3.0", + info=f"importing {attr_name} from 'litestar.contrib.repository.testing' is deprecated, please" + f"import it from 'litestar.repository.testing.{attr_name}' instead", + ) + + value = globals()[attr_name] = getattr(generic_mock_repository, attr_name) + return value + + raise AttributeError(f"module {__name__!r} has no attribute {attr_name!r}") diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__init__.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__init__.py diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1b0f40a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..199862b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__pycache__/dto.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__pycache__/dto.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5467843 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__pycache__/dto.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__pycache__/types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__pycache__/types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..09819d1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/__pycache__/types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/base.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/base.py new file mode 100644 index 0000000..9ce9608 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/base.py @@ -0,0 +1,38 @@ +"""Application ORM configuration.""" + +from __future__ import annotations + +try: + # v0.6.0+ + from advanced_alchemy._listeners import touch_updated_timestamp # pyright: ignore +except ImportError: + from advanced_alchemy.base import touch_updated_timestamp # type: ignore[no-redef,attr-defined] + +from advanced_alchemy.base import ( + AuditColumns, + BigIntAuditBase, + BigIntBase, + BigIntPrimaryKey, + CommonTableAttributes, + ModelProtocol, + UUIDAuditBase, + UUIDBase, + UUIDPrimaryKey, + create_registry, + orm_registry, +) + +__all__ = ( + "AuditColumns", + "BigIntAuditBase", + "BigIntBase", + "BigIntPrimaryKey", + "CommonTableAttributes", + "create_registry", + "ModelProtocol", + "touch_updated_timestamp", + "UUIDAuditBase", + "UUIDBase", + "UUIDPrimaryKey", + "orm_registry", +) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/dto.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/dto.py new file mode 100644 index 0000000..beea75d --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/dto.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from advanced_alchemy.extensions.litestar.dto import SQLAlchemyDTO, SQLAlchemyDTOConfig + +__all__ = ("SQLAlchemyDTO", "SQLAlchemyDTOConfig") diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/__init__.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/__init__.py new file mode 100644 index 0000000..5bc913c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/__init__.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyPlugin + +from .init import ( + AsyncSessionConfig, + EngineConfig, + GenericSessionConfig, + GenericSQLAlchemyConfig, + SQLAlchemyAsyncConfig, + SQLAlchemyInitPlugin, + SQLAlchemySyncConfig, + SyncSessionConfig, +) +from .serialization import SQLAlchemySerializationPlugin + +__all__ = ( + "AsyncSessionConfig", + "EngineConfig", + "GenericSQLAlchemyConfig", + "GenericSessionConfig", + "SQLAlchemyAsyncConfig", + "SQLAlchemyInitPlugin", + "SQLAlchemyPlugin", + "SQLAlchemySerializationPlugin", + "SQLAlchemySyncConfig", + "SyncSessionConfig", +) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..019ff72 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/__pycache__/serialization.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/__pycache__/serialization.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..7bd0360 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/__pycache__/serialization.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/__init__.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/__init__.py new file mode 100644 index 0000000..2e507c1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/__init__.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from .config import ( + AsyncSessionConfig, + EngineConfig, + GenericSessionConfig, + GenericSQLAlchemyConfig, + SQLAlchemyAsyncConfig, + SQLAlchemySyncConfig, + SyncSessionConfig, +) +from .plugin import SQLAlchemyInitPlugin + +__all__ = ( + "AsyncSessionConfig", + "EngineConfig", + "GenericSQLAlchemyConfig", + "GenericSessionConfig", + "SQLAlchemyAsyncConfig", + "SQLAlchemyInitPlugin", + "SQLAlchemySyncConfig", + "SyncSessionConfig", +) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..43da1aa --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/__pycache__/plugin.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/__pycache__/plugin.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..e06ec42 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/__pycache__/plugin.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__init__.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__init__.py new file mode 100644 index 0000000..f2e39da --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__init__.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from .asyncio import AsyncSessionConfig, SQLAlchemyAsyncConfig +from .common import GenericSessionConfig, GenericSQLAlchemyConfig +from .engine import EngineConfig +from .sync import SQLAlchemySyncConfig, SyncSessionConfig + +__all__ = ( + "AsyncSessionConfig", + "EngineConfig", + "GenericSQLAlchemyConfig", + "GenericSessionConfig", + "SQLAlchemyAsyncConfig", + "SQLAlchemySyncConfig", + "SyncSessionConfig", +) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2a316ef --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/asyncio.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/asyncio.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..23fe455 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/asyncio.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/common.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/common.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6e83a95 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/common.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/compat.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/compat.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4ba72bb --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/compat.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/engine.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/engine.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1d4553b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/engine.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/sync.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/sync.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..d777bb9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/__pycache__/sync.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/asyncio.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/asyncio.py new file mode 100644 index 0000000..434c761 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/asyncio.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from advanced_alchemy.config.asyncio import AlembicAsyncConfig, AsyncSessionConfig +from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import ( + SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig, +) +from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import ( + autocommit_before_send_handler, + default_before_send_handler, +) +from sqlalchemy.ext.asyncio import AsyncEngine + +from litestar.contrib.sqlalchemy.plugins.init.config.compat import _CreateEngineMixin + +__all__ = ( + "SQLAlchemyAsyncConfig", + "AlembicAsyncConfig", + "AsyncSessionConfig", + "default_before_send_handler", + "autocommit_before_send_handler", +) + + +class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig, _CreateEngineMixin[AsyncEngine]): ... diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/common.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/common.py new file mode 100644 index 0000000..9afc48c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/common.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from advanced_alchemy.config.common import GenericAlembicConfig, GenericSessionConfig, GenericSQLAlchemyConfig +from advanced_alchemy.extensions.litestar.plugins.init.config.common import ( + SESSION_SCOPE_KEY, + SESSION_TERMINUS_ASGI_EVENTS, +) + +__all__ = ( + "SESSION_SCOPE_KEY", + "SESSION_TERMINUS_ASGI_EVENTS", + "GenericSQLAlchemyConfig", + "GenericSessionConfig", + "GenericAlembicConfig", +) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/compat.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/compat.py new file mode 100644 index 0000000..d76dea7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/compat.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar + +from litestar.utils.deprecation import deprecated + +if TYPE_CHECKING: + from sqlalchemy import Engine + from sqlalchemy.ext.asyncio import AsyncEngine + + +EngineT_co = TypeVar("EngineT_co", bound="Engine | AsyncEngine", covariant=True) + + +class HasGetEngine(Protocol[EngineT_co]): + def get_engine(self) -> EngineT_co: ... + + +class _CreateEngineMixin(Generic[EngineT_co]): + @deprecated(version="2.1.1", removal_in="3.0.0", alternative="get_engine()") + def create_engine(self: HasGetEngine[EngineT_co]) -> EngineT_co: + return self.get_engine() diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/engine.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/engine.py new file mode 100644 index 0000000..31c3f5e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/engine.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from advanced_alchemy.config.engine import EngineConfig + +__all__ = ("EngineConfig",) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/sync.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/sync.py new file mode 100644 index 0000000..48a029b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/config/sync.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from advanced_alchemy.config.sync import AlembicSyncConfig, SyncSessionConfig +from advanced_alchemy.extensions.litestar.plugins.init.config.sync import ( + SQLAlchemySyncConfig as _SQLAlchemySyncConfig, +) +from advanced_alchemy.extensions.litestar.plugins.init.config.sync import ( + autocommit_before_send_handler, + default_before_send_handler, +) +from sqlalchemy import Engine + +from litestar.contrib.sqlalchemy.plugins.init.config.compat import _CreateEngineMixin + +__all__ = ( + "SQLAlchemySyncConfig", + "AlembicSyncConfig", + "SyncSessionConfig", + "default_before_send_handler", + "autocommit_before_send_handler", +) + + +class SQLAlchemySyncConfig(_SQLAlchemySyncConfig, _CreateEngineMixin[Engine]): ... diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/plugin.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/plugin.py new file mode 100644 index 0000000..dbf814b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/init/plugin.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyInitPlugin + +__all__ = ("SQLAlchemyInitPlugin",) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/serialization.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/serialization.py new file mode 100644 index 0000000..539b194 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/plugins/serialization.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from advanced_alchemy.extensions.litestar.plugins import SQLAlchemySerializationPlugin + +__all__ = ("SQLAlchemySerializationPlugin",) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__init__.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__init__.py new file mode 100644 index 0000000..64a8359 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__init__.py @@ -0,0 +1,11 @@ +from ._async import SQLAlchemyAsyncRepository +from ._sync import SQLAlchemySyncRepository +from ._util import wrap_sqlalchemy_exception +from .types import ModelT + +__all__ = ( + "SQLAlchemyAsyncRepository", + "SQLAlchemySyncRepository", + "ModelT", + "wrap_sqlalchemy_exception", +) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..3e6dacf --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/_async.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/_async.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..461fbad --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/_async.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/_sync.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/_sync.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..bd6d80f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/_sync.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/_util.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/_util.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..561fd53 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/_util.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..fa8191e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/__pycache__/types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/_async.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/_async.py new file mode 100644 index 0000000..417ec35 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/_async.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from advanced_alchemy.repository import SQLAlchemyAsyncRepository + +__all__ = ("SQLAlchemyAsyncRepository",) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/_sync.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/_sync.py new file mode 100644 index 0000000..58ccbb8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/_sync.py @@ -0,0 +1,7 @@ +# Do not edit this file directly. It has been autogenerated from +# litestar/contrib/sqlalchemy/repository/_async.py +from __future__ import annotations + +from advanced_alchemy.repository import SQLAlchemySyncRepository + +__all__ = ("SQLAlchemySyncRepository",) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/_util.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/_util.py new file mode 100644 index 0000000..c0ce747 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/_util.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from advanced_alchemy.repository._util import get_instrumented_attr, wrap_sqlalchemy_exception + +__all__ = ( + "wrap_sqlalchemy_exception", + "get_instrumented_attr", +) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/types.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/types.py new file mode 100644 index 0000000..2a4204c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/repository/types.py @@ -0,0 +1,15 @@ +from advanced_alchemy.repository.typing import ( + ModelT, + RowT, + SelectT, + SQLAlchemyAsyncRepositoryT, + SQLAlchemySyncRepositoryT, +) + +__all__ = ( + "ModelT", + "SelectT", + "RowT", + "SQLAlchemySyncRepositoryT", + "SQLAlchemyAsyncRepositoryT", +) diff --git a/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/types.py b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/types.py new file mode 100644 index 0000000..61fb75a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/contrib/sqlalchemy/types.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from advanced_alchemy.types import GUID, ORA_JSONB, BigIntIdentity, DateTimeUTC, JsonB + +__all__ = ( + "GUID", + "ORA_JSONB", + "DateTimeUTC", + "BigIntIdentity", + "JsonB", +) diff --git a/venv/lib/python3.11/site-packages/litestar/controller.py b/venv/lib/python3.11/site-packages/litestar/controller.py new file mode 100644 index 0000000..967454b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/controller.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +import types +from collections import defaultdict +from copy import deepcopy +from operator import attrgetter +from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast + +from litestar._layers.utils import narrow_response_cookies, narrow_response_headers +from litestar.exceptions import ImproperlyConfiguredException +from litestar.handlers.base import BaseRouteHandler +from litestar.handlers.http_handlers import HTTPRouteHandler +from litestar.handlers.websocket_handlers import WebsocketRouteHandler +from litestar.types.empty import Empty +from litestar.utils import ensure_async_callable, normalize_path +from litestar.utils.signature import add_types_to_signature_namespace + +__all__ = ("Controller",) + + +if TYPE_CHECKING: + from litestar.connection import Request, WebSocket + from litestar.datastructures import CacheControlHeader, ETag + from litestar.dto import AbstractDTO + from litestar.openapi.spec import SecurityRequirement + from litestar.response import Response + from litestar.router import Router + from litestar.types import ( + AfterRequestHookHandler, + AfterResponseHookHandler, + BeforeRequestHookHandler, + Dependencies, + ExceptionHandlersMap, + Guard, + Middleware, + ParametersMap, + ResponseCookies, + TypeEncodersMap, + ) + from litestar.types.composite_types import ResponseHeaders, TypeDecodersSequence + from litestar.types.empty import EmptyType + + +class Controller: + """The Litestar Controller class. + + Subclass this class to create 'view' like components and utilize OOP. + """ + + __slots__ = ( + "after_request", + "after_response", + "before_request", + "dependencies", + "dto", + "etag", + "exception_handlers", + "guards", + "include_in_schema", + "middleware", + "opt", + "owner", + "parameters", + "path", + "request_class", + "response_class", + "response_cookies", + "response_headers", + "return_dto", + "security", + "signature_namespace", + "tags", + "type_encoders", + "type_decoders", + "websocket_class", + ) + + after_request: AfterRequestHookHandler | None + """A sync or async function executed before a :class:`Request <.connection.Request>` is passed to any route handler. + + If this function returns a value, the request will not reach the route handler, and instead this value will be used. + """ + after_response: AfterResponseHookHandler | None + """A sync or async function called after the response has been awaited. + + It receives the :class:`Request <.connection.Request>` instance and should not return any values. + """ + before_request: BeforeRequestHookHandler | None + """A sync or async function called immediately before calling the route handler. + + It receives the :class:`Request <.connection.Request>` instance and any non-``None`` return value is used for the + response, bypassing the route handler. + """ + cache_control: CacheControlHeader | None + """A :class:`CacheControlHeader <.datastructures.CacheControlHeader>` header to add to route handlers of this + controller. + + Can be overridden by route handlers. + """ + dependencies: Dependencies | None + """A string keyed dictionary of dependency :class:`Provider <.di.Provide>` instances.""" + dto: type[AbstractDTO] | None | EmptyType + """:class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and validation of request data.""" + etag: ETag | None + """An ``etag`` header of type :class:`ETag <.datastructures.ETag>` to add to route handlers of this controller. + + Can be overridden by route handlers. + """ + exception_handlers: ExceptionHandlersMap | None + """A map of handler functions to status codes and/or exception types.""" + guards: Sequence[Guard] | None + """A sequence of :class:`Guard <.types.Guard>` callables.""" + include_in_schema: bool | EmptyType + """A boolean flag dictating whether the route handler should be documented in the OpenAPI schema""" + middleware: Sequence[Middleware] | None + """A sequence of :class:`Middleware <.types.Middleware>`.""" + opt: Mapping[str, Any] | None + """A string key mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or wherever you + have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + """ + owner: Router + """The :class:`Router <.router.Router>` or :class:`Litestar <litestar.app.Litestar>` app that owns the controller. + + This value is set internally by Litestar and it should not be set when subclassing the controller. + """ + parameters: ParametersMap | None + """A mapping of :class:`Parameter <.params.Parameter>` definitions available to all application paths.""" + path: str + """A path fragment for the controller. + + All route handlers under the controller will have the fragment appended to them. If not set it defaults to ``/``. + """ + request_class: type[Request] | None + """A custom subclass of :class:`Request <.connection.Request>` to be used as the default request for all route + handlers under the controller. + """ + response_class: type[Response] | None + """A custom subclass of :class:`Response <.response.Response>` to be used as the default response for all route + handlers under the controller. + """ + response_cookies: ResponseCookies | None + """A list of :class:`Cookie <.datastructures.Cookie>` instances.""" + response_headers: ResponseHeaders | None + """A string keyed dictionary mapping :class:`ResponseHeader <.datastructures.ResponseHeader>` instances.""" + return_dto: type[AbstractDTO] | None | EmptyType + """:class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing outbound response + data. + """ + tags: Sequence[str] | None + """A sequence of string tags that will be appended to the schema of all route handlers under the controller.""" + security: Sequence[SecurityRequirement] | None + """A sequence of dictionaries that to the schema of all route handlers under the controller.""" + signature_namespace: dict[str, Any] + """A mapping of names to types for use in forward reference resolution during signature modeling.""" + signature_types: Sequence[Any] + """A sequence of types for use in forward reference resolution during signature modeling. + + These types will be added to the signature namespace using their ``__name__`` attribute. + """ + type_decoders: TypeDecodersSequence | None + """A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization.""" + type_encoders: TypeEncodersMap | None + """A mapping of types to callables that transform them into types supported for serialization.""" + websocket_class: type[WebSocket] | None + """A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as the default websocket for all route + handlers under the controller. + """ + + def __init__(self, owner: Router) -> None: + """Initialize a controller. + + Should only be called by routers as part of controller registration. + + Args: + owner: An instance of :class:`Router <.router.Router>` + """ + # Since functions set on classes are bound, we need replace the bound instance with the class version and wrap + # it to ensure it does not get bound. + for key in ("after_request", "after_response", "before_request"): + cls_value = getattr(type(self), key, None) + if callable(cls_value): + setattr(self, key, ensure_async_callable(cls_value)) + + if not hasattr(self, "dto"): + self.dto = Empty + + if not hasattr(self, "return_dto"): + self.return_dto = Empty + + if not hasattr(self, "include_in_schema"): + self.include_in_schema = Empty + + self.signature_namespace = add_types_to_signature_namespace( + getattr(self, "signature_types", []), getattr(self, "signature_namespace", {}) + ) + + for key in self.__slots__: + if not hasattr(self, key): + setattr(self, key, None) + + self.response_cookies = narrow_response_cookies(self.response_cookies) + self.response_headers = narrow_response_headers(self.response_headers) + self.path = normalize_path(self.path or "/") + self.owner = owner + + def get_route_handlers(self) -> list[BaseRouteHandler]: + """Get a controller's route handlers and set the controller as the handlers' owner. + + Returns: + A list containing a copy of the route handlers defined on the controller + """ + + route_handlers: list[BaseRouteHandler] = [] + controller_names = set(dir(Controller)) + self_handlers = [ + getattr(self, name) + for name in dir(self) + if name not in controller_names and isinstance(getattr(self, name), BaseRouteHandler) + ] + self_handlers.sort(key=attrgetter("handler_id")) + for self_handler in self_handlers: + route_handler = deepcopy(self_handler) + # at the point we get a reference to the handler function, it's unbound, so + # we replace it with a regular bound method here + route_handler._fn = types.MethodType(route_handler._fn, self) + route_handler.owner = self + route_handlers.append(route_handler) + + self.validate_route_handlers(route_handlers=route_handlers) + + return route_handlers + + def validate_route_handlers(self, route_handlers: list[BaseRouteHandler]) -> None: + """Validate that the combination of path and decorator method or type are unique on the controller. + + Args: + route_handlers: The controller's route handlers. + + Raises: + ImproperlyConfiguredException + + Returns: + None + """ + paths: defaultdict[str, set[str]] = defaultdict(set) + + for route_handler in route_handlers: + if isinstance(route_handler, HTTPRouteHandler): + methods: set[str] = cast("set[str]", route_handler.http_methods) + elif isinstance(route_handler, WebsocketRouteHandler): + methods = {"websocket"} + else: + methods = {"asgi"} + + for path in route_handler.paths: + if (entry := paths[path]) and (intersection := entry.intersection(methods)): + raise ImproperlyConfiguredException( + f"the combination of path and method must be unique in a controller - " + f"the following methods {''.join(m.lower() for m in intersection)} for {type(self).__name__} " + f"controller path {path} are not unique" + ) + paths[path].update(methods) diff --git a/venv/lib/python3.11/site-packages/litestar/data_extractors.py b/venv/lib/python3.11/site-packages/litestar/data_extractors.py new file mode 100644 index 0000000..61993b4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/data_extractors.py @@ -0,0 +1,443 @@ +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Iterable, Literal, TypedDict, cast + +from litestar._parsers import parse_cookie_string +from litestar.connection.request import Request +from litestar.datastructures.upload_file import UploadFile +from litestar.enums import HttpMethod, RequestEncodingType + +__all__ = ( + "ConnectionDataExtractor", + "ExtractedRequestData", + "ExtractedResponseData", + "ResponseDataExtractor", + "RequestExtractorField", + "ResponseExtractorField", +) + + +if TYPE_CHECKING: + from litestar.connection import ASGIConnection + from litestar.types import Method + from litestar.types.asgi_types import HTTPResponseBodyEvent, HTTPResponseStartEvent + + +def _obfuscate(values: dict[str, Any], fields_to_obfuscate: set[str]) -> dict[str, Any]: + """Obfuscate values in a dictionary, replacing values with `******` + + Args: + values: A dictionary of strings + fields_to_obfuscate: keys to obfuscate + + Returns: + A dictionary with obfuscated strings + """ + return {key: "*****" if key.lower() in fields_to_obfuscate else value for key, value in values.items()} + + +RequestExtractorField = Literal[ + "path", "method", "content_type", "headers", "cookies", "query", "path_params", "body", "scheme", "client" +] + +ResponseExtractorField = Literal["status_code", "headers", "body", "cookies"] + + +class ExtractedRequestData(TypedDict, total=False): + """Dictionary representing extracted request data.""" + + body: Coroutine[Any, Any, Any] + client: tuple[str, int] + content_type: tuple[str, dict[str, str]] + cookies: dict[str, str] + headers: dict[str, str] + method: Method + path: str + path_params: dict[str, Any] + query: bytes | dict[str, Any] + scheme: str + + +class ConnectionDataExtractor: + """Utility class to extract data from an :class:`ASGIConnection <litestar.connection.ASGIConnection>`, + :class:`Request <litestar.connection.Request>` or :class:`WebSocket <litestar.connection.WebSocket>` instance. + """ + + __slots__ = ( + "connection_extractors", + "request_extractors", + "parse_body", + "parse_query", + "obfuscate_headers", + "obfuscate_cookies", + "skip_parse_malformed_body", + ) + + def __init__( + self, + extract_body: bool = True, + extract_client: bool = True, + extract_content_type: bool = True, + extract_cookies: bool = True, + extract_headers: bool = True, + extract_method: bool = True, + extract_path: bool = True, + extract_path_params: bool = True, + extract_query: bool = True, + extract_scheme: bool = True, + obfuscate_cookies: set[str] | None = None, + obfuscate_headers: set[str] | None = None, + parse_body: bool = False, + parse_query: bool = False, + skip_parse_malformed_body: bool = False, + ) -> None: + """Initialize ``ConnectionDataExtractor`` + + Args: + extract_body: Whether to extract body, (for requests only). + extract_client: Whether to extract the client (host, port) mapping. + extract_content_type: Whether to extract the content type and any options. + extract_cookies: Whether to extract cookies. + extract_headers: Whether to extract headers. + extract_method: Whether to extract the HTTP method, (for requests only). + extract_path: Whether to extract the path. + extract_path_params: Whether to extract path parameters. + extract_query: Whether to extract query parameters. + extract_scheme: Whether to extract the http scheme. + obfuscate_headers: headers keys to obfuscate. Obfuscated values are replaced with '*****'. + obfuscate_cookies: cookie keys to obfuscate. Obfuscated values are replaced with '*****'. + parse_body: Whether to parse the body value or return the raw byte string, (for requests only). + parse_query: Whether to parse query parameters or return the raw byte string. + skip_parse_malformed_body: Whether to skip parsing the body if it is malformed + """ + self.parse_body = parse_body + self.parse_query = parse_query + self.skip_parse_malformed_body = skip_parse_malformed_body + self.obfuscate_headers = {h.lower() for h in (obfuscate_headers or set())} + self.obfuscate_cookies = {c.lower() for c in (obfuscate_cookies or set())} + self.connection_extractors: dict[str, Callable[[ASGIConnection[Any, Any, Any, Any]], Any]] = {} + self.request_extractors: dict[RequestExtractorField, Callable[[Request[Any, Any, Any]], Any]] = {} + if extract_scheme: + self.connection_extractors["scheme"] = self.extract_scheme + if extract_client: + self.connection_extractors["client"] = self.extract_client + if extract_path: + self.connection_extractors["path"] = self.extract_path + if extract_headers: + self.connection_extractors["headers"] = self.extract_headers + if extract_cookies: + self.connection_extractors["cookies"] = self.extract_cookies + if extract_query: + self.connection_extractors["query"] = self.extract_query + if extract_path_params: + self.connection_extractors["path_params"] = self.extract_path_params + if extract_method: + self.request_extractors["method"] = self.extract_method + if extract_content_type: + self.request_extractors["content_type"] = self.extract_content_type + if extract_body: + self.request_extractors["body"] = self.extract_body + + def __call__(self, connection: ASGIConnection[Any, Any, Any, Any]) -> ExtractedRequestData: + """Extract data from the connection, returning a dictionary of values. + + Notes: + - The value for ``body`` - if present - is an unresolved Coroutine and as such should be awaited by the receiver. + + Args: + connection: An ASGI connection or its subclasses. + + Returns: + A string keyed dictionary of extracted values. + """ + extractors = ( + {**self.connection_extractors, **self.request_extractors} # type: ignore[misc] + if isinstance(connection, Request) + else self.connection_extractors + ) + return cast("ExtractedRequestData", {key: extractor(connection) for key, extractor in extractors.items()}) + + async def extract( + self, connection: ASGIConnection[Any, Any, Any, Any], fields: Iterable[str] + ) -> ExtractedRequestData: + extractors = ( + {**self.connection_extractors, **self.request_extractors} # type: ignore[misc] + if isinstance(connection, Request) + else self.connection_extractors + ) + data = {} + for key, extractor in extractors.items(): + if key not in fields: + continue + if inspect.iscoroutinefunction(extractor): + value = await extractor(connection) + else: + value = extractor(connection) + data[key] = value + return cast("ExtractedRequestData", data) + + @staticmethod + def extract_scheme(connection: ASGIConnection[Any, Any, Any, Any]) -> str: + """Extract the scheme from an ``ASGIConnection`` + + Args: + connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. + + Returns: + The connection's scope["scheme"] value + """ + return connection.scope["scheme"] + + @staticmethod + def extract_client(connection: ASGIConnection[Any, Any, Any, Any]) -> tuple[str, int]: + """Extract the client from an ``ASGIConnection`` + + Args: + connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. + + Returns: + The connection's scope["client"] value or a default value. + """ + return connection.scope.get("client") or ("", 0) + + @staticmethod + def extract_path(connection: ASGIConnection[Any, Any, Any, Any]) -> str: + """Extract the path from an ``ASGIConnection`` + + Args: + connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. + + Returns: + The connection's scope["path"] value + """ + return connection.scope["path"] + + def extract_headers(self, connection: ASGIConnection[Any, Any, Any, Any]) -> dict[str, str]: + """Extract headers from an ``ASGIConnection`` + + Args: + connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. + + Returns: + A dictionary with the connection's headers. + """ + headers = {k.decode("latin-1"): v.decode("latin-1") for k, v in connection.scope["headers"]} + return _obfuscate(headers, self.obfuscate_headers) if self.obfuscate_headers else headers + + def extract_cookies(self, connection: ASGIConnection[Any, Any, Any, Any]) -> dict[str, str]: + """Extract cookies from an ``ASGIConnection`` + + Args: + connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. + + Returns: + A dictionary with the connection's cookies. + """ + return _obfuscate(connection.cookies, self.obfuscate_cookies) if self.obfuscate_cookies else connection.cookies + + def extract_query(self, connection: ASGIConnection[Any, Any, Any, Any]) -> Any: + """Extract query from an ``ASGIConnection`` + + Args: + connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. + + Returns: + Either a dictionary with the connection's parsed query string or the raw query byte-string. + """ + return connection.query_params.dict() if self.parse_query else connection.scope.get("query_string", b"") + + @staticmethod + def extract_path_params(connection: ASGIConnection[Any, Any, Any, Any]) -> dict[str, Any]: + """Extract the path parameters from an ``ASGIConnection`` + + Args: + connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. + + Returns: + A dictionary with the connection's path parameters. + """ + return connection.path_params + + @staticmethod + def extract_method(request: Request[Any, Any, Any]) -> Method: + """Extract the method from an ``ASGIConnection`` + + Args: + request: A :class:`Request <litestar.connection.Request>` instance. + + Returns: + The request's scope["method"] value. + """ + return request.scope["method"] + + @staticmethod + def extract_content_type(request: Request[Any, Any, Any]) -> tuple[str, dict[str, str]]: + """Extract the content-type from an ``ASGIConnection`` + + Args: + request: A :class:`Request <litestar.connection.Request>` instance. + + Returns: + A tuple containing the request's parsed 'Content-Type' header. + """ + return request.content_type + + async def extract_body(self, request: Request[Any, Any, Any]) -> Any: + """Extract the body from an ``ASGIConnection`` + + Args: + request: A :class:`Request <litestar.connection.Request>` instance. + + Returns: + Either the parsed request body or the raw byte-string. + """ + if request.method == HttpMethod.GET: + return None + if not self.parse_body: + return await request.body() + try: + request_encoding_type = request.content_type[0] + if request_encoding_type == RequestEncodingType.JSON: + return await request.json() + form_data = await request.form() + if request_encoding_type == RequestEncodingType.URL_ENCODED: + return dict(form_data) + return { + key: repr(value) if isinstance(value, UploadFile) else value for key, value in form_data.multi_items() + } + except Exception as exc: + if self.skip_parse_malformed_body: + return await request.body() + raise exc + + +class ExtractedResponseData(TypedDict, total=False): + """Dictionary representing extracted response data.""" + + body: bytes + status_code: int + headers: dict[str, str] + cookies: dict[str, str] + + +class ResponseDataExtractor: + """Utility class to extract data from a ``Message``""" + + __slots__ = ("extractors", "parse_headers", "obfuscate_headers", "obfuscate_cookies") + + def __init__( + self, + extract_body: bool = True, + extract_cookies: bool = True, + extract_headers: bool = True, + extract_status_code: bool = True, + obfuscate_cookies: set[str] | None = None, + obfuscate_headers: set[str] | None = None, + ) -> None: + """Initialize ``ResponseDataExtractor`` with options. + + Args: + extract_body: Whether to extract the body. + extract_cookies: Whether to extract the cookies. + extract_headers: Whether to extract the headers. + extract_status_code: Whether to extract the status code. + obfuscate_cookies: cookie keys to obfuscate. Obfuscated values are replaced with '*****'. + obfuscate_headers: headers keys to obfuscate. Obfuscated values are replaced with '*****'. + """ + self.obfuscate_headers = {h.lower() for h in (obfuscate_headers or set())} + self.obfuscate_cookies = {c.lower() for c in (obfuscate_cookies or set())} + self.extractors: dict[ + ResponseExtractorField, Callable[[tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]], Any] + ] = {} + if extract_body: + self.extractors["body"] = self.extract_response_body + if extract_status_code: + self.extractors["status_code"] = self.extract_status_code + if extract_headers: + self.extractors["headers"] = self.extract_headers + if extract_cookies: + self.extractors["cookies"] = self.extract_cookies + + def __call__(self, messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> ExtractedResponseData: + """Extract data from the response, returning a dictionary of values. + + Args: + messages: A tuple containing + :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>` + and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`. + + Returns: + A string keyed dictionary of extracted values. + """ + return cast("ExtractedResponseData", {key: extractor(messages) for key, extractor in self.extractors.items()}) + + @staticmethod + def extract_response_body(messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> bytes: + """Extract the response body from a ``Message`` + + Args: + messages: A tuple containing + :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>` + and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`. + + Returns: + The Response's body as a byte-string. + """ + return messages[1]["body"] + + @staticmethod + def extract_status_code(messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> int: + """Extract a status code from a ``Message`` + + Args: + messages: A tuple containing + :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>` + and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`. + + Returns: + The Response's status-code. + """ + return messages[0]["status"] + + def extract_headers(self, messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> dict[str, str]: + """Extract headers from a ``Message`` + + Args: + messages: A tuple containing + :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>` + and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`. + + Returns: + The Response's headers dict. + """ + headers = { + key.decode("latin-1"): value.decode("latin-1") + for key, value in filter(lambda x: x[0].lower() != b"set-cookie", messages[0]["headers"]) + } + return ( + _obfuscate( + headers, + self.obfuscate_headers, + ) + if self.obfuscate_headers + else headers + ) + + def extract_cookies(self, messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> dict[str, str]: + """Extract cookies from a ``Message`` + + Args: + messages: A tuple containing + :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>` + and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`. + + Returns: + The Response's cookies dict. + """ + if cookie_string := ";".join( + [x[1].decode("latin-1") for x in filter(lambda x: x[0].lower() == b"set-cookie", messages[0]["headers"])] + ): + parsed_cookies = parse_cookie_string(cookie_string) + return _obfuscate(parsed_cookies, self.obfuscate_cookies) if self.obfuscate_cookies else parsed_cookies + return {} diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/__init__.py b/venv/lib/python3.11/site-packages/litestar/datastructures/__init__.py new file mode 100644 index 0000000..74fc25b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/__init__.py @@ -0,0 +1,39 @@ +from litestar.datastructures.cookie import Cookie +from litestar.datastructures.headers import ( + Accept, + CacheControlHeader, + ETag, + Header, + Headers, + MutableScopeHeaders, +) +from litestar.datastructures.multi_dicts import ( + FormMultiDict, + ImmutableMultiDict, + MultiDict, + MultiMixin, +) +from litestar.datastructures.response_header import ResponseHeader +from litestar.datastructures.state import ImmutableState, State +from litestar.datastructures.upload_file import UploadFile +from litestar.datastructures.url import URL, Address + +__all__ = ( + "Accept", + "Address", + "CacheControlHeader", + "Cookie", + "ETag", + "FormMultiDict", + "Header", + "Headers", + "ImmutableMultiDict", + "ImmutableState", + "MultiDict", + "MultiMixin", + "MutableScopeHeaders", + "ResponseHeader", + "State", + "UploadFile", + "URL", +) diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..085a180 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/cookie.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/cookie.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..87a2646 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/cookie.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/headers.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/headers.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4a8e1ad --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/headers.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/multi_dicts.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/multi_dicts.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c5a3dd4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/multi_dicts.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/response_header.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/response_header.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0c83cae --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/response_header.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/state.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/state.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6cd1f18 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/state.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/upload_file.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/upload_file.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4db0e27 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/upload_file.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/url.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/url.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a9c2138 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/__pycache__/url.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/cookie.py b/venv/lib/python3.11/site-packages/litestar/datastructures/cookie.py new file mode 100644 index 0000000..21cedc3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/cookie.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from http.cookies import SimpleCookie +from typing import Any, Literal + +__all__ = ("Cookie",) + + +@dataclass +class Cookie: + """Container class for defining a cookie using the ``Set-Cookie`` header. + + See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie for more details regarding this header. + """ + + key: str + """Key for the cookie.""" + path: str = "/" + """Path fragment that must exist in the request url for the cookie to be valid. + + Defaults to ``/``. + """ + value: str | None = field(default=None) + """Value for the cookie, if none given defaults to empty string.""" + max_age: int | None = field(default=None) + """Maximal age of the cookie before its invalidated.""" + expires: int | None = field(default=None) + """Seconds from now until the cookie expires.""" + domain: str | None = field(default=None) + """Domain for which the cookie is valid.""" + secure: bool | None = field(default=None) + """Https is required for the cookie.""" + httponly: bool | None = field(default=None) + """Forbids javascript to access the cookie via ``document.cookie``.""" + samesite: Literal["lax", "strict", "none"] = field(default="lax") + """Controls whether or not a cookie is sent with cross-site requests. + + Defaults to 'lax'. + """ + description: str | None = field(default=None) + """Description of the response cookie header for OpenAPI documentation.""" + documentation_only: bool = field(default=False) + """Defines the Cookie instance as for OpenAPI documentation purpose only.""" + + @property + def simple_cookie(self) -> SimpleCookie: + """Get a simple cookie object from the values. + + Returns: + A :class:`SimpleCookie <http.cookies.SimpleCookie>` + """ + simple_cookie: SimpleCookie = SimpleCookie() + simple_cookie[self.key] = self.value or "" + + namespace = simple_cookie[self.key] + for key, value in self.dict.items(): + if key in {"key", "value"}: + continue + if value is not None: + updated_key = key + if updated_key == "max_age": + updated_key = "max-age" + namespace[updated_key] = value + + return simple_cookie + + def to_header(self, **kwargs: Any) -> str: + """Return a string representation suitable to be sent as HTTP headers. + + Args: + **kwargs: Any kwargs to pass to the simple cookie output method. + """ + return self.simple_cookie.output(**kwargs).strip() + + def to_encoded_header(self) -> tuple[bytes, bytes]: + """Create encoded header for ASGI ``send``. + + Returns: + A two tuple of bytes. + """ + return b"set-cookie", self.to_header(header="").strip().encode("latin-1") + + @property + def dict(self) -> dict[str, Any]: + """Get the cookie as a dict. + + Returns: + A dict of values + """ + return { + k: v + for k, v in asdict(self).items() + if k not in {"documentation_only", "description", "__pydantic_initialised__"} + } + + def __hash__(self) -> int: + return hash((self.key, self.path, self.domain)) + + def __eq__(self, other: Any) -> bool: + """Determine whether two cookie instances are equal according to the cookie spec, i.e. hey have a similar path, + domain and key. + + Args: + other: An arbitrary value + + Returns: + A boolean + """ + if isinstance(other, Cookie): + return other.key == self.key and other.path == self.path and other.domain == self.domain + return False diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/headers.py b/venv/lib/python3.11/site-packages/litestar/datastructures/headers.py new file mode 100644 index 0000000..f3e9bd7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/headers.py @@ -0,0 +1,534 @@ +import re +from abc import ABC, abstractmethod +from contextlib import suppress +from copy import copy +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Pattern, + Tuple, + Union, + cast, +) + +from multidict import CIMultiDict, CIMultiDictProxy, MultiMapping +from typing_extensions import get_type_hints + +from litestar._multipart import parse_content_header +from litestar.datastructures.multi_dicts import MultiMixin +from litestar.dto.base_dto import AbstractDTO +from litestar.exceptions import ImproperlyConfiguredException, ValidationException +from litestar.types.empty import Empty +from litestar.typing import FieldDefinition +from litestar.utils.dataclass import simple_asdict +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from litestar.types.asgi_types import ( + HeaderScope, + Message, + RawHeaders, + RawHeadersList, + Scope, + ) + +__all__ = ("Accept", "CacheControlHeader", "ETag", "Header", "Headers", "MutableScopeHeaders") + +ETAG_RE = re.compile(r'([Ww]/)?"(.+)"') +PRINTABLE_ASCII_RE: Pattern[str] = re.compile(r"^[ -~]+$") + + +def _encode_headers(headers: Iterable[Tuple[str, str]]) -> "RawHeadersList": + return [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers] + + +class Headers(CIMultiDictProxy[str], MultiMixin[str]): + """An immutable, case-insensitive multi dict for HTTP headers.""" + + def __init__(self, headers: Optional[Union[Mapping[str, str], "RawHeaders", MultiMapping]] = None) -> None: + """Initialize ``Headers``. + + Args: + headers: Initial value. + """ + if not isinstance(headers, MultiMapping): + headers_: Union[Mapping[str, str], List[Tuple[str, str]]] = {} + if headers: + if isinstance(headers, Mapping): + headers_ = headers # pyright: ignore + else: + headers_ = [(key.decode("latin-1"), value.decode("latin-1")) for key, value in headers] + + super().__init__(CIMultiDict(headers_)) + else: + super().__init__(headers) + self._header_list: Optional[RawHeadersList] = None + + @classmethod + def from_scope(cls, scope: "Scope") -> "Headers": + """Create headers from a send-message. + + Args: + scope: The ASGI connection scope. + + Returns: + Headers + + Raises: + ValueError: If the message does not have a ``headers`` key + """ + connection_state = ScopeState.from_scope(scope) + if (headers := connection_state.headers) is Empty: + headers = connection_state.headers = cls(scope["headers"]) + return headers + + def to_header_list(self) -> "RawHeadersList": + """Raw header value. + + Returns: + A list of tuples contain the header and header-value as bytes + """ + # Since ``Headers`` are immutable, this can be cached + if not self._header_list: + self._header_list = _encode_headers((key, value) for key in set(self) for value in self.getall(key)) + return self._header_list + + +class MutableScopeHeaders(MutableMapping): + """A case-insensitive, multidict-like structure that can be used to mutate headers within a + :class:`Scope <.types.Scope>` + """ + + def __init__(self, scope: Optional["HeaderScope"] = None) -> None: + """Initialize ``MutableScopeHeaders`` from a ``HeaderScope``. + + Args: + scope: The ASGI connection scope. + """ + self.headers: RawHeadersList + if scope is not None: + if not isinstance(scope["headers"], list): + scope["headers"] = list(scope["headers"]) + + self.headers = cast("RawHeadersList", scope["headers"]) + else: + self.headers = [] + + @classmethod + def from_message(cls, message: "Message") -> "MutableScopeHeaders": + """Construct a header from a message object. + + Args: + message: :class:`Message <.types.Message>`. + + Returns: + MutableScopeHeaders. + + Raises: + ValueError: If the message does not have a ``headers`` key. + """ + if "headers" not in message: + raise ValueError(f"Invalid message type: {message['type']!r}") + + return cls(cast("HeaderScope", message)) + + def add(self, key: str, value: str) -> None: + """Add a header to the scope. + + Notes: + - This method keeps duplicates. + + Args: + key: Header key. + value: Header value. + + Returns: + None. + """ + self.headers.append((key.lower().encode("latin-1"), value.encode("latin-1"))) + + def getall(self, key: str, default: Optional[List[str]] = None) -> List[str]: + """Get all values of a header. + + Args: + key: Header key. + default: Default value to return if ``name`` is not found. + + Returns: + A list of strings. + + Raises: + KeyError: if no header for ``name`` was found and ``default`` is not given. + """ + name = key.lower() + values = [ + header_value.decode("latin-1") + for header_name, header_value in self.headers + if header_name.decode("latin-1").lower() == name + ] + if not values: + if default: + return default + raise KeyError + return values + + def extend_header_value(self, key: str, value: str) -> None: + """Extend a multivalued header. + + Notes: + - A multivalues header is a header that can take a comma separated list. + - If the header previously did not exist, it will be added. + + Args: + key: Header key. + value: Header value to add, + + Returns: + None + """ + existing = self.get(key) + if existing is not None: + value = ",".join([*existing.split(","), value]) + self[key] = value + + def __getitem__(self, key: str) -> str: + """Get the first header matching ``name``""" + name = key.lower() + for header in self.headers: + if header[0].decode("latin-1").lower() == name: + return header[1].decode("latin-1") + raise KeyError + + def _find_indices(self, key: str) -> List[int]: + name = key.lower() + return [i for i, (name_, _) in enumerate(self.headers) if name_.decode("latin-1").lower() == name] + + def __setitem__(self, key: str, value: str) -> None: + """Set a header in the scope, overwriting duplicates.""" + name_encoded = key.lower().encode("latin-1") + value_encoded = value.encode("latin-1") + if indices := self._find_indices(key): + for i in indices[1:]: + del self.headers[i] + self.headers[indices[0]] = (name_encoded, value_encoded) + else: + self.headers.append((name_encoded, value_encoded)) + + def __delitem__(self, key: str) -> None: + """Delete all headers matching ``name``""" + indices = self._find_indices(key) + for i in indices[::-1]: + del self.headers[i] + + def __len__(self) -> int: + """Return the length of the internally stored headers, including duplicates.""" + return len(self.headers) + + def __iter__(self) -> Iterator[str]: + """Create an iterator of header names including duplicates.""" + return iter(h[0].decode("latin-1") for h in self.headers) + + +@dataclass +class Header(ABC): + """An abstract type for HTTP headers.""" + + HEADER_NAME: ClassVar[str] = "" + + documentation_only: bool = False + """Defines the header instance as for OpenAPI documentation purpose only.""" + + @abstractmethod + def _get_header_value(self) -> str: + """Get the header value as string.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def from_header(cls, header_value: str) -> "Header": + """Construct a header from its string representation.""" + + def to_header(self, include_header_name: bool = False) -> str: + """Get the header as string. + + Args: + include_header_name: should include the header name in the return value. If set to false + the return value will only include the header value. if set to true the return value + will be: ``<header name>: <header value>``. Defaults to false. + """ + + if not self.HEADER_NAME: + raise ImproperlyConfiguredException("Missing header name") + + return (f"{self.HEADER_NAME}: " if include_header_name else "") + self._get_header_value() + + +@dataclass +class CacheControlHeader(Header): + """A ``cache-control`` header.""" + + HEADER_NAME: ClassVar[str] = "cache-control" + + max_age: Optional[int] = None + """Accessor for the ``max-age`` directive.""" + s_maxage: Optional[int] = None + """Accessor for the ``s-maxage`` directive.""" + no_cache: Optional[bool] = None + """Accessor for the ``no-cache`` directive.""" + no_store: Optional[bool] = None + """Accessor for the ``no-store`` directive.""" + private: Optional[bool] = None + """Accessor for the ``private`` directive.""" + public: Optional[bool] = None + """Accessor for the ``public`` directive.""" + no_transform: Optional[bool] = None + """Accessor for the ``no-transform`` directive.""" + must_revalidate: Optional[bool] = None + """Accessor for the ``must-revalidate`` directive.""" + proxy_revalidate: Optional[bool] = None + """Accessor for the ``proxy-revalidate`` directive.""" + must_understand: Optional[bool] = None + """Accessor for the ``must-understand`` directive.""" + immutable: Optional[bool] = None + """Accessor for the ``immutable`` directive.""" + stale_while_revalidate: Optional[int] = None + """Accessor for the ``stale-while-revalidate`` directive.""" + + _field_definitions: ClassVar[Optional[Dict[str, FieldDefinition]]] = None + + def _get_header_value(self) -> str: + """Get the header value as string.""" + + cc_items = [ + key.replace("_", "-") if isinstance(value, bool) else f"{key.replace('_', '-')}={value}" + for key, value in simple_asdict(self, exclude_none=True, exclude={"documentation_only"}).items() + ] + return ", ".join(cc_items) + + @classmethod + def from_header(cls, header_value: str) -> "CacheControlHeader": + """Create a ``CacheControlHeader`` instance from the header value. + + Args: + header_value: the header value as string + + Returns: + An instance of ``CacheControlHeader`` + """ + + cc_items = [v.strip() for v in header_value.split(",")] + kwargs: Dict[str, Any] = {} + field_definitions = cls._get_field_definitions() + for cc_item in cc_items: + key_value = cc_item.split("=") + key_value[0] = key_value[0].replace("-", "_") + if len(key_value) == 1: + kwargs[key_value[0]] = True + elif len(key_value) == 2: + key, value = key_value + if key not in field_definitions: + raise ImproperlyConfiguredException("Invalid cache-control header") + kwargs[key] = cls._convert_to_type(value, field_definition=field_definitions[key]) + else: + raise ImproperlyConfiguredException("Invalid cache-control header value") + + try: + return CacheControlHeader(**kwargs) + except TypeError as exc: + raise ImproperlyConfiguredException from exc + + @classmethod + def prevent_storing(cls) -> "CacheControlHeader": + """Create a ``cache-control`` header with the ``no-store`` directive which indicates that any caches of any kind + (private or shared) should not store this response. + """ + + return cls(no_store=True) + + @classmethod + def _get_field_definitions(cls) -> Dict[str, FieldDefinition]: + """Get the type annotations for the ``CacheControlHeader`` class properties. + + This is needed due to the conversion from pydantic models to dataclasses. Dataclasses do not support + automatic conversion of types like pydantic models do. + + Returns: + A dictionary of type annotations + + """ + + if cls._field_definitions is None: + cls._field_definitions = {} + for key, value in get_type_hints(cls, include_extras=True).items(): + definition = FieldDefinition.from_kwarg(annotation=value, name=key) + # resolve_model_type so that field_definition.raw has the real raw type e.g. <class 'bool'> + cls._field_definitions[key] = AbstractDTO.resolve_model_type(definition) + return cls._field_definitions + + @classmethod + def _convert_to_type(cls, value: str, field_definition: FieldDefinition) -> Any: + """Convert the value to the expected type. + + Args: + value: the value of the cache-control directive + field_definition: the field definition for the value to convert + + Returns: + The value converted to the expected type + """ + # bool values shouldn't be initiated since they should have been caught earlier in the from_header method and + # set with a value of True + expected_type = field_definition.raw + if expected_type is bool: + raise ImproperlyConfiguredException("Invalid cache-control header value") + return expected_type(value) + + +@dataclass +class ETag(Header): + """An ``etag`` header.""" + + HEADER_NAME: ClassVar[str] = "etag" + + weak: bool = False + value: Optional[str] = None # only ASCII characters + + def _get_header_value(self) -> str: + value = f'"{self.value}"' + return f"W/{value}" if self.weak else value + + @classmethod + def from_header(cls, header_value: str) -> "ETag": + """Construct an ``etag`` header from its string representation. + + Note that this will unquote etag-values + """ + match = ETAG_RE.match(header_value) + if not match: + raise ImproperlyConfiguredException + weak, value = match.group(1, 2) + try: + return cls(weak=bool(weak), value=value) + except ValueError as exc: + raise ImproperlyConfiguredException from exc + + def __post_init__(self) -> None: + if self.documentation_only is False and self.value is None: + raise ValidationException("value must be set if documentation_only is false") + if self.value and not PRINTABLE_ASCII_RE.fullmatch(self.value): + raise ValidationException("value must only contain ASCII printable characters") + + +class MediaTypeHeader: + """A helper class for ``Accept`` header parsing.""" + + __slots__ = ("maintype", "subtype", "params", "_params_str") + + def __init__(self, type_str: str) -> None: + # preserve the original parameters, because the order might be + # changed in the dict + self._params_str = "".join(type_str.partition(";")[1:]) + + full_type, self.params = parse_content_header(type_str) + self.maintype, _, self.subtype = full_type.partition("/") + + def __str__(self) -> str: + return f"{self.maintype}/{self.subtype}{self._params_str}" + + @property + def priority(self) -> Tuple[int, int]: + # Use fixed point values with two decimals to avoid problems + # when comparing float values + quality = 100 + if "q" in self.params: + with suppress(ValueError): + quality = int(100 * float(self.params["q"])) + + if self.maintype == "*": + specificity = 0 + elif self.subtype == "*": + specificity = 1 + elif not self.params or ("q" in self.params and len(self.params) == 1): + # no params or 'q' is the only one which we ignore + specificity = 2 + else: + specificity = 3 + + return quality, specificity + + def match(self, other: "MediaTypeHeader") -> bool: + return next( + (False for key, value in self.params.items() if key != "q" and value != other.params.get(key)), + False + if self.subtype != "*" and other.subtype != "*" and self.subtype != other.subtype + else self.maintype == "*" or other.maintype == "*" or self.maintype == other.maintype, + ) + + +class Accept: + """An ``Accept`` header.""" + + __slots__ = ("_accepted_types",) + + def __init__(self, accept_value: str) -> None: + self._accepted_types = [MediaTypeHeader(t) for t in accept_value.split(",")] + self._accepted_types.sort(key=lambda t: t.priority, reverse=True) + + def __len__(self) -> int: + return len(self._accepted_types) + + def __getitem__(self, key: int) -> str: + return str(self._accepted_types[key]) + + def __iter__(self) -> Iterator[str]: + return map(str, self._accepted_types) + + def best_match(self, provided_types: List[str], default: Optional[str] = None) -> Optional[str]: + """Find the best matching media type for the request. + + Args: + provided_types: A list of media types that can be provided as a response. These types + can contain a wildcard ``*`` character in the main- or subtype part. + default: The media type that is returned if none of the provided types match. + + Returns: + The best matching media type. If the matching provided type contains wildcard characters, + they are replaced with the corresponding part of the accepted type. Otherwise the + provided type is returned as-is. + """ + types = [MediaTypeHeader(t) for t in provided_types] + + for accepted in self._accepted_types: + for provided in types: + if provided.match(accepted): + # Return the accepted type with wildcards replaced + # by concrete parts from the provided type + result = copy(provided) + if result.subtype == "*": + result.subtype = accepted.subtype + if result.maintype == "*": + result.maintype = accepted.maintype + return str(result) + return default + + def accepts(self, media_type: str) -> bool: + """Check if the request accepts the specified media type. + + If multiple media types can be provided, it is better to use :func:`best_match`. + + Args: + media_type: The media type to check for. + + Returns: + True if the request accepts ``media_type``. + """ + return self.best_match([media_type]) == media_type diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/multi_dicts.py b/venv/lib/python3.11/site-packages/litestar/datastructures/multi_dicts.py new file mode 100644 index 0000000..7702e1a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/multi_dicts.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Any, Generator, Generic, Iterable, Mapping, TypeVar + +from multidict import MultiDict as BaseMultiDict +from multidict import MultiDictProxy, MultiMapping + +from litestar.datastructures.upload_file import UploadFile + +if TYPE_CHECKING: + from typing_extensions import Self + + +__all__ = ("FormMultiDict", "ImmutableMultiDict", "MultiDict", "MultiMixin") + + +T = TypeVar("T") + + +class MultiMixin(Generic[T], MultiMapping[T], ABC): + """Mixin providing common methods for multi dicts, used by :class:`ImmutableMultiDict` and :class:`MultiDict`""" + + def dict(self) -> dict[str, list[Any]]: + """Return the multi-dict as a dict of lists. + + Returns: + A dict of lists + """ + return {k: self.getall(k) for k in set(self.keys())} + + def multi_items(self) -> Generator[tuple[str, T], None, None]: + """Get all keys and values, including duplicates. + + Returns: + A list of tuples containing key-value pairs + """ + for key in set(self): + for value in self.getall(key): + yield key, value + + +class MultiDict(BaseMultiDict[T], MultiMixin[T], Generic[T]): + """MultiDict, using :class:`MultiDict <multidict.MultiDictProxy>`.""" + + def __init__(self, args: MultiMapping | Mapping[str, T] | Iterable[tuple[str, T]] | None = None) -> None: + """Initialize ``MultiDict`` from a`MultiMapping``, + :class:`Mapping <typing.Mapping>` or an iterable of tuples. + + Args: + args: Mapping-like structure to create the ``MultiDict`` from + """ + super().__init__(args or {}) + + def immutable(self) -> ImmutableMultiDict[T]: + """Create an. + + :class:`ImmutableMultiDict` view. + + Returns: + An immutable multi dict + """ + return ImmutableMultiDict[T](self) # pyright: ignore + + def copy(self) -> Self: + """Return a shallow copy""" + return type(self)(list(self.multi_items())) + + +class ImmutableMultiDict(MultiDictProxy[T], MultiMixin[T], Generic[T]): + """Immutable MultiDict, using class:`MultiDictProxy <multidict.MultiDictProxy>`.""" + + def __init__(self, args: MultiMapping | Mapping[str, Any] | Iterable[tuple[str, Any]] | None = None) -> None: + """Initialize ``ImmutableMultiDict`` from a `MultiMapping``, + :class:`Mapping <typing.Mapping>` or an iterable of tuples. + + Args: + args: Mapping-like structure to create the ``ImmutableMultiDict`` from + """ + super().__init__(BaseMultiDict(args or {})) + + def mutable_copy(self) -> MultiDict[T]: + """Create a mutable copy as a :class:`MultiDict` + + Returns: + A mutable multi dict + """ + return MultiDict(list(self.multi_items())) + + def copy(self) -> Self: # type: ignore[override] + """Return a shallow copy""" + return type(self)(self.items()) + + +class FormMultiDict(ImmutableMultiDict[Any]): + """MultiDict for form data.""" + + async def close(self) -> None: + """Close all files in the multi-dict. + + Returns: + None + """ + for _, value in self.multi_items(): + if isinstance(value, UploadFile): + await value.close() diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/response_header.py b/venv/lib/python3.11/site-packages/litestar/datastructures/response_header.py new file mode 100644 index 0000000..f781d0c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/response_header.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from litestar.exceptions import ImproperlyConfiguredException + +if TYPE_CHECKING: + from litestar.openapi.spec import Example + +__all__ = ("ResponseHeader",) + + +@dataclass +class ResponseHeader: + """Container type for a response header.""" + + name: str + """Header name""" + + documentation_only: bool = False + """Defines the ResponseHeader instance as for OpenAPI documentation purpose only.""" + + value: str | None = None + """Value to set for the response header.""" + + description: str | None = None + """A brief description of the parameter. This could contain examples of + use. + + [CommonMark syntax](https://spec.commonmark.org/) MAY be used for + rich text representation. + """ + + required: bool = False + """Determines whether this parameter is mandatory. + + If the [parameter location](https://spec.openapis.org/oas/v3.1.0#parameterIn) is `"path"`, this property is **REQUIRED** and its value MUST be `true`. + Otherwise, the property MAY be included and its default value is `false`. + """ + + deprecated: bool = False + """Specifies that a parameter is deprecated and SHOULD be transitioned out + of usage. + + Default value is `false`. + """ + + allow_empty_value: bool = False + """Sets the ability to pass empty-valued parameters. This is valid only for + `query` parameters and allows sending a parameter with an empty value. + Default value is `false`. If. + + [style](https://spec.openapis.org/oas/v3.1.0#parameterStyle) is used, and if behavior is `n/a` (cannot be + serialized), the value of `allowEmptyValue` SHALL be ignored. Use of this property is NOT RECOMMENDED, as it is + likely to be removed in a later revision. + + The rules for serialization of the parameter are specified in one of two ways. + For simpler scenarios, a [schema](https://spec.openapis.org/oas/v3.1.0#parameterSchema) and [style](https://spec.openapis.org/oas/v3.1.0#parameterStyle) + can describe the structure and syntax of the parameter. + """ + + style: str | None = None + """Describes how the parameter value will be serialized depending on the + type of the parameter value. Default values (based on value of `in`): + + - for `query` - `form`; + - for `path` - `simple`; + - for `header` - `simple`; + - for `cookie` - `form`. + """ + + explode: bool | None = None + """When this is true, parameter values of type `array` or `object` generate + separate parameters for each value of the array or key-value pair of the + map. + + For other types of parameters this property has no effect. + When [style](https://spec.openapis.org/oas/v3.1.0#parameterStyle) is `form`, the default value is `true`. + For all other styles, the default value is `false`. + """ + + allow_reserved: bool = False + """Determines whether the parameter value SHOULD allow reserved characters, + as defined by. + + [RFC3986](https://tools.ietf.org/html/rfc3986#section-2.2) `:/?#[]@!$&'()*+,;=` to be included without percent- + encoding. + + This property only applies to parameters with an `in` value of `query`. The default value is `false`. + """ + + example: Any | None = None + """Example of the parameter's potential value. + + The example SHOULD match the specified schema and encoding + properties if present. The `example` field is mutually exclusive of + the `examples` field. Furthermore, if referencing a `schema` that + contains an example, the `example` value SHALL _override_ the + example provided by the schema. To represent examples of media types + that cannot naturally be represented in JSON or YAML, a string value + can contain the example with escaping where necessary. + """ + + examples: dict[str, Example] | None = None + """Examples of the parameter's potential value. Each example SHOULD contain + a value in the correct format as specified in the parameter encoding. The + `examples` field is mutually exclusive of the `example` field. Furthermore, + if referencing a `schema` that contains an example, the `examples` value + SHALL _override_ the example provided by the schema. + + For more complex scenarios, the [content](https://spec.openapis.org/oas/v3.1.0#parameterContent) property + can define the media type and schema of the parameter. + A parameter MUST contain either a `schema` property, or a `content` property, but not both. + When `example` or `examples` are provided in conjunction with the `schema` object, + the example MUST follow the prescribed serialization strategy for the parameter. + """ + + def __post_init__(self) -> None: + """Ensure that either value is set or the instance is for documentation_only.""" + if not self.documentation_only and self.value is None: + raise ImproperlyConfiguredException("value must be set if documentation_only is false") + + def __hash__(self) -> int: + return hash(self.name) diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/state.py b/venv/lib/python3.11/site-packages/litestar/datastructures/state.py new file mode 100644 index 0000000..71980e0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/state.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +from copy import copy, deepcopy +from threading import RLock +from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Iterator, Mapping, MutableMapping + +if TYPE_CHECKING: + from typing_extensions import Self + +__all__ = ("ImmutableState", "State") + + +class ImmutableState(Mapping[str, Any]): + """An object meant to store arbitrary state. + + It can be accessed using dot notation while exposing dict like functionalities. + """ + + __slots__ = ( + "_state", + "_deep_copy", + ) + + _state: dict[str, Any] + + def __init__( + self, state: ImmutableState | Mapping[str, Any] | Iterable[tuple[str, Any]], deep_copy: bool = True + ) -> None: + """Initialize an ``ImmutableState`` instance. + + Args: + state: An object to initialize the state from. Can be a dict, an instance of :class:`ImmutableState`, or a tuple + of key value paris. + deep_copy: Whether to 'deepcopy' the passed in state. + + Examples: + .. code-block:: python + + from litestar.datastructures import ImmutableState + + state_dict = {"first": 1, "second": 2, "third": 3, "fourth": 4} + state = ImmutableState(state_dict) + + # state implements the Mapping type: + assert len(state) == 3 + assert "first" in state + assert not "fourth" in state + assert state["first"] == 1 + assert [(k, v) for k, v in state.items()] == [("first", 1), ("second", 2), ("third", 3)] + + # state implements __bool__ + assert state # state is true when it has values. + assert not State() # state is empty when it has no values. + + # it has a 'dict' method to retrieve a shallow copy of the underlying dict + inner_dict = state.dict() + assert inner_dict == state_dict + + # you can also retrieve a mutable State by calling 'mutable_copy' + mutable_state = state.mutable_copy() + del state["first"] + assert "first" not in state + + """ + if isinstance(state, ImmutableState): + state = state._state + + if not isinstance(state, dict) and isinstance(state, Iterable): + state = dict(state) + + super().__setattr__("_deep_copy", deep_copy) + super().__setattr__("_state", deepcopy(state) if deep_copy else state) + + def __bool__(self) -> bool: + """Return a boolean indicating whether the wrapped dict instance has values.""" + return bool(self._state) + + def __getitem__(self, key: str) -> Any: + """Get the value for the corresponding key from the wrapped state object using subscription notation. + + Args: + key: Key to access. + + Raises: + KeyError + + Returns: + A value from the wrapped state instance. + """ + return self._state[key] + + def __iter__(self) -> Iterator[str]: + """Return an iterator iterating the wrapped state dict. + + Returns: + An iterator of strings + """ + return iter(self._state) + + def __len__(self) -> int: + """Return length of the wrapped state dict. + + Returns: + An integer + """ + return len(self._state) + + def __getattr__(self, key: str) -> Any: + """Get the value for the corresponding key from the wrapped state object using attribute notation. + + Args: + key: Key to retrieve + + Raises: + AttributeError: if the given attribute is not set. + + Returns: + The retrieved value + """ + try: + return self._state[key] + except KeyError as e: + raise AttributeError from e + + def __copy__(self) -> Self: + """Return a shallow copy of the given state object. + + Customizes how the builtin "copy" function will work. + """ + return self.__class__(self._state, deep_copy=self._deep_copy) # pyright: ignore + + def mutable_copy(self) -> State: + """Return a mutable copy of the state object. + + Returns: + A ``State`` + """ + return State(self._state, deep_copy=self._deep_copy) + + def dict(self) -> dict[str, Any]: + """Return a shallow copy of the wrapped dict. + + Returns: + A dict + """ + return copy(self._state) + + @classmethod + def __get_validators__( + cls, + ) -> Generator[Callable[[ImmutableState | dict[str, Any] | Iterable[tuple[str, Any]]], ImmutableState], None, None]: # type: ignore[valid-type] + """Pydantic compatible method to allow custom parsing of state instances in a SignatureModel.""" + yield cls.validate + + @classmethod + def validate(cls, value: ImmutableState | dict[str, Any] | Iterable[tuple[str, Any]]) -> Self: # type: ignore[valid-type] + """Parse a value and instantiate state inside a SignatureModel. This allows us to use custom subclasses of + state, as well as allows users to decide whether state is mutable or immutable. + + Args: + value: The value from which to initialize the state instance. + + Returns: + An ImmutableState instance + """ + deep_copy = value._deep_copy if isinstance(value, ImmutableState) else False + return cls(value, deep_copy=deep_copy) + + +class State(ImmutableState, MutableMapping[str, Any]): + """An object meant to store arbitrary state. + + It can be accessed using dot notation while exposing dict like functionalities. + """ + + __slots__ = ("_lock",) + + _lock: RLock + + def __init__( + self, + state: ImmutableState | Mapping[str, Any] | Iterable[tuple[str, Any]] | None = None, + deep_copy: bool = False, + ) -> None: + """Initialize a ``State`` instance with an optional value. + + Args: + state: An object to initialize the state from. Can be a dict, an instance of 'ImmutableState', or a tuple of key value paris. + deep_copy: Whether to 'deepcopy' the passed in state. + + .. code-block:: python + :caption: Examples + + from litestar.datastructures import State + + state_dict = {"first": 1, "second": 2, "third": 3, "fourth": 4} + state = State(state_dict) + + # state can be accessed using '.' notation + assert state.fourth == 4 + del state.fourth + + # state implements the Mapping type: + assert len(state) == 3 + assert "first" in state + assert not "fourth" in state + assert state["first"] == 1 + assert [(k, v) for k, v in state.items()] == [("first", 1), ("second", 2), ("third", 3)] + + state["fourth"] = 4 + assert "fourth" in state + del state["fourth"] + + # state implements __bool__ + assert state # state is true when it has values. + assert not State() # state is empty when it has no values. + + # it has shallow copy + copied_state = state.copy() + del copied_state.first + assert state.first + + # it has a 'dict' method to retrieve a shallow copy of the underlying dict + inner_dict = state.dict() + assert inner_dict == state_dict + + # you can get an immutable copy of the state by calling 'immutable_immutable_copy' + immutable_copy = state.immutable_copy() + del immutable_copy.first # raises AttributeError + + """ + + super().__init__(state if state is not None else {}, deep_copy=deep_copy) + super().__setattr__("_lock", RLock()) + + def __delitem__(self, key: str) -> None: + """Delete the value from the key from the wrapped state object using subscription notation. + + Args: + key: Key to delete + + Raises: + KeyError: if the given attribute is not set. + + Returns: + None + """ + + with self._lock: + del self._state[key] + + def __setitem__(self, key: str, value: Any) -> None: + """Set an item in the state using subscription notation. + + Args: + key: Key to set. + value: Value to set. + + Returns: + None + """ + + with self._lock: + self._state[key] = value + + def __setattr__(self, key: str, value: Any) -> None: + """Set an item in the state using attribute notation. + + Args: + key: Key to set. + value: Value to set. + + Returns: + None + """ + + with self._lock: + self._state[key] = value + + def __delattr__(self, key: str) -> None: + """Delete the value from the key from the wrapped state object using attribute notation. + + Args: + key: Key to delete + + Raises: + AttributeError: if the given attribute is not set. + + Returns: + None + """ + + try: + with self._lock: + del self._state[key] + except KeyError as e: + raise AttributeError from e + + def copy(self) -> Self: + """Return a shallow copy of the state object. + + Returns: + A ``State`` + """ + return self.__class__(self.dict(), deep_copy=self._deep_copy) # pyright: ignore + + def immutable_copy(self) -> ImmutableState: + """Return a shallow copy of the state object, setting it to be frozen. + + Returns: + A ``State`` + """ + return ImmutableState(self, deep_copy=self._deep_copy) diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/upload_file.py b/venv/lib/python3.11/site-packages/litestar/datastructures/upload_file.py new file mode 100644 index 0000000..09ad2d3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/upload_file.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from tempfile import SpooledTemporaryFile + +from litestar.concurrency import sync_to_thread +from litestar.constants import ONE_MEGABYTE + +__all__ = ("UploadFile",) + + +class UploadFile: + """Representation of a file upload""" + + __slots__ = ("filename", "file", "content_type", "headers") + + def __init__( + self, + content_type: str, + filename: str, + file_data: bytes | None = None, + headers: dict[str, str] | None = None, + max_spool_size: int = ONE_MEGABYTE, + ) -> None: + """Upload file in-memory container. + + Args: + content_type: Content type for the file. + filename: The filename. + file_data: File data. + headers: Any attached headers. + max_spool_size: The size above which the temporary file will be rolled to disk. + """ + self.filename = filename + self.content_type = content_type + self.file = SpooledTemporaryFile(max_size=max_spool_size) + self.headers = headers or {} + + if file_data: + self.file.write(file_data) + self.file.seek(0) + + @property + def rolled_to_disk(self) -> bool: + """Determine whether the spooled file exceeded the rolled-to-disk threshold and is no longer in memory. + + Returns: + A boolean flag + """ + return getattr(self.file, "_rolled", False) + + async def write(self, data: bytes) -> int: + """Proxy for data writing. + + Args: + data: Byte string to write. + + Returns: + None + """ + if self.rolled_to_disk: + return await sync_to_thread(self.file.write, data) + return self.file.write(data) + + async def read(self, size: int = -1) -> bytes: + """Proxy for data reading. + + Args: + size: position from which to read. + + Returns: + Byte string. + """ + if self.rolled_to_disk: + return await sync_to_thread(self.file.read, size) + return self.file.read(size) + + async def seek(self, offset: int) -> int: + """Async proxy for file seek. + + Args: + offset: start position.. + + Returns: + None. + """ + if self.rolled_to_disk: + return await sync_to_thread(self.file.seek, offset) + return self.file.seek(offset) + + async def close(self) -> None: + """Async proxy for file close. + + Returns: + None. + """ + if self.rolled_to_disk: + return await sync_to_thread(self.file.close) + return self.file.close() + + def __repr__(self) -> str: + return f"{self.filename} - {self.content_type}" diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/url.py b/venv/lib/python3.11/site-packages/litestar/datastructures/url.py new file mode 100644 index 0000000..f3441d0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/url.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import TYPE_CHECKING, Any, NamedTuple +from urllib.parse import SplitResult, urlencode, urlsplit, urlunsplit + +from litestar._parsers import parse_query_string +from litestar.datastructures import MultiDict +from litestar.types import Empty + +if TYPE_CHECKING: + from typing_extensions import Self + + from litestar.types import EmptyType, Scope + +__all__ = ("Address", "URL") + +_DEFAULT_SCHEME_PORTS = {"http": 80, "https": 443, "ftp": 21, "ws": 80, "wss": 443} + + +class Address(NamedTuple): + """Just a network address.""" + + host: str + """Address host.""" + port: int + """Address port.""" + + +def make_absolute_url(path: str | URL, base: str | URL) -> str: + """Create an absolute URL. + + Args: + path: URL path to make absolute + base: URL to use as a base + + Returns: + A string representing the new, absolute URL + """ + url = base if isinstance(base, URL) else URL(base) + netloc = url.netloc + path = url.path.rstrip("/") + str(path) + return str(URL.from_components(scheme=url.scheme, netloc=netloc, path=path)) + + +class URL: + """Representation and modification utilities of a URL.""" + + __slots__ = ( + "_query_params", + "_parsed_url", + "fragment", + "hostname", + "netloc", + "password", + "path", + "port", + "query", + "scheme", + "username", + ) + + _query_params: EmptyType | MultiDict + _parsed_url: str | None + + scheme: str + """URL scheme.""" + netloc: str + """Network location.""" + path: str + """Hierarchical path.""" + fragment: str + """Fragment component.""" + query: str + """Query string.""" + username: str | None + """Username if specified.""" + password: str | None + """Password if specified.""" + port: int | None + """Port if specified.""" + hostname: str | None + """Hostname if specified.""" + + def __new__(cls, url: str | SplitResult) -> URL: + """Create a new instance. + + Args: + url: url string or split result to represent. + """ + return cls._new(url=url) + + @classmethod + @lru_cache + def _new(cls, url: str | SplitResult) -> URL: + instance = super().__new__(cls) + instance._parsed_url = None + + if isinstance(url, str): + result = urlsplit(url) + instance._parsed_url = url + else: + result = url + + instance.scheme = result.scheme + instance.netloc = result.netloc + instance.path = result.path + instance.fragment = result.fragment + instance.query = result.query + instance.username = result.username + instance.password = result.password + instance.port = result.port + instance.hostname = result.hostname + instance._query_params = Empty + + return instance + + @property + def _url(self) -> str: + if not self._parsed_url: + self._parsed_url = str( + urlunsplit( + SplitResult( + scheme=self.scheme, + netloc=self.netloc, + path=self.path, + fragment=self.fragment, + query=self.query, + ) + ) + ) + return self._parsed_url + + @classmethod + @lru_cache + def from_components( + cls, + scheme: str = "", + netloc: str = "", + path: str = "", + fragment: str = "", + query: str = "", + ) -> Self: + """Create a new URL from components. + + Args: + scheme: URL scheme + netloc: Network location + path: Hierarchical path + query: Query component + fragment: Fragment identifier + + Returns: + A new URL with the given components + """ + return cls( + SplitResult( + scheme=scheme, + netloc=netloc, + path=path, + fragment=fragment, + query=query, + ) + ) + + @classmethod + def from_scope(cls, scope: Scope) -> Self: + """Construct a URL from a :class:`Scope <.types.Scope>` + + Args: + scope: A scope + + Returns: + A URL + """ + scheme = scope.get("scheme", "http") + server = scope.get("server") + path = scope.get("root_path", "") + scope["path"] + query_string = scope.get("query_string", b"") + + # we use iteration here because it's faster, and headers might not yet be cached + host = next( + ( + header_value.decode("latin-1") + for header_name, header_value in scope.get("headers", []) + if header_name == b"host" + ), + "", + ) + if server and not host: + host, port = server + default_port = _DEFAULT_SCHEME_PORTS[scheme] + if port != default_port: + host = f"{host}:{port}" + + return cls.from_components( + scheme=scheme if server else "", + query=query_string.decode(), + netloc=host, + path=path, + ) + + def with_replacements( + self, + scheme: str = "", + netloc: str = "", + path: str = "", + query: str | MultiDict | None | EmptyType = Empty, + fragment: str = "", + ) -> Self: + """Create a new URL, replacing the given components. + + Args: + scheme: URL scheme + netloc: Network location + path: Hierarchical path + query: Raw query string + fragment: Fragment identifier + + Returns: + A new URL with the given components replaced + """ + if isinstance(query, MultiDict): + query = urlencode(query=query) + + query = (query if query is not Empty else self.query) or "" + + return type(self).from_components( + scheme=scheme or self.scheme, + netloc=netloc or self.netloc, + path=path or self.path, + query=query, + fragment=fragment or self.fragment, + ) + + @property + def query_params(self) -> MultiDict: + """Query parameters of a URL as a :class:`MultiDict <.datastructures.multi_dicts.MultiDict>` + + Returns: + A :class:`MultiDict <.datastructures.multi_dicts.MultiDict>` with query parameters + + Notes: + - The returned ``MultiDict`` is mutable, :class:`URL` itself is *immutable*, + therefore mutating the query parameters will not directly mutate the ``URL``. + If you want to modify query parameters, make modifications in the + multidict and pass them back to :meth:`with_replacements` + """ + if self._query_params is Empty: + self._query_params = MultiDict(parse_query_string(query_string=self.query.encode())) + return self._query_params + + def __str__(self) -> str: + return self._url + + def __eq__(self, other: Any) -> bool: + if isinstance(other, (str, URL)): + return str(self) == str(other) + return NotImplemented # pragma: no cover + + def __repr__(self) -> str: + return f"{type(self).__name__}({self._url!r})" diff --git a/venv/lib/python3.11/site-packages/litestar/di.py b/venv/lib/python3.11/site-packages/litestar/di.py new file mode 100644 index 0000000..066a128 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/di.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from inspect import isasyncgenfunction, isclass, isgeneratorfunction +from typing import TYPE_CHECKING, Any + +from litestar.exceptions import ImproperlyConfiguredException +from litestar.types import Empty +from litestar.utils import ensure_async_callable +from litestar.utils.predicates import is_async_callable +from litestar.utils.warnings import ( + warn_implicit_sync_to_thread, + warn_sync_to_thread_with_async_callable, + warn_sync_to_thread_with_generator, +) + +if TYPE_CHECKING: + from litestar._signature import SignatureModel + from litestar.types import AnyCallable + from litestar.utils.signature import ParsedSignature + +__all__ = ("Provide",) + + +class Provide: + """Wrapper class for dependency injection""" + + __slots__ = ( + "dependency", + "has_sync_callable", + "has_sync_generator_dependency", + "has_async_generator_dependency", + "parsed_fn_signature", + "signature_model", + "sync_to_thread", + "use_cache", + "value", + ) + + parsed_fn_signature: ParsedSignature + signature_model: type[SignatureModel] + dependency: AnyCallable + + def __init__( + self, + dependency: AnyCallable | type[Any], + use_cache: bool = False, + sync_to_thread: bool | None = None, + ) -> None: + """Initialize ``Provide`` + + Args: + dependency: Callable to call or class to instantiate. The result is then injected as a dependency. + use_cache: Cache the dependency return value. Defaults to False. + sync_to_thread: Run sync code in an async thread. Defaults to False. + """ + if not callable(dependency): + raise ImproperlyConfiguredException("Provider dependency must a callable value") + + is_class_dependency = isclass(dependency) + self.has_sync_generator_dependency = isgeneratorfunction( + dependency if not is_class_dependency else dependency.__call__ # type: ignore[operator] + ) + self.has_async_generator_dependency = isasyncgenfunction( + dependency if not is_class_dependency else dependency.__call__ # type: ignore[operator] + ) + has_generator_dependency = self.has_sync_generator_dependency or self.has_async_generator_dependency + + if has_generator_dependency and use_cache: + raise ImproperlyConfiguredException( + "Cannot cache generator dependency, consider using Lifespan Context instead." + ) + + has_sync_callable = is_class_dependency or not is_async_callable(dependency) # pyright: ignore + + if sync_to_thread is not None: + if has_generator_dependency: + warn_sync_to_thread_with_generator(dependency, stacklevel=3) # type: ignore[arg-type] + elif not has_sync_callable: + warn_sync_to_thread_with_async_callable(dependency, stacklevel=3) # pyright: ignore + elif has_sync_callable and not has_generator_dependency: + warn_implicit_sync_to_thread(dependency, stacklevel=3) # pyright: ignore + + if sync_to_thread and has_sync_callable: + self.dependency = ensure_async_callable(dependency) # pyright: ignore + self.has_sync_callable = False + else: + self.dependency = dependency # pyright: ignore + self.has_sync_callable = has_sync_callable + + self.sync_to_thread = bool(sync_to_thread) + self.use_cache = use_cache + self.value: Any = Empty + + async def __call__(self, **kwargs: Any) -> Any: + """Call the provider's dependency.""" + + if self.use_cache and self.value is not Empty: + return self.value + + if self.has_sync_callable: + value = self.dependency(**kwargs) + else: + value = await self.dependency(**kwargs) + + if self.use_cache: + self.value = value + + return value + + def __eq__(self, other: Any) -> bool: + # check if memory address is identical, otherwise compare attributes + return other is self or ( + isinstance(other, self.__class__) + and other.dependency == self.dependency + and other.use_cache == self.use_cache + and other.value == self.value + ) diff --git a/venv/lib/python3.11/site-packages/litestar/dto/__init__.py b/venv/lib/python3.11/site-packages/litestar/dto/__init__.py new file mode 100644 index 0000000..052e6a4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/__init__.py @@ -0,0 +1,20 @@ +from .base_dto import AbstractDTO +from .config import DTOConfig +from .data_structures import DTOData, DTOFieldDefinition +from .dataclass_dto import DataclassDTO +from .field import DTOField, Mark, dto_field +from .msgspec_dto import MsgspecDTO +from .types import RenameStrategy + +__all__ = ( + "AbstractDTO", + "DTOConfig", + "DTOData", + "DTOField", + "DTOFieldDefinition", + "DataclassDTO", + "Mark", + "MsgspecDTO", + "RenameStrategy", + "dto_field", +) diff --git a/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..7551946 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/_backend.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/_backend.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..d39aea9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/_backend.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/_codegen_backend.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/_codegen_backend.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..610f4b4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/_codegen_backend.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/_types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/_types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..53877b3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/_types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/base_dto.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/base_dto.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..eb9af7a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/base_dto.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/config.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/config.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ba3f03d --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/config.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/data_structures.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/data_structures.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..785f3b0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/data_structures.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/dataclass_dto.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/dataclass_dto.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2388d4a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/dataclass_dto.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/field.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/field.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a6d1cbb --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/field.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/msgspec_dto.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/msgspec_dto.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0d15d6b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/msgspec_dto.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ec01cd4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/__pycache__/types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/dto/_backend.py b/venv/lib/python3.11/site-packages/litestar/dto/_backend.py new file mode 100644 index 0000000..1c48dc0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/_backend.py @@ -0,0 +1,911 @@ +"""DTO backends do the heavy lifting of decoding and validating raw bytes into domain models, and +back again, to bytes. +""" + +from __future__ import annotations + +from dataclasses import replace +from typing import ( + TYPE_CHECKING, + AbstractSet, + Any, + ClassVar, + Collection, + Final, + Mapping, + Protocol, + Union, + cast, +) + +import msgspec +from msgspec import UNSET, Struct, UnsetType, convert, defstruct, field +from typing_extensions import Annotated + +from litestar.dto._types import ( + CollectionType, + CompositeType, + MappingType, + NestedFieldInfo, + SimpleType, + TransferDTOFieldDefinition, + TransferType, + TupleType, + UnionType, +) +from litestar.dto.data_structures import DTOData, DTOFieldDefinition +from litestar.dto.field import Mark +from litestar.enums import RequestEncodingType +from litestar.params import KwargDefinition +from litestar.serialization import decode_json, decode_msgpack +from litestar.types import Empty +from litestar.typing import FieldDefinition +from litestar.utils import unique_name_for_scope + +if TYPE_CHECKING: + from litestar.connection import ASGIConnection + from litestar.dto import AbstractDTO, RenameStrategy + from litestar.types.serialization import LitestarEncodableType + +__all__ = ("DTOBackend",) + + +class CompositeTypeHandler(Protocol): + def __call__( + self, + field_definition: FieldDefinition, + exclude: AbstractSet[str], + include: AbstractSet[str], + rename_fields: dict[str, str], + unique_name: str, + nested_depth: int, + ) -> CompositeType: ... + + +class DTOBackend: + __slots__ = ( + "annotation", + "dto_data_type", + "dto_factory", + "field_definition", + "handler_id", + "is_data_field", + "model_type", + "parsed_field_definitions", + "reverse_name_map", + "transfer_model_type", + "wrapper_attribute_name", + ) + + _seen_model_names: ClassVar[set[str]] = set() + + def __init__( + self, + dto_factory: type[AbstractDTO], + field_definition: FieldDefinition, + handler_id: str, + is_data_field: bool, + model_type: type[Any], + wrapper_attribute_name: str | None, + ) -> None: + """Create dto backend instance. + + Args: + dto_factory: The DTO factory class calling this backend. + field_definition: Parsed type. + handler_id: The name of the handler that this backend is for. + is_data_field: Whether the field is a subclass of DTOData. + model_type: Model type. + wrapper_attribute_name: If the data that DTO should operate upon is wrapped in a generic datastructure, this is the name of the attribute that the data is stored in. + """ + self.dto_factory: Final[type[AbstractDTO]] = dto_factory + self.field_definition: Final[FieldDefinition] = field_definition + self.is_data_field: Final[bool] = is_data_field + self.handler_id: Final[str] = handler_id + self.model_type: Final[type[Any]] = model_type + self.wrapper_attribute_name: Final[str | None] = wrapper_attribute_name + + self.parsed_field_definitions = self.parse_model( + model_type=model_type, + exclude=self.dto_factory.config.exclude, + include=self.dto_factory.config.include, + rename_fields=self.dto_factory.config.rename_fields, + ) + self.transfer_model_type = self.create_transfer_model_type( + model_name=model_type.__name__, field_definitions=self.parsed_field_definitions + ) + self.dto_data_type: type[DTOData] | None = None + + if field_definition.is_subclass_of(DTOData): + self.dto_data_type = field_definition.annotation + field_definition = self.field_definition.inner_types[0] + + self.annotation = build_annotation_for_backend(model_type, field_definition, self.transfer_model_type) + + def parse_model( + self, + model_type: Any, + exclude: AbstractSet[str], + include: AbstractSet[str], + rename_fields: dict[str, str], + nested_depth: int = 0, + ) -> tuple[TransferDTOFieldDefinition, ...]: + """Reduce :attr:`model_type` to a tuple :class:`TransferDTOFieldDefinition` instances. + + Returns: + Fields for data transfer. + """ + defined_fields = [] + generic_field_definitons = list(FieldDefinition.from_annotation(model_type).generic_types or ()) + for field_definition in self.dto_factory.generate_field_definitions(model_type): + if field_definition.is_type_var: + base_arg_field = generic_field_definitons.pop() + field_definition = replace( + field_definition, annotation=base_arg_field.annotation, raw=base_arg_field.raw + ) + + if _should_mark_private(field_definition, self.dto_factory.config.underscore_fields_private): + field_definition.dto_field.mark = Mark.PRIVATE + + try: + transfer_type = self._create_transfer_type( + field_definition=field_definition, + exclude=exclude, + include=include, + rename_fields=rename_fields, + field_name=field_definition.name, + unique_name=field_definition.model_name, + nested_depth=nested_depth, + ) + except RecursionError: + continue + + transfer_field_definition = TransferDTOFieldDefinition.from_dto_field_definition( + field_definition=field_definition, + serialization_name=rename_fields.get(field_definition.name), + transfer_type=transfer_type, + is_partial=self.dto_factory.config.partial, + is_excluded=_should_exclude_field( + field_definition=field_definition, + exclude=exclude, + include=include, + is_data_field=self.is_data_field, + ), + ) + defined_fields.append(transfer_field_definition) + return tuple(defined_fields) + + def _create_transfer_model_name(self, model_name: str) -> str: + long_name_prefix = self.handler_id.split("::")[0] + short_name_prefix = _camelize(long_name_prefix.split(".")[-1], True) + + name_suffix = "RequestBody" if self.is_data_field else "ResponseBody" + + if (short_name := f"{short_name_prefix}{model_name}{name_suffix}") not in self._seen_model_names: + name = short_name + elif (long_name := f"{long_name_prefix}{model_name}{name_suffix}") not in self._seen_model_names: + name = long_name + else: + name = unique_name_for_scope(long_name, self._seen_model_names) + + self._seen_model_names.add(name) + + return name + + def create_transfer_model_type( + self, + model_name: str, + field_definitions: tuple[TransferDTOFieldDefinition, ...], + ) -> type[Struct]: + """Create a model for data transfer. + + Args: + model_name: name for the type that should be unique across all transfer types. + field_definitions: field definitions for the container type. + + Returns: + A ``BackendT`` class. + """ + struct_name = self._create_transfer_model_name(model_name) + + struct = _create_struct_for_field_definitions( + struct_name, field_definitions, self.dto_factory.config.rename_strategy + ) + setattr(struct, "__schema_name__", struct_name) + return struct + + def parse_raw(self, raw: bytes, asgi_connection: ASGIConnection) -> Struct | Collection[Struct]: + """Parse raw bytes into transfer model type. + + Args: + raw: bytes + asgi_connection: The current ASGI Connection + + Returns: + The raw bytes parsed into transfer model type. + """ + request_encoding = RequestEncodingType.JSON + + if (content_type := getattr(asgi_connection, "content_type", None)) and (media_type := content_type[0]): + request_encoding = media_type + + type_decoders = asgi_connection.route_handler.resolve_type_decoders() + + if request_encoding == RequestEncodingType.MESSAGEPACK: + result = decode_msgpack(value=raw, target_type=self.annotation, type_decoders=type_decoders) + else: + result = decode_json(value=raw, target_type=self.annotation, type_decoders=type_decoders) + + return cast("Struct | Collection[Struct]", result) + + def parse_builtins(self, builtins: Any, asgi_connection: ASGIConnection) -> Any: + """Parse builtin types into transfer model type. + + Args: + builtins: Builtin type. + asgi_connection: The current ASGI Connection + + Returns: + The builtin type parsed into transfer model type. + """ + return convert( + obj=builtins, + type=self.annotation, + dec_hook=asgi_connection.route_handler.default_deserializer, + strict=False, + str_keys=True, + ) + + def populate_data_from_builtins(self, builtins: Any, asgi_connection: ASGIConnection) -> Any: + """Populate model instance from builtin types. + + Args: + builtins: Builtin type. + asgi_connection: The current ASGI Connection + + Returns: + Instance or collection of ``model_type`` instances. + """ + if self.dto_data_type: + return self.dto_data_type( + backend=self, + data_as_builtins=_transfer_data( + destination_type=dict, + source_data=self.parse_builtins(builtins, asgi_connection), + field_definitions=self.parsed_field_definitions, + field_definition=self.field_definition, + is_data_field=self.is_data_field, + ), + ) + return self.transfer_data_from_builtins(self.parse_builtins(builtins, asgi_connection)) + + def transfer_data_from_builtins(self, builtins: Any) -> Any: + """Populate model instance from builtin types. + + Args: + builtins: Builtin type. + + Returns: + Instance or collection of ``model_type`` instances. + """ + return _transfer_data( + destination_type=self.model_type, + source_data=builtins, + field_definitions=self.parsed_field_definitions, + field_definition=self.field_definition, + is_data_field=self.is_data_field, + ) + + def populate_data_from_raw(self, raw: bytes, asgi_connection: ASGIConnection) -> Any: + """Parse raw bytes into instance of `model_type`. + + Args: + raw: bytes + asgi_connection: The current ASGI Connection + + Returns: + Instance or collection of ``model_type`` instances. + """ + if self.dto_data_type: + return self.dto_data_type( + backend=self, + data_as_builtins=_transfer_data( + destination_type=dict, + source_data=self.parse_raw(raw, asgi_connection), + field_definitions=self.parsed_field_definitions, + field_definition=self.field_definition, + is_data_field=self.is_data_field, + ), + ) + return _transfer_data( + destination_type=self.model_type, + source_data=self.parse_raw(raw, asgi_connection), + field_definitions=self.parsed_field_definitions, + field_definition=self.field_definition, + is_data_field=self.is_data_field, + ) + + def encode_data(self, data: Any) -> LitestarEncodableType: + """Encode data into a ``LitestarEncodableType``. + + Args: + data: Data to encode. + + Returns: + Encoded data. + """ + if self.wrapper_attribute_name: + wrapped_transfer = _transfer_data( + destination_type=self.transfer_model_type, + source_data=getattr(data, self.wrapper_attribute_name), + field_definitions=self.parsed_field_definitions, + field_definition=self.field_definition, + is_data_field=self.is_data_field, + ) + setattr( + data, + self.wrapper_attribute_name, + wrapped_transfer, + ) + return cast("LitestarEncodableType", data) + + return cast( + "LitestarEncodableType", + _transfer_data( + destination_type=self.transfer_model_type, + source_data=data, + field_definitions=self.parsed_field_definitions, + field_definition=self.field_definition, + is_data_field=self.is_data_field, + ), + ) + + def _get_handler_for_field_definition(self, field_definition: FieldDefinition) -> CompositeTypeHandler | None: + if field_definition.is_union: + return self._create_union_type + + if field_definition.is_tuple: + if len(field_definition.inner_types) == 2 and field_definition.inner_types[1].annotation is Ellipsis: + return self._create_collection_type + return self._create_tuple_type + + if field_definition.is_mapping: + return self._create_mapping_type + + if field_definition.is_non_string_collection: + return self._create_collection_type + return None + + def _create_transfer_type( + self, + field_definition: FieldDefinition, + exclude: AbstractSet[str], + include: AbstractSet[str], + rename_fields: dict[str, str], + field_name: str, + unique_name: str, + nested_depth: int, + ) -> CompositeType | SimpleType: + exclude = _filter_nested_field(exclude, field_name) + include = _filter_nested_field(include, field_name) + rename_fields = _filter_nested_field_mapping(rename_fields, field_name) + + if composite_type_handler := self._get_handler_for_field_definition(field_definition): + return composite_type_handler( + field_definition=field_definition, + exclude=exclude, + include=include, + rename_fields=rename_fields, + unique_name=unique_name, + nested_depth=nested_depth, + ) + + transfer_model: NestedFieldInfo | None = None + + if self.dto_factory.detect_nested_field(field_definition): + if nested_depth == self.dto_factory.config.max_nested_depth: + raise RecursionError + + nested_field_definitions = self.parse_model( + model_type=field_definition.annotation, + exclude=exclude, + include=include, + rename_fields=rename_fields, + nested_depth=nested_depth + 1, + ) + + transfer_model = NestedFieldInfo( + model=self.create_transfer_model_type(unique_name, nested_field_definitions), + field_definitions=nested_field_definitions, + ) + + return SimpleType(field_definition, nested_field_info=transfer_model) + + def _create_collection_type( + self, + field_definition: FieldDefinition, + exclude: AbstractSet[str], + include: AbstractSet[str], + rename_fields: dict[str, str], + unique_name: str, + nested_depth: int, + ) -> CollectionType: + inner_types = field_definition.inner_types + inner_type = self._create_transfer_type( + field_definition=inner_types[0] if inner_types else FieldDefinition.from_annotation(Any), + exclude=exclude, + include=include, + field_name="0", + unique_name=f"{unique_name}_0", + nested_depth=nested_depth, + rename_fields=rename_fields, + ) + return CollectionType( + field_definition=field_definition, inner_type=inner_type, has_nested=inner_type.has_nested + ) + + def _create_mapping_type( + self, + field_definition: FieldDefinition, + exclude: AbstractSet[str], + include: AbstractSet[str], + rename_fields: dict[str, str], + unique_name: str, + nested_depth: int, + ) -> MappingType: + inner_types = field_definition.inner_types + key_type = self._create_transfer_type( + field_definition=inner_types[0] if inner_types else FieldDefinition.from_annotation(Any), + exclude=exclude, + include=include, + field_name="0", + unique_name=f"{unique_name}_0", + nested_depth=nested_depth, + rename_fields=rename_fields, + ) + value_type = self._create_transfer_type( + field_definition=inner_types[1] if inner_types else FieldDefinition.from_annotation(Any), + exclude=exclude, + include=include, + field_name="1", + unique_name=f"{unique_name}_1", + nested_depth=nested_depth, + rename_fields=rename_fields, + ) + return MappingType( + field_definition=field_definition, + key_type=key_type, + value_type=value_type, + has_nested=key_type.has_nested or value_type.has_nested, + ) + + def _create_tuple_type( + self, + field_definition: FieldDefinition, + exclude: AbstractSet[str], + include: AbstractSet[str], + rename_fields: dict[str, str], + unique_name: str, + nested_depth: int, + ) -> TupleType: + inner_types = tuple( + self._create_transfer_type( + field_definition=inner_type, + exclude=exclude, + include=include, + field_name=str(i), + unique_name=f"{unique_name}_{i}", + nested_depth=nested_depth, + rename_fields=rename_fields, + ) + for i, inner_type in enumerate(field_definition.inner_types) + ) + return TupleType( + field_definition=field_definition, + inner_types=inner_types, + has_nested=any(t.has_nested for t in inner_types), + ) + + def _create_union_type( + self, + field_definition: FieldDefinition, + exclude: AbstractSet[str], + include: AbstractSet[str], + rename_fields: dict[str, str], + unique_name: str, + nested_depth: int, + ) -> UnionType: + inner_types = tuple( + self._create_transfer_type( + field_definition=inner_type, + exclude=exclude, + include=include, + field_name=str(i), + unique_name=f"{unique_name}_{i}", + nested_depth=nested_depth, + rename_fields=rename_fields, + ) + for i, inner_type in enumerate(field_definition.inner_types) + ) + return UnionType( + field_definition=field_definition, + inner_types=inner_types, + has_nested=any(t.has_nested for t in inner_types), + ) + + +def _camelize(value: str, capitalize_first_letter: bool) -> str: + return "".join( + word if index == 0 and not capitalize_first_letter else word.capitalize() + for index, word in enumerate(value.split("_")) + ) + + +def _filter_nested_field(field_name_set: AbstractSet[str], field_name: str) -> AbstractSet[str]: + """Filter a nested field name.""" + return {split[1] for s in field_name_set if (split := s.split(".", 1))[0] == field_name and len(split) > 1} + + +def _filter_nested_field_mapping(field_name_mapping: Mapping[str, str], field_name: str) -> dict[str, str]: + """Filter a nested field name.""" + return { + split[1]: v + for s, v in field_name_mapping.items() + if (split := s.split(".", 1))[0] == field_name and len(split) > 1 + } + + +def _transfer_data( + destination_type: type[Any], + source_data: Any | Collection[Any], + field_definitions: tuple[TransferDTOFieldDefinition, ...], + field_definition: FieldDefinition, + is_data_field: bool, +) -> Any: + """Create instance or iterable of instances of ``destination_type``. + + Args: + destination_type: the model type received by the DTO on type narrowing. + source_data: data that has been parsed and validated via the backend. + field_definitions: model field definitions. + field_definition: the parsed type that represents the handler annotation for which the DTO is being applied. + is_data_field: whether the DTO is being applied to a ``data`` field. + + Returns: + Data parsed into ``destination_type``. + """ + if field_definition.is_non_string_collection: + if not field_definition.is_mapping: + return field_definition.instantiable_origin( + _transfer_data( + destination_type=destination_type, + source_data=item, + field_definitions=field_definitions, + field_definition=field_definition.inner_types[0], + is_data_field=is_data_field, + ) + for item in source_data + ) + return field_definition.instantiable_origin( + ( + key, + _transfer_data( + destination_type=destination_type, + source_data=value, + field_definitions=field_definitions, + field_definition=field_definition.inner_types[1], + is_data_field=is_data_field, + ), + ) + for key, value in source_data.items() # type: ignore[union-attr] + ) + + return _transfer_instance_data( + destination_type=destination_type, + source_instance=source_data, + field_definitions=field_definitions, + is_data_field=is_data_field, + ) + + +def _transfer_instance_data( + destination_type: type[Any], + source_instance: Any, + field_definitions: tuple[TransferDTOFieldDefinition, ...], + is_data_field: bool, +) -> Any: + """Create instance of ``destination_type`` with data from ``source_instance``. + + Args: + destination_type: the model type received by the DTO on type narrowing. + source_instance: primitive data that has been parsed and validated via the backend. + field_definitions: model field definitions. + is_data_field: whether the given field is a 'data' kwarg field. + + Returns: + Data parsed into ``model_type``. + """ + unstructured_data = {} + + for field_definition in field_definitions: + if not is_data_field: + if field_definition.is_excluded: + continue + elif not ( + field_definition.name in source_instance + if isinstance(source_instance, Mapping) + else hasattr(source_instance, field_definition.name) + ): + continue + + transfer_type = field_definition.transfer_type + source_value = ( + source_instance[field_definition.name] + if isinstance(source_instance, Mapping) + else getattr(source_instance, field_definition.name) + ) + + if field_definition.is_partial and is_data_field and source_value is UNSET: + continue + + unstructured_data[field_definition.name] = _transfer_type_data( + source_value=source_value, + transfer_type=transfer_type, + nested_as_dict=destination_type is dict, + is_data_field=is_data_field, + ) + + return destination_type(**unstructured_data) + + +def _transfer_type_data( + source_value: Any, + transfer_type: TransferType, + nested_as_dict: bool, + is_data_field: bool, +) -> Any: + if isinstance(transfer_type, SimpleType) and transfer_type.nested_field_info: + if nested_as_dict: + destination_type: Any = dict + elif is_data_field: + destination_type = transfer_type.field_definition.annotation + else: + destination_type = transfer_type.nested_field_info.model + + return _transfer_instance_data( + destination_type=destination_type, + source_instance=source_value, + field_definitions=transfer_type.nested_field_info.field_definitions, + is_data_field=is_data_field, + ) + + if isinstance(transfer_type, UnionType) and transfer_type.has_nested: + return _transfer_nested_union_type_data( + transfer_type=transfer_type, + source_value=source_value, + is_data_field=is_data_field, + ) + + if isinstance(transfer_type, CollectionType): + if transfer_type.has_nested: + return transfer_type.field_definition.instantiable_origin( + _transfer_type_data( + source_value=item, + transfer_type=transfer_type.inner_type, + nested_as_dict=False, + is_data_field=is_data_field, + ) + for item in source_value + ) + + return transfer_type.field_definition.instantiable_origin(source_value) + return source_value + + +def _transfer_nested_union_type_data( + transfer_type: UnionType, + source_value: Any, + is_data_field: bool, +) -> Any: + for inner_type in transfer_type.inner_types: + if isinstance(inner_type, CompositeType): + raise RuntimeError("Composite inner types not (yet) supported for nested unions.") + + if inner_type.nested_field_info and isinstance( + source_value, + inner_type.nested_field_info.model if is_data_field else inner_type.field_definition.annotation, + ): + return _transfer_instance_data( + destination_type=inner_type.field_definition.annotation + if is_data_field + else inner_type.nested_field_info.model, + source_instance=source_value, + field_definitions=inner_type.nested_field_info.field_definitions, + is_data_field=is_data_field, + ) + return source_value + + +def _create_msgspec_field(field_definition: TransferDTOFieldDefinition) -> Any: + kwargs: dict[str, Any] = {} + if field_definition.is_partial: + kwargs["default"] = UNSET + + elif field_definition.default is not Empty: + kwargs["default"] = field_definition.default + + elif field_definition.default_factory is not None: + kwargs["default_factory"] = field_definition.default_factory + + if field_definition.serialization_name is not None: + kwargs["name"] = field_definition.serialization_name + + return field(**kwargs) + + +def _create_struct_field_meta_for_field_definition(field_definition: TransferDTOFieldDefinition) -> msgspec.Meta | None: + if (kwarg_definition := field_definition.kwarg_definition) is None or not isinstance( + kwarg_definition, KwargDefinition + ): + return None + + return msgspec.Meta( + description=kwarg_definition.description, + examples=[e.value for e in kwarg_definition.examples or []], + ge=kwarg_definition.ge, + gt=kwarg_definition.gt, + le=kwarg_definition.le, + lt=kwarg_definition.lt, + max_length=kwarg_definition.max_length if not field_definition.is_partial else None, + min_length=kwarg_definition.min_length if not field_definition.is_partial else None, + multiple_of=kwarg_definition.multiple_of, + pattern=kwarg_definition.pattern, + title=kwarg_definition.title, + ) + + +def _create_struct_for_field_definitions( + model_name: str, + field_definitions: tuple[TransferDTOFieldDefinition, ...], + rename_strategy: RenameStrategy | dict[str, str] | None, +) -> type[Struct]: + struct_fields: list[tuple[str, type] | tuple[str, type, type]] = [] + + for field_definition in field_definitions: + if field_definition.is_excluded: + continue + + field_type = _create_transfer_model_type_annotation(field_definition.transfer_type) + if field_definition.is_partial: + field_type = Union[field_type, UnsetType] + + if (field_meta := _create_struct_field_meta_for_field_definition(field_definition)) is not None: + field_type = Annotated[field_type, field_meta] + + struct_fields.append( + ( + field_definition.name, + field_type, + _create_msgspec_field(field_definition), + ) + ) + return defstruct(model_name, struct_fields, frozen=True, kw_only=True, rename=rename_strategy) + + +def build_annotation_for_backend( + model_type: type[Any], field_definition: FieldDefinition, transfer_model: type[Struct] +) -> Any: + """A helper to re-build a generic outer type with new inner type. + + Args: + model_type: The original model type. + field_definition: The parsed type that represents the handler annotation for which the DTO is being applied. + transfer_model: The transfer model generated to represent the model type. + + Returns: + Annotation with new inner type if applicable. + """ + if not field_definition.inner_types: + if field_definition.is_subclass_of(model_type): + return transfer_model + return field_definition.annotation + + inner_types = tuple( + build_annotation_for_backend(model_type, inner_type, transfer_model) + for inner_type in field_definition.inner_types + ) + + return field_definition.safe_generic_origin[inner_types] + + +def _should_mark_private(field_definition: DTOFieldDefinition, underscore_fields_private: bool) -> bool: + """Returns ``True`` where a field should be marked as private. + + Fields should be marked as private when: + - the ``underscore_fields_private`` flag is set. + - the field is not already marked. + - the field name is prefixed with an underscore + + Args: + field_definition: defined DTO field + underscore_fields_private: whether fields prefixed with an underscore should be marked as private. + """ + return bool( + underscore_fields_private and field_definition.dto_field.mark is None and field_definition.name.startswith("_") + ) + + +def _should_exclude_field( + field_definition: DTOFieldDefinition, exclude: AbstractSet[str], include: AbstractSet[str], is_data_field: bool +) -> bool: + """Returns ``True`` where a field should be excluded from data transfer. + + Args: + field_definition: defined DTO field + exclude: names of fields to exclude + include: names of fields to exclude + is_data_field: whether the field is a data field + + Returns: + ``True`` if the field should not be included in any data transfer. + """ + field_name = field_definition.name + if field_name in exclude: + return True + if include and field_name not in include and not (any(f.startswith(f"{field_name}.") for f in include)): + return True + if field_definition.dto_field.mark is Mark.PRIVATE: + return True + if is_data_field and field_definition.dto_field.mark is Mark.READ_ONLY: + return True + return not is_data_field and field_definition.dto_field.mark is Mark.WRITE_ONLY + + +def _create_transfer_model_type_annotation(transfer_type: TransferType) -> Any: + """Create a type annotation for a transfer model. + + Uses the parsed type that originates from the data model and the transfer model generated to represent a nested + type to reconstruct the type annotation for the transfer model. + """ + if isinstance(transfer_type, SimpleType): + if transfer_type.nested_field_info: + return transfer_type.nested_field_info.model + return transfer_type.field_definition.annotation + + if isinstance(transfer_type, CollectionType): + return _create_transfer_model_collection_type(transfer_type) + + if isinstance(transfer_type, MappingType): + return _create_transfer_model_mapping_type(transfer_type) + + if isinstance(transfer_type, TupleType): + return _create_transfer_model_tuple_type(transfer_type) + + if isinstance(transfer_type, UnionType): + return _create_transfer_model_union_type(transfer_type) + + raise RuntimeError(f"Unexpected transfer type: {type(transfer_type)}") + + +def _create_transfer_model_collection_type(transfer_type: CollectionType) -> Any: + generic_collection_type = transfer_type.field_definition.safe_generic_origin + inner_type = _create_transfer_model_type_annotation(transfer_type.inner_type) + if transfer_type.field_definition.origin is tuple: + return generic_collection_type[inner_type, ...] + return generic_collection_type[inner_type] + + +def _create_transfer_model_tuple_type(transfer_type: TupleType) -> Any: + inner_types = tuple(_create_transfer_model_type_annotation(t) for t in transfer_type.inner_types) + return transfer_type.field_definition.safe_generic_origin[inner_types] + + +def _create_transfer_model_union_type(transfer_type: UnionType) -> Any: + inner_types = tuple(_create_transfer_model_type_annotation(t) for t in transfer_type.inner_types) + return transfer_type.field_definition.safe_generic_origin[inner_types] + + +def _create_transfer_model_mapping_type(transfer_type: MappingType) -> Any: + key_type = _create_transfer_model_type_annotation(transfer_type.key_type) + value_type = _create_transfer_model_type_annotation(transfer_type.value_type) + return transfer_type.field_definition.safe_generic_origin[key_type, value_type] diff --git a/venv/lib/python3.11/site-packages/litestar/dto/_codegen_backend.py b/venv/lib/python3.11/site-packages/litestar/dto/_codegen_backend.py new file mode 100644 index 0000000..deff908 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/_codegen_backend.py @@ -0,0 +1,541 @@ +"""DTO backends do the heavy lifting of decoding and validating raw bytes into domain models, and +back again, to bytes. +""" + +from __future__ import annotations + +import re +import textwrap +from contextlib import contextmanager, nullcontext +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Generator, + Mapping, + Protocol, + cast, +) + +from msgspec import UNSET + +from litestar.dto._backend import DTOBackend +from litestar.dto._types import ( + CollectionType, + CompositeType, + SimpleType, + TransferDTOFieldDefinition, + TransferType, + UnionType, +) +from litestar.utils.helpers import unique_name_for_scope + +if TYPE_CHECKING: + from litestar.connection import ASGIConnection + from litestar.dto import AbstractDTO + from litestar.types.serialization import LitestarEncodableType + from litestar.typing import FieldDefinition + +__all__ = ("DTOCodegenBackend",) + + +class DTOCodegenBackend(DTOBackend): + __slots__ = ( + "_transfer_to_dict", + "_transfer_to_model_type", + "_transfer_data_from_builtins", + "_transfer_data_from_builtins_with_overrides", + "_encode_data", + ) + + def __init__( + self, + dto_factory: type[AbstractDTO], + field_definition: FieldDefinition, + handler_id: str, + is_data_field: bool, + model_type: type[Any], + wrapper_attribute_name: str | None, + ) -> None: + """Create dto backend instance. + + Args: + dto_factory: The DTO factory class calling this backend. + field_definition: Parsed type. + handler_id: The name of the handler that this backend is for. + is_data_field: Whether the field is a subclass of DTOData. + model_type: Model type. + wrapper_attribute_name: If the data that DTO should operate upon is wrapped in a generic datastructure, + this is the name of the attribute that the data is stored in. + """ + super().__init__( + dto_factory=dto_factory, + field_definition=field_definition, + handler_id=handler_id, + is_data_field=is_data_field, + model_type=model_type, + wrapper_attribute_name=wrapper_attribute_name, + ) + self._transfer_to_dict = self._create_transfer_data_fn( + destination_type=dict, + field_definition=self.field_definition, + ) + self._transfer_to_model_type = self._create_transfer_data_fn( + destination_type=self.model_type, + field_definition=self.field_definition, + ) + self._transfer_data_from_builtins = self._create_transfer_data_fn( + destination_type=self.model_type, + field_definition=self.field_definition, + ) + self._transfer_data_from_builtins_with_overrides = self._create_transfer_data_fn( + destination_type=self.model_type, + field_definition=self.field_definition, + ) + self._encode_data = self._create_transfer_data_fn( + destination_type=self.transfer_model_type, + field_definition=self.field_definition, + ) + + def populate_data_from_builtins(self, builtins: Any, asgi_connection: ASGIConnection) -> Any: + """Populate model instance from builtin types. + + Args: + builtins: Builtin type. + asgi_connection: The current ASGI Connection + + Returns: + Instance or collection of ``model_type`` instances. + """ + if self.dto_data_type: + return self.dto_data_type( + backend=self, + data_as_builtins=self._transfer_to_dict(self.parse_builtins(builtins, asgi_connection)), + ) + return self.transfer_data_from_builtins(self.parse_builtins(builtins, asgi_connection)) + + def transfer_data_from_builtins(self, builtins: Any) -> Any: + """Populate model instance from builtin types. + + Args: + builtins: Builtin type. + + Returns: + Instance or collection of ``model_type`` instances. + """ + return self._transfer_data_from_builtins(builtins) + + def populate_data_from_raw(self, raw: bytes, asgi_connection: ASGIConnection) -> Any: + """Parse raw bytes into instance of `model_type`. + + Args: + raw: bytes + asgi_connection: The current ASGI Connection + + Returns: + Instance or collection of ``model_type`` instances. + """ + if self.dto_data_type: + return self.dto_data_type( + backend=self, + data_as_builtins=self._transfer_to_dict(self.parse_raw(raw, asgi_connection)), + ) + return self._transfer_to_model_type(self.parse_raw(raw, asgi_connection)) + + def encode_data(self, data: Any) -> LitestarEncodableType: + """Encode data into a ``LitestarEncodableType``. + + Args: + data: Data to encode. + + Returns: + Encoded data. + """ + if self.wrapper_attribute_name: + wrapped_transfer = self._encode_data(getattr(data, self.wrapper_attribute_name)) + setattr(data, self.wrapper_attribute_name, wrapped_transfer) + return cast("LitestarEncodableType", data) + + return cast("LitestarEncodableType", self._encode_data(data)) + + def _create_transfer_data_fn( + self, + destination_type: type[Any], + field_definition: FieldDefinition, + ) -> Any: + """Create instance or iterable of instances of ``destination_type``. + + Args: + destination_type: the model type received by the DTO on type narrowing. + field_definition: the parsed type that represents the handler annotation for which the DTO is being applied. + + Returns: + Data parsed into ``destination_type``. + """ + + return TransferFunctionFactory.create_transfer_data( + destination_type=destination_type, + field_definitions=self.parsed_field_definitions, + is_data_field=self.is_data_field, + field_definition=field_definition, + ) + + +class FieldAccessManager(Protocol): + def __call__(self, source_name: str, field_name: str, expect_optional: bool) -> ContextManager[str]: ... + + +class TransferFunctionFactory: + def __init__(self, is_data_field: bool, nested_as_dict: bool) -> None: + self.is_data_field = is_data_field + self._fn_locals: dict[str, Any] = { + "Mapping": Mapping, + "UNSET": UNSET, + } + self._indentation = 1 + self._body = "" + self.names: set[str] = set() + self.nested_as_dict = nested_as_dict + self._re_index_access = re.compile(r"\[['\"](\w+?)['\"]]") + + def _add_to_fn_globals(self, name: str, value: Any) -> str: + unique_name = unique_name_for_scope(name, self._fn_locals) + self._fn_locals[unique_name] = value + return unique_name + + def _create_local_name(self, name: str) -> str: + unique_name = unique_name_for_scope(name, self.names) + self.names.add(unique_name) + return unique_name + + def _make_function( + self, source_value_name: str, return_value_name: str, fn_name: str = "func" + ) -> Callable[[Any], Any]: + """Wrap the current body contents in a function definition and turn it into a callable object""" + source = f"def {fn_name}({source_value_name}):\n{self._body} return {return_value_name}" + ctx: dict[str, Any] = {**self._fn_locals} + exec(source, ctx) # noqa: S102 + return ctx["func"] # type: ignore[no-any-return] + + def _add_stmt(self, stmt: str) -> None: + self._body += textwrap.indent(stmt + "\n", " " * self._indentation) + + @contextmanager + def _start_block(self, expr: str | None = None) -> Generator[None, None, None]: + """Start an indented block. If `expr` is given, use it as the "opening line" + of the block. + """ + if expr is not None: + self._add_stmt(expr) + self._indentation += 1 + yield + self._indentation -= 1 + + @contextmanager + def _try_except_pass(self, exception: str) -> Generator[None, None, None]: + """Enter a `try / except / pass` block. Content written while inside this context + will go into the `try` block. + """ + with self._start_block("try:"): + yield + with self._start_block(expr=f"except {exception}:"): + self._add_stmt("pass") + + @contextmanager + def _access_mapping_item( + self, source_name: str, field_name: str, expect_optional: bool + ) -> Generator[str, None, None]: + """Enter a context within which an item of a mapping can be accessed safely, + i.e. only if it is contained within that mapping. + Yields an expression that accesses the mapping item. Content written while + within this context can use this expression to access the desired value. + """ + value_expr = f"{source_name}['{field_name}']" + + # if we expect an optional item, it's faster to check if it exists beforehand + if expect_optional: + with self._start_block(f"if '{field_name}' in {source_name}:"): + yield value_expr + # the happy path of a try/except will be faster than that, so we use that if + # we expect a value + else: + with self._try_except_pass("KeyError"): + yield value_expr + + @contextmanager + def _access_attribute(self, source_name: str, field_name: str, expect_optional: bool) -> Generator[str, None, None]: + """Enter a context within which an attribute of an object can be accessed + safely, i.e. only if the object actually has the attribute. + Yields an expression that retrieves the object attribute. Content written while + within this context can use this expression to access the desired value. + """ + + value_expr = f"{source_name}.{field_name}" + + # if we expect an optional attribute it's faster to check with hasattr + if expect_optional: + with self._start_block(f"if hasattr({source_name}, '{field_name}'):"): + yield value_expr + # the happy path of a try/except will be faster than that, so we use that if + # we expect a value + else: + with self._try_except_pass("AttributeError"): + yield value_expr + + @classmethod + def create_transfer_instance_data( + cls, + field_definitions: tuple[TransferDTOFieldDefinition, ...], + destination_type: type[Any], + is_data_field: bool, + ) -> Callable[[Any], Any]: + factory = cls(is_data_field=is_data_field, nested_as_dict=destination_type is dict) + tmp_return_type_name = factory._create_local_name("tmp_return_type") + source_instance_name = factory._create_local_name("source_instance") + destination_type_name = factory._add_to_fn_globals("destination_type", destination_type) + factory._create_transfer_instance_data( + tmp_return_type_name=tmp_return_type_name, + source_instance_name=source_instance_name, + destination_type_name=destination_type_name, + field_definitions=field_definitions, + destination_type_is_dict=destination_type is dict, + ) + return factory._make_function(source_value_name=source_instance_name, return_value_name=tmp_return_type_name) + + @classmethod + def create_transfer_type_data( + cls, + transfer_type: TransferType, + is_data_field: bool, + ) -> Callable[[Any], Any]: + factory = cls(is_data_field=is_data_field, nested_as_dict=False) + tmp_return_type_name = factory._create_local_name("tmp_return_type") + source_value_name = factory._create_local_name("source_value") + factory._create_transfer_type_data_body( + transfer_type=transfer_type, + nested_as_dict=False, + assignment_target=tmp_return_type_name, + source_value_name=source_value_name, + ) + return factory._make_function(source_value_name=source_value_name, return_value_name=tmp_return_type_name) + + @classmethod + def create_transfer_data( + cls, + destination_type: type[Any], + field_definitions: tuple[TransferDTOFieldDefinition, ...], + is_data_field: bool, + field_definition: FieldDefinition | None = None, + ) -> Callable[[Any], Any]: + if field_definition and field_definition.is_non_string_collection: + factory = cls( + is_data_field=is_data_field, + nested_as_dict=False, + ) + source_value_name = factory._create_local_name("source_value") + return_value_name = factory._create_local_name("tmp_return_value") + factory._create_transfer_data_body_nested( + field_definitions=field_definitions, + field_definition=field_definition, + destination_type=destination_type, + source_data_name=source_value_name, + assignment_target=return_value_name, + ) + return factory._make_function(source_value_name=source_value_name, return_value_name=return_value_name) + + return cls.create_transfer_instance_data( + destination_type=destination_type, + field_definitions=field_definitions, + is_data_field=is_data_field, + ) + + def _create_transfer_data_body_nested( + self, + field_definition: FieldDefinition, + field_definitions: tuple[TransferDTOFieldDefinition, ...], + destination_type: type[Any], + source_data_name: str, + assignment_target: str, + ) -> None: + origin_name = self._add_to_fn_globals("origin", field_definition.instantiable_origin) + transfer_func = TransferFunctionFactory.create_transfer_data( + is_data_field=self.is_data_field, + destination_type=destination_type, + field_definition=field_definition.inner_types[0], + field_definitions=field_definitions, + ) + transfer_func_name = self._add_to_fn_globals("transfer_data", transfer_func) + if field_definition.is_mapping: + self._add_stmt( + f"{assignment_target} = {origin_name}((key, {transfer_func_name}(item)) for key, item in {source_data_name}.items())" + ) + else: + self._add_stmt( + f"{assignment_target} = {origin_name}({transfer_func_name}(item) for item in {source_data_name})" + ) + + def _create_transfer_instance_data( + self, + tmp_return_type_name: str, + source_instance_name: str, + destination_type_name: str, + field_definitions: tuple[TransferDTOFieldDefinition, ...], + destination_type_is_dict: bool, + ) -> None: + local_dict_name = self._create_local_name("unstructured_data") + self._add_stmt(f"{local_dict_name} = {{}}") + + if field_definitions := tuple(f for f in field_definitions if self.is_data_field or not f.is_excluded): + if len(field_definitions) > 1 and ("." in source_instance_name or "[" in source_instance_name): + # If there's more than one field we have to access, we check if it is + # nested. If it is nested, we assign it to a local variable to avoid + # repeated lookups. This is only a small performance improvement for + # regular attributes, but can be quite significant for properties or + # other types of descriptors, where I/O may be involved, such as the + # case for lazy loaded relationships in SQLAlchemy + if "." in source_instance_name: + level_1, level_2 = source_instance_name.split(".", 1) + else: + level_1, level_2, *_ = self._re_index_access.split(source_instance_name, maxsplit=1) + + new_source_instance_name = self._create_local_name(f"{level_1}_{level_2}") + self._add_stmt(f"{new_source_instance_name} = {source_instance_name}") + source_instance_name = new_source_instance_name + + for source_type in ("mapping", "object"): + if source_type == "mapping": + block_expr = f"if isinstance({source_instance_name}, Mapping):" + access_item = self._access_mapping_item + else: + block_expr = "else:" + access_item = self._access_attribute + + with self._start_block(expr=block_expr): + self._create_transfer_instance_data_inner( + local_dict_name=local_dict_name, + field_definitions=field_definitions, + access_field_safe=access_item, + source_instance_name=source_instance_name, + ) + + # if the destination type is a dict we can reuse our temporary dictionary of + # unstructured data as the "return value" + if not destination_type_is_dict: + self._add_stmt(f"{tmp_return_type_name} = {destination_type_name}(**{local_dict_name})") + else: + self._add_stmt(f"{tmp_return_type_name} = {local_dict_name}") + + def _create_transfer_instance_data_inner( + self, + *, + local_dict_name: str, + field_definitions: tuple[TransferDTOFieldDefinition, ...], + access_field_safe: FieldAccessManager, + source_instance_name: str, + ) -> None: + for field_definition in field_definitions: + with access_field_safe( + source_name=source_instance_name, + field_name=field_definition.name, + expect_optional=field_definition.is_partial or field_definition.is_optional, + ) as source_value_expr: + if self.is_data_field and field_definition.is_partial: + # we assign the source value to a name here, so we can skip + # getting it twice from the source instance + source_value_name = self._create_local_name("source_value") + self._add_stmt(f"{source_value_name} = {source_value_expr}") + ctx = self._start_block(f"if {source_value_name} is not UNSET:") + else: + # in these cases, we only ever access the source value once, so + # we can skip assigning it + source_value_name = source_value_expr + ctx = nullcontext() # type: ignore[assignment] + with ctx: + self._create_transfer_type_data_body( + transfer_type=field_definition.transfer_type, + nested_as_dict=self.nested_as_dict, + source_value_name=source_value_name, + assignment_target=f"{local_dict_name}['{field_definition.name}']", + ) + + def _create_transfer_type_data_body( + self, + transfer_type: TransferType, + nested_as_dict: bool, + source_value_name: str, + assignment_target: str, + ) -> None: + if isinstance(transfer_type, SimpleType) and transfer_type.nested_field_info: + if nested_as_dict: + destination_type: Any = dict + elif self.is_data_field: + destination_type = transfer_type.field_definition.annotation + else: + destination_type = transfer_type.nested_field_info.model + + self._create_transfer_instance_data( + field_definitions=transfer_type.nested_field_info.field_definitions, + tmp_return_type_name=assignment_target, + source_instance_name=source_value_name, + destination_type_name=self._add_to_fn_globals("destination_type", destination_type), + destination_type_is_dict=destination_type is dict, + ) + return + + if isinstance(transfer_type, UnionType) and transfer_type.has_nested: + self._create_transfer_nested_union_type_data( + transfer_type=transfer_type, + source_value_name=source_value_name, + assignment_target=assignment_target, + ) + return + + if isinstance(transfer_type, CollectionType): + origin_name = self._add_to_fn_globals("origin", transfer_type.field_definition.instantiable_origin) + if transfer_type.has_nested: + transfer_type_data_fn = TransferFunctionFactory.create_transfer_type_data( + is_data_field=self.is_data_field, transfer_type=transfer_type.inner_type + ) + transfer_type_data_name = self._add_to_fn_globals("transfer_type_data", transfer_type_data_fn) + self._add_stmt( + f"{assignment_target} = {origin_name}({transfer_type_data_name}(item) for item in {source_value_name})" + ) + return + + self._add_stmt(f"{assignment_target} = {origin_name}({source_value_name})") + return + + self._add_stmt(f"{assignment_target} = {source_value_name}") + + def _create_transfer_nested_union_type_data( + self, + transfer_type: UnionType, + source_value_name: str, + assignment_target: str, + ) -> None: + for inner_type in transfer_type.inner_types: + if isinstance(inner_type, CompositeType): + continue + + if inner_type.nested_field_info: + if self.is_data_field: + constraint_type = inner_type.nested_field_info.model + destination_type = inner_type.field_definition.annotation + else: + constraint_type = inner_type.field_definition.annotation + destination_type = inner_type.nested_field_info.model + + constraint_type_name = self._add_to_fn_globals("constraint_type", constraint_type) + destination_type_name = self._add_to_fn_globals("destination_type", destination_type) + + with self._start_block(f"if isinstance({source_value_name}, {constraint_type_name}):"): + self._create_transfer_instance_data( + destination_type_name=destination_type_name, + destination_type_is_dict=destination_type is dict, + field_definitions=inner_type.nested_field_info.field_definitions, + source_instance_name=source_value_name, + tmp_return_type_name=assignment_target, + ) + return + self._add_stmt(f"{assignment_target} = {source_value_name}") diff --git a/venv/lib/python3.11/site-packages/litestar/dto/_types.py b/venv/lib/python3.11/site-packages/litestar/dto/_types.py new file mode 100644 index 0000000..24e99b7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/_types.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from litestar.dto.data_structures import DTOFieldDefinition + +if TYPE_CHECKING: + from typing import Any + + from typing_extensions import Self + + from litestar.typing import FieldDefinition + + +@dataclass(frozen=True) +class NestedFieldInfo: + """Type for representing fields and model type of nested model type.""" + + __slots__ = ("model", "field_definitions") + + model: type[Any] + field_definitions: tuple[TransferDTOFieldDefinition, ...] + + +@dataclass(frozen=True) +class TransferType: + """Type for representing model types for data transfer.""" + + __slots__ = ("field_definition",) + + field_definition: FieldDefinition + + +@dataclass(frozen=True) +class SimpleType(TransferType): + """Represents indivisible, non-composite types.""" + + __slots__ = ("nested_field_info",) + + nested_field_info: NestedFieldInfo | None + """If the type is a 'nested' type, this is the model generated for transfer to/from it.""" + + @property + def has_nested(self) -> bool: + return self.nested_field_info is not None + + +@dataclass(frozen=True) +class CompositeType(TransferType): + """A type that is made up of other types.""" + + __slots__ = ("has_nested",) + + has_nested: bool + """Whether the type represents nested model types within itself.""" + + +@dataclass(frozen=True) +class UnionType(CompositeType): + """Type for representing union types for data transfer.""" + + __slots__ = ("inner_types",) + + inner_types: tuple[CompositeType | SimpleType, ...] + + +@dataclass(frozen=True) +class CollectionType(CompositeType): + """Type for representing collection types for data transfer.""" + + __slots__ = ("inner_type",) + + inner_type: CompositeType | SimpleType + + +@dataclass(frozen=True) +class TupleType(CompositeType): + """Type for representing tuples for data transfer.""" + + __slots__ = ("inner_types",) + + inner_types: tuple[CompositeType | SimpleType, ...] + + +@dataclass(frozen=True) +class MappingType(CompositeType): + """Type for representing mappings for data transfer.""" + + __slots__ = ("key_type", "value_type") + + key_type: CompositeType | SimpleType + value_type: CompositeType | SimpleType + + +@dataclass(frozen=True) +class TransferDTOFieldDefinition(DTOFieldDefinition): + __slots__ = ( + "default_factory", + "dto_field", + "model_name", + "is_excluded", + "is_partial", + "serialization_name", + "transfer_type", + "unique_name", + ) + + transfer_type: TransferType + """Type of the field for transfer.""" + serialization_name: str | None + """Name of the field as it should appear in serialized form.""" + is_partial: bool + """Whether the field is optional for transfer.""" + is_excluded: bool + """Whether the field should be excluded from transfer.""" + + @classmethod + def from_dto_field_definition( + cls, + field_definition: DTOFieldDefinition, + transfer_type: TransferType, + serialization_name: str | None, + is_partial: bool, + is_excluded: bool, + ) -> Self: + return cls( + annotation=field_definition.annotation, + args=field_definition.args, + default=field_definition.default, + default_factory=field_definition.default_factory, + dto_field=field_definition.dto_field, + extra=field_definition.extra, + inner_types=field_definition.inner_types, + instantiable_origin=field_definition.instantiable_origin, + is_excluded=is_excluded, + is_partial=is_partial, + kwarg_definition=field_definition.kwarg_definition, + metadata=field_definition.metadata, + name=field_definition.name, + origin=field_definition.origin, + raw=field_definition.raw, + safe_generic_origin=field_definition.safe_generic_origin, + serialization_name=serialization_name, + transfer_type=transfer_type, + type_wrappers=field_definition.type_wrappers, + model_name=field_definition.model_name, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/dto/base_dto.py b/venv/lib/python3.11/site-packages/litestar/dto/base_dto.py new file mode 100644 index 0000000..991b09f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/base_dto.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import typing +from abc import abstractmethod +from inspect import getmodule +from typing import TYPE_CHECKING, Collection, Generic, TypeVar + +from typing_extensions import NotRequired, TypedDict, get_type_hints + +from litestar.dto._backend import DTOBackend +from litestar.dto._codegen_backend import DTOCodegenBackend +from litestar.dto.config import DTOConfig +from litestar.dto.data_structures import DTOData +from litestar.dto.types import RenameStrategy +from litestar.enums import RequestEncodingType +from litestar.exceptions.dto_exceptions import InvalidAnnotationException +from litestar.types.builtin_types import NoneType +from litestar.types.composite_types import TypeEncodersMap +from litestar.typing import FieldDefinition + +if TYPE_CHECKING: + from typing import Any, ClassVar, Generator + + from typing_extensions import Self + + from litestar._openapi.schema_generation import SchemaCreator + from litestar.connection import ASGIConnection + from litestar.dto.data_structures import DTOFieldDefinition + from litestar.openapi.spec import Reference, Schema + from litestar.types.serialization import LitestarEncodableType + +__all__ = ("AbstractDTO",) + +T = TypeVar("T") + + +class _BackendDict(TypedDict): + data_backend: NotRequired[DTOBackend] + return_backend: NotRequired[DTOBackend] + + +class AbstractDTO(Generic[T]): + """Base class for DTO types.""" + + __slots__ = ("asgi_connection",) + + config: ClassVar[DTOConfig] + """Config objects to define properties of the DTO.""" + model_type: type[T] + """If ``annotation`` is an iterable, this is the inner type, otherwise will be the same as ``annotation``.""" + + _dto_backends: ClassVar[dict[str, _BackendDict]] = {} + + def __init__(self, asgi_connection: ASGIConnection) -> None: + """Create an AbstractDTOFactory type. + + Args: + asgi_connection: A :class:`ASGIConnection <litestar.connection.base.ASGIConnection>` instance. + """ + self.asgi_connection = asgi_connection + + def __class_getitem__(cls, annotation: Any) -> type[Self]: + field_definition = FieldDefinition.from_annotation(annotation) + + if (field_definition.is_optional and len(field_definition.args) > 2) or ( + field_definition.is_union and not field_definition.is_optional + ): + raise InvalidAnnotationException("Unions are currently not supported as type argument to DTOs.") + + if field_definition.is_forward_ref: + raise InvalidAnnotationException("Forward references are not supported as type argument to DTO") + + # if a configuration is not provided, and the type narrowing is a type var, we don't want to create a subclass + config = cls.get_dto_config_from_annotated_type(field_definition) + + if not config: + if field_definition.is_type_var: + return cls + config = cls.config if hasattr(cls, "config") else DTOConfig() + + cls_dict: dict[str, Any] = {"config": config, "_type_backend_map": {}, "_handler_backend_map": {}} + if not field_definition.is_type_var: + cls_dict.update(model_type=field_definition.annotation) + + return type(f"{cls.__name__}[{annotation}]", (cls,), cls_dict) # pyright: ignore + + def decode_builtins(self, value: dict[str, Any]) -> Any: + """Decode a dictionary of Python values into an the DTO's datatype.""" + + backend = self._dto_backends[self.asgi_connection.route_handler.handler_id]["data_backend"] # pyright: ignore + return backend.populate_data_from_builtins(value, self.asgi_connection) + + def decode_bytes(self, value: bytes) -> Any: + """Decode a byte string into an the DTO's datatype.""" + + backend = self._dto_backends[self.asgi_connection.route_handler.handler_id]["data_backend"] # pyright: ignore + return backend.populate_data_from_raw(value, self.asgi_connection) + + def data_to_encodable_type(self, data: T | Collection[T]) -> LitestarEncodableType: + backend = self._dto_backends[self.asgi_connection.route_handler.handler_id]["return_backend"] # pyright: ignore + return backend.encode_data(data) + + @classmethod + @abstractmethod + def generate_field_definitions(cls, model_type: type[Any]) -> Generator[DTOFieldDefinition, None, None]: + """Generate ``FieldDefinition`` instances from ``model_type``. + + Yields: + ``FieldDefinition`` instances. + """ + + @classmethod + @abstractmethod + def detect_nested_field(cls, field_definition: FieldDefinition) -> bool: + """Return ``True`` if ``field_definition`` represents a nested model field. + + Args: + field_definition: inspect type to determine if field represents a nested model. + + Returns: + ``True`` if ``field_definition`` represents a nested model field. + """ + + @classmethod + def is_supported_model_type_field(cls, field_definition: FieldDefinition) -> bool: + """Check support for the given type. + + Args: + field_definition: A :class:`FieldDefinition <litestar.typing.FieldDefinition>` instance. + + Returns: + Whether the type of the field definition is supported by the DTO. + """ + return field_definition.is_subclass_of(cls.model_type) or ( + field_definition.origin + and any( + cls.resolve_model_type(inner_field).is_subclass_of(cls.model_type) + for inner_field in field_definition.inner_types + ) + ) + + @classmethod + def create_for_field_definition( + cls, + field_definition: FieldDefinition, + handler_id: str, + backend_cls: type[DTOBackend] | None = None, + ) -> None: + """Creates a DTO subclass for a field definition. + + Args: + field_definition: A :class:`FieldDefinition <litestar.typing.FieldDefinition>` instance. + handler_id: ID of the route handler for which to create a DTO instance. + backend_cls: Alternative DTO backend class to use + + Returns: + None + """ + + if handler_id not in cls._dto_backends: + cls._dto_backends[handler_id] = {} + + backend_context = cls._dto_backends[handler_id] + key = "data_backend" if field_definition.name == "data" else "return_backend" + + if key not in backend_context: + model_type_field_definition = cls.resolve_model_type(field_definition=field_definition) + wrapper_attribute_name: str | None = None + + if not model_type_field_definition.is_subclass_of(cls.model_type): + if resolved_generic_result := cls.resolve_generic_wrapper_type( + field_definition=model_type_field_definition + ): + model_type_field_definition, field_definition, wrapper_attribute_name = resolved_generic_result + else: + raise InvalidAnnotationException( + f"DTO narrowed with '{cls.model_type}', handler type is '{field_definition.annotation}'" + ) + + if backend_cls is None: + backend_cls = DTOCodegenBackend if cls.config.experimental_codegen_backend else DTOBackend + elif backend_cls is DTOCodegenBackend and cls.config.experimental_codegen_backend is False: + backend_cls = DTOBackend + + backend_context[key] = backend_cls( # type: ignore[literal-required] + dto_factory=cls, + field_definition=field_definition, + model_type=model_type_field_definition.annotation, + wrapper_attribute_name=wrapper_attribute_name, + is_data_field=field_definition.name == "data", + handler_id=handler_id, + ) + + @classmethod + def create_openapi_schema( + cls, field_definition: FieldDefinition, handler_id: str, schema_creator: SchemaCreator + ) -> Reference | Schema: + """Create an OpenAPI request body. + + Returns: + OpenAPI request body. + """ + key = "data_backend" if field_definition.name == "data" else "return_backend" + backend = cls._dto_backends[handler_id][key] # type: ignore[literal-required] + return schema_creator.for_field_definition(FieldDefinition.from_annotation(backend.annotation)) + + @classmethod + def resolve_generic_wrapper_type( + cls, field_definition: FieldDefinition + ) -> tuple[FieldDefinition, FieldDefinition, str] | None: + """Handle where DTO supported data is wrapped in a generic container type. + + Args: + field_definition: A parsed type annotation that represents the annotation used to narrow the DTO type. + + Returns: + The data model type. + """ + if field_definition.origin and ( + inner_fields := [ + inner_field + for inner_field in field_definition.inner_types + if cls.resolve_model_type(inner_field).is_subclass_of(cls.model_type) + ] + ): + inner_field = inner_fields[0] + model_field_definition = cls.resolve_model_type(inner_field) + + for attr, attr_type in cls.get_model_type_hints(field_definition.origin).items(): + if isinstance(attr_type.annotation, TypeVar) or any( + isinstance(t.annotation, TypeVar) for t in attr_type.inner_types + ): + if attr_type.is_non_string_collection: + # the inner type of the collection type is the type var, so we need to specialize the + # collection type with the DTO supported type. + specialized_annotation = attr_type.safe_generic_origin[model_field_definition.annotation] + return model_field_definition, FieldDefinition.from_annotation(specialized_annotation), attr + return model_field_definition, inner_field, attr + return None + + @staticmethod + def get_model_type_hints( + model_type: type[Any], namespace: dict[str, Any] | None = None + ) -> dict[str, FieldDefinition]: + """Retrieve type annotations for ``model_type``. + + Args: + model_type: Any type-annotated class. + namespace: Optional namespace to use for resolving type hints. + + Returns: + Parsed type hints for ``model_type`` resolved within the scope of its module. + """ + namespace = namespace or {} + namespace.update(vars(typing)) + namespace.update( + { + "TypeEncodersMap": TypeEncodersMap, + "DTOConfig": DTOConfig, + "RenameStrategy": RenameStrategy, + "RequestEncodingType": RequestEncodingType, + } + ) + + if model_module := getmodule(model_type): + namespace.update(vars(model_module)) + + return { + k: FieldDefinition.from_kwarg(annotation=v, name=k) + for k, v in get_type_hints(model_type, localns=namespace, include_extras=True).items() # pyright: ignore + } + + @staticmethod + def get_dto_config_from_annotated_type(field_definition: FieldDefinition) -> DTOConfig | None: + """Extract data type and config instances from ``Annotated`` annotation. + + Args: + field_definition: A parsed type annotation that represents the annotation used to narrow the DTO type. + + Returns: + The type and config object extracted from the annotation. + """ + return next((item for item in field_definition.metadata if isinstance(item, DTOConfig)), None) + + @classmethod + def resolve_model_type(cls, field_definition: FieldDefinition) -> FieldDefinition: + """Resolve the data model type from a parsed type. + + Args: + field_definition: A parsed type annotation that represents the annotation used to narrow the DTO type. + + Returns: + A :class:`FieldDefinition <.typing.FieldDefinition>` that represents the data model type. + """ + if field_definition.is_optional: + return cls.resolve_model_type( + next(t for t in field_definition.inner_types if not t.is_subclass_of(NoneType)) + ) + + if field_definition.is_subclass_of(DTOData): + return cls.resolve_model_type(field_definition.inner_types[0]) + + if field_definition.is_collection: + if field_definition.is_mapping: + return cls.resolve_model_type(field_definition.inner_types[1]) + + if field_definition.is_tuple: + if any(t is Ellipsis for t in field_definition.args): + return cls.resolve_model_type(field_definition.inner_types[0]) + elif field_definition.is_non_string_collection: + return cls.resolve_model_type(field_definition.inner_types[0]) + + return field_definition diff --git a/venv/lib/python3.11/site-packages/litestar/dto/config.py b/venv/lib/python3.11/site-packages/litestar/dto/config.py new file mode 100644 index 0000000..e213d17 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/config.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from litestar.exceptions import ImproperlyConfiguredException + +if TYPE_CHECKING: + from typing import AbstractSet + + from litestar.dto.types import RenameStrategy + +__all__ = ("DTOConfig",) + + +@dataclass(frozen=True) +class DTOConfig: + """Control the generated DTO.""" + + exclude: AbstractSet[str] = field(default_factory=set) + """Explicitly exclude fields from the generated DTO. + + If exclude is specified, all fields not specified in exclude will be included by default. + + Notes: + - The field names are dot-separated paths to nested fields, e.g. ``"address.street"`` will + exclude the ``"street"`` field from a nested ``"address"`` model. + - 'exclude' mutually exclusive with 'include' - specifying both values will raise an + ``ImproperlyConfiguredException``. + """ + include: AbstractSet[str] = field(default_factory=set) + """Explicitly include fields in the generated DTO. + + If include is specified, all fields not specified in include will be excluded by default. + + Notes: + - The field names are dot-separated paths to nested fields, e.g. ``"address.street"`` will + include the ``"street"`` field from a nested ``"address"`` model. + - 'include' mutually exclusive with 'exclude' - specifying both values will raise an + ``ImproperlyConfiguredException``. + """ + rename_fields: dict[str, str] = field(default_factory=dict) + """Mapping of field names, to new name.""" + rename_strategy: RenameStrategy | None = None + """Rename all fields using a pre-defined strategy or a custom strategy. + + The pre-defined strategies are: `upper`, `lower`, `camel`, `pascal`. + + A custom strategy is any callable that accepts a string as an argument and + return a string. + + Fields defined in ``rename_fields`` are ignored.""" + max_nested_depth: int = 1 + """The maximum depth of nested items allowed for data transfer.""" + partial: bool = False + """Allow transfer of partial data.""" + underscore_fields_private: bool = True + """Fields starting with an underscore are considered private and excluded from data transfer.""" + experimental_codegen_backend: bool | None = None + """Use the experimental codegen backend""" + + def __post_init__(self) -> None: + if self.include and self.exclude: + raise ImproperlyConfiguredException( + "'include' and 'exclude' are mutually exclusive options, please use one of them" + ) diff --git a/venv/lib/python3.11/site-packages/litestar/dto/data_structures.py b/venv/lib/python3.11/site-packages/litestar/dto/data_structures.py new file mode 100644 index 0000000..a5c3386 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/data_structures.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Generic, TypeVar + +from litestar.typing import FieldDefinition + +if TYPE_CHECKING: + from typing import Any, Callable + + from litestar.dto import DTOField + from litestar.dto._backend import DTOBackend + +T = TypeVar("T") + + +class DTOData(Generic[T]): + """DTO validated data and utility methods.""" + + __slots__ = ("_backend", "_data_as_builtins") + + def __init__(self, backend: DTOBackend, data_as_builtins: Any) -> None: + self._backend = backend + self._data_as_builtins = data_as_builtins + + def create_instance(self, **kwargs: Any) -> T: + """Create an instance of the DTO validated data. + + Args: + **kwargs: Additional data to create the instance with. Takes precedence over DTO validated data. + """ + data = dict(self._data_as_builtins) + for k, v in kwargs.items(): + _set_nested_dict_value(data, k.split("__"), v) + return self._backend.transfer_data_from_builtins(data) # type: ignore[no-any-return] + + def update_instance(self, instance: T, **kwargs: Any) -> T: + """Update an instance with the DTO validated data. + + Args: + instance: The instance to update. + **kwargs: Additional data to update the instance with. Takes precedence over DTO validated data. + """ + data = {**self._data_as_builtins, **kwargs} + for k, v in data.items(): + setattr(instance, k, v) + return instance + + def as_builtins(self) -> Any: + """Return the DTO validated data as builtins.""" + return self._data_as_builtins + + +def _set_nested_dict_value(d: dict[str, Any], keys: list[str], value: Any) -> None: + if len(keys) == 1: + d[keys[0]] = value + else: + key = keys[0] + d.setdefault(key, {}) + _set_nested_dict_value(d[key], keys[1:], value) + + +@dataclass(frozen=True) +class DTOFieldDefinition(FieldDefinition): + """A model field representation for purposes of generating a DTO backend model type.""" + + __slots__ = ( + "default_factory", + "dto_field", + "model_name", + ) + + model_name: str + """The name of the model for which the field is generated.""" + default_factory: Callable[[], Any] | None + """Default factory of the field.""" + dto_field: DTOField + """DTO field configuration.""" + + @classmethod + def from_field_definition( + cls, + field_definition: FieldDefinition, + model_name: str, + default_factory: Callable[[], Any] | None, + dto_field: DTOField, + ) -> DTOFieldDefinition: + """Create a :class:`FieldDefinition` from a :class:`FieldDefinition`. + + Args: + field_definition: A :class:`FieldDefinition` to create a :class:`FieldDefinition` from. + model_name: The name of the model. + default_factory: Default factory function, if any. + dto_field: DTOField instance. + + Returns: + A :class:`FieldDefinition` instance. + """ + return DTOFieldDefinition( + annotation=field_definition.annotation, + args=field_definition.args, + default=field_definition.default, + default_factory=default_factory, + dto_field=dto_field, + extra=field_definition.extra, + inner_types=field_definition.inner_types, + instantiable_origin=field_definition.instantiable_origin, + kwarg_definition=field_definition.kwarg_definition, + metadata=field_definition.metadata, + model_name=model_name, + name=field_definition.name, + origin=field_definition.origin, + raw=field_definition.raw, + safe_generic_origin=field_definition.safe_generic_origin, + type_wrappers=field_definition.type_wrappers, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/dto/dataclass_dto.py b/venv/lib/python3.11/site-packages/litestar/dto/dataclass_dto.py new file mode 100644 index 0000000..554b0f3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/dataclass_dto.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from dataclasses import MISSING, fields, replace +from typing import TYPE_CHECKING, Generic, TypeVar + +from litestar.dto.base_dto import AbstractDTO +from litestar.dto.data_structures import DTOFieldDefinition +from litestar.dto.field import DTO_FIELD_META_KEY, DTOField +from litestar.params import DependencyKwarg, KwargDefinition +from litestar.types.empty import Empty + +if TYPE_CHECKING: + from typing import Collection, Generator + + from litestar.types.protocols import DataclassProtocol + from litestar.typing import FieldDefinition + + +__all__ = ("DataclassDTO", "T") + +T = TypeVar("T", bound="DataclassProtocol | Collection[DataclassProtocol]") +AnyDataclass = TypeVar("AnyDataclass", bound="DataclassProtocol") + + +class DataclassDTO(AbstractDTO[T], Generic[T]): + """Support for domain modelling with dataclasses.""" + + @classmethod + def generate_field_definitions( + cls, model_type: type[DataclassProtocol] + ) -> Generator[DTOFieldDefinition, None, None]: + dc_fields = {f.name: f for f in fields(model_type)} + for key, field_definition in cls.get_model_type_hints(model_type).items(): + if not (dc_field := dc_fields.get(key)): + continue + + default = dc_field.default if dc_field.default is not MISSING else Empty + default_factory = dc_field.default_factory if dc_field.default_factory is not MISSING else None + field_defintion = replace( + DTOFieldDefinition.from_field_definition( + field_definition=field_definition, + default_factory=default_factory, + dto_field=dc_field.metadata.get(DTO_FIELD_META_KEY, DTOField()), + model_name=model_type.__name__, + ), + name=key, + default=default, + ) + + yield ( + replace(field_defintion, default=Empty, kwarg_definition=default) + if isinstance(default, (KwargDefinition, DependencyKwarg)) + else field_defintion + ) + + @classmethod + def detect_nested_field(cls, field_definition: FieldDefinition) -> bool: + return hasattr(field_definition.annotation, "__dataclass_fields__") diff --git a/venv/lib/python3.11/site-packages/litestar/dto/field.py b/venv/lib/python3.11/site-packages/litestar/dto/field.py new file mode 100644 index 0000000..7ef8a39 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/field.py @@ -0,0 +1,50 @@ +"""DTO domain types.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Literal + +__all__ = ( + "DTO_FIELD_META_KEY", + "DTOField", + "Mark", + "dto_field", +) + +DTO_FIELD_META_KEY = "__dto__" + + +class Mark(str, Enum): + """For marking field definitions on domain models.""" + + READ_ONLY = "read-only" + """To mark a field that can be read, but not updated by clients.""" + WRITE_ONLY = "write-only" + """To mark a field that can be written to, but not read by clients.""" + PRIVATE = "private" + """To mark a field that can neither be read or updated by clients.""" + + +@dataclass +class DTOField: + """For configuring DTO behavior on model fields.""" + + mark: Mark | Literal["read-only", "write-only", "private"] | None = None + """Mark the field as read-only, or private.""" + + +def dto_field(mark: Literal["read-only", "write-only", "private"] | Mark) -> dict[str, DTOField]: + """Create a field metadata mapping. + + Args: + mark: A DTO mark for the field, e.g., "read-only". + + Returns: + A dict for setting as field metadata, such as the dataclass "metadata" field key, or the SQLAlchemy "info" + field. + + Marking a field automates its inclusion/exclusion from DTO field definitions, depending on the DTO's purpose. + """ + return {DTO_FIELD_META_KEY: DTOField(mark=Mark(mark))} diff --git a/venv/lib/python3.11/site-packages/litestar/dto/msgspec_dto.py b/venv/lib/python3.11/site-packages/litestar/dto/msgspec_dto.py new file mode 100644 index 0000000..826a1d2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/msgspec_dto.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from dataclasses import replace +from typing import TYPE_CHECKING, Generic, TypeVar + +from msgspec import NODEFAULT, Struct, structs + +from litestar.dto.base_dto import AbstractDTO +from litestar.dto.data_structures import DTOFieldDefinition +from litestar.dto.field import DTO_FIELD_META_KEY, DTOField +from litestar.types.empty import Empty + +if TYPE_CHECKING: + from typing import Any, Collection, Generator + + from litestar.typing import FieldDefinition + + +__all__ = ("MsgspecDTO",) + +T = TypeVar("T", bound="Struct | Collection[Struct]") + + +class MsgspecDTO(AbstractDTO[T], Generic[T]): + """Support for domain modelling with Msgspec.""" + + @classmethod + def generate_field_definitions(cls, model_type: type[Struct]) -> Generator[DTOFieldDefinition, None, None]: + msgspec_fields = {f.name: f for f in structs.fields(model_type)} + + def default_or_empty(value: Any) -> Any: + return Empty if value is NODEFAULT else value + + def default_or_none(value: Any) -> Any: + return None if value is NODEFAULT else value + + for key, field_definition in cls.get_model_type_hints(model_type).items(): + msgspec_field = msgspec_fields[key] + dto_field = (field_definition.extra or {}).pop(DTO_FIELD_META_KEY, DTOField()) + + yield replace( + DTOFieldDefinition.from_field_definition( + field_definition=field_definition, + dto_field=dto_field, + model_name=model_type.__name__, + default_factory=default_or_none(msgspec_field.default_factory), + ), + default=default_or_empty(msgspec_field.default), + name=key, + ) + + @classmethod + def detect_nested_field(cls, field_definition: FieldDefinition) -> bool: + return field_definition.is_subclass_of(Struct) diff --git a/venv/lib/python3.11/site-packages/litestar/dto/types.py b/venv/lib/python3.11/site-packages/litestar/dto/types.py new file mode 100644 index 0000000..f154e49 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/types.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable, Literal + + from typing_extensions import TypeAlias + +__all__ = ("RenameStrategy",) + +RenameStrategy: TypeAlias = 'Literal["lower", "upper", "camel", "pascal", "kebab"] | Callable[[str], str]' +"""A pre-defined strategy or a custom callback for converting DTO field names.""" diff --git a/venv/lib/python3.11/site-packages/litestar/enums.py b/venv/lib/python3.11/site-packages/litestar/enums.py new file mode 100644 index 0000000..a660228 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/enums.py @@ -0,0 +1,90 @@ +from enum import Enum + +__all__ = ( + "CompressionEncoding", + "HttpMethod", + "MediaType", + "OpenAPIMediaType", + "ParamType", + "RequestEncodingType", + "ScopeType", +) + + +class HttpMethod(str, Enum): + """An Enum for HTTP methods.""" + + DELETE = "DELETE" + GET = "GET" + HEAD = "HEAD" + OPTIONS = "OPTIONS" + PATCH = "PATCH" + POST = "POST" + PUT = "PUT" + + +class MediaType(str, Enum): + """An Enum for ``Content-Type`` header values.""" + + JSON = "application/json" + MESSAGEPACK = "application/x-msgpack" + HTML = "text/html" + TEXT = "text/plain" + CSS = "text/css" + XML = "application/xml" + + +class OpenAPIMediaType(str, Enum): + """An Enum for OpenAPI specific response ``Content-Type`` header values.""" + + OPENAPI_YAML = "application/vnd.oai.openapi" + OPENAPI_JSON = "application/vnd.oai.openapi+json" + + +class RequestEncodingType(str, Enum): + """An Enum for request ``Content-Type`` header values designating encoding formats.""" + + JSON = "application/json" + MESSAGEPACK = "application/x-msgpack" + MULTI_PART = "multipart/form-data" + URL_ENCODED = "application/x-www-form-urlencoded" + + +class ScopeType(str, Enum): + """An Enum for the 'http' key stored under Scope. + + Notes: + - ``asgi`` is used by Litestar internally and is not part of the specification. + """ + + HTTP = "http" + WEBSOCKET = "websocket" + ASGI = "asgi" + + +class ParamType(str, Enum): + """An Enum for the types of parameters a request can receive.""" + + PATH = "path" + QUERY = "query" + COOKIE = "cookie" + HEADER = "header" + + +class CompressionEncoding(str, Enum): + """An Enum for supported compression encodings.""" + + GZIP = "gzip" + BROTLI = "br" + + +class ASGIExtension(str, Enum): + """ASGI extension keys: https://asgi.readthedocs.io/en/latest/extensions.html""" + + WS_DENIAL = "websocket.http.response" + SERVER_PUSH = "http.response.push" + ZERO_COPY_SEND_EXTENSION = "http.response.zerocopysend" + PATH_SEND = "http.response.pathsend" + TLS = "tls" + EARLY_HINTS = "http.response.early_hint" + HTTP_TRAILERS = "http.response.trailers" diff --git a/venv/lib/python3.11/site-packages/litestar/events/__init__.py b/venv/lib/python3.11/site-packages/litestar/events/__init__.py new file mode 100644 index 0000000..a291141 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/events/__init__.py @@ -0,0 +1,4 @@ +from .emitter import BaseEventEmitterBackend, SimpleEventEmitter +from .listener import EventListener, listener + +__all__ = ("EventListener", "SimpleEventEmitter", "BaseEventEmitterBackend", "listener") diff --git a/venv/lib/python3.11/site-packages/litestar/events/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/events/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..bc8ee50 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/events/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/events/__pycache__/emitter.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/events/__pycache__/emitter.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f59ec24 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/events/__pycache__/emitter.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/events/__pycache__/listener.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/events/__pycache__/listener.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..666bcbc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/events/__pycache__/listener.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/events/emitter.py b/venv/lib/python3.11/site-packages/litestar/events/emitter.py new file mode 100644 index 0000000..7c33c9e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/events/emitter.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import math +import sys +from abc import ABC, abstractmethod +from collections import defaultdict +from contextlib import AsyncExitStack +from functools import partial +from typing import TYPE_CHECKING, Any, Sequence + +if sys.version_info < (3, 9): + from typing import AsyncContextManager +else: + from contextlib import AbstractAsyncContextManager as AsyncContextManager + +import anyio + +from litestar.exceptions import ImproperlyConfiguredException + +if TYPE_CHECKING: + from types import TracebackType + + from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + + from litestar.events.listener import EventListener + +__all__ = ("BaseEventEmitterBackend", "SimpleEventEmitter") + + +class BaseEventEmitterBackend(AsyncContextManager["BaseEventEmitterBackend"], ABC): + """Abstract class used to define event emitter backends.""" + + __slots__ = ("listeners",) + + listeners: defaultdict[str, set[EventListener]] + + def __init__(self, listeners: Sequence[EventListener]) -> None: + """Create an event emitter instance. + + Args: + listeners: A list of listeners. + """ + self.listeners = defaultdict(set) + for listener in listeners: + for event_id in listener.event_ids: + self.listeners[event_id].add(listener) + + @abstractmethod + def emit(self, event_id: str, *args: Any, **kwargs: Any) -> None: + """Emit an event to all attached listeners. + + Args: + event_id: The ID of the event to emit, e.g 'my_event'. + *args: args to pass to the listener(s). + **kwargs: kwargs to pass to the listener(s) + + Returns: + None + """ + raise NotImplementedError("not implemented") + + +class SimpleEventEmitter(BaseEventEmitterBackend): + """Event emitter the works only in the current process""" + + __slots__ = ("_queue", "_exit_stack", "_receive_stream", "_send_stream") + + def __init__(self, listeners: Sequence[EventListener]) -> None: + """Create an event emitter instance. + + Args: + listeners: A list of listeners. + """ + super().__init__(listeners=listeners) + self._receive_stream: MemoryObjectReceiveStream | None = None + self._send_stream: MemoryObjectSendStream | None = None + self._exit_stack: AsyncExitStack | None = None + + async def _worker(self, receive_stream: MemoryObjectReceiveStream) -> None: + """Run items from ``receive_stream`` in a task group. + + Returns: + None + """ + async with receive_stream, anyio.create_task_group() as task_group: + async for item in receive_stream: + fn, args, kwargs = item + if kwargs: + fn = partial(fn, **kwargs) + task_group.start_soon(fn, *args) # pyright: ignore[reportGeneralTypeIssues] + + async def __aenter__(self) -> SimpleEventEmitter: + self._exit_stack = AsyncExitStack() + send_stream, receive_stream = anyio.create_memory_object_stream(math.inf) # type: ignore[var-annotated] + self._send_stream = send_stream + task_group = anyio.create_task_group() + + await self._exit_stack.enter_async_context(task_group) + await self._exit_stack.enter_async_context(send_stream) + task_group.start_soon(self._worker, receive_stream) + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self._exit_stack: + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + + self._exit_stack = None + self._send_stream = None + + def emit(self, event_id: str, *args: Any, **kwargs: Any) -> None: + """Emit an event to all attached listeners. + + Args: + event_id: The ID of the event to emit, e.g 'my_event'. + *args: args to pass to the listener(s). + **kwargs: kwargs to pass to the listener(s) + + Returns: + None + """ + if not (self._send_stream and self._exit_stack): + raise RuntimeError("Emitter not initialized") + + if listeners := self.listeners.get(event_id): + for listener in listeners: + self._send_stream.send_nowait((listener.fn, args, kwargs)) + return + raise ImproperlyConfiguredException(f"no event listeners are registered for event ID: {event_id}") diff --git a/venv/lib/python3.11/site-packages/litestar/events/listener.py b/venv/lib/python3.11/site-packages/litestar/events/listener.py new file mode 100644 index 0000000..63c9848 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/events/listener.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from litestar.exceptions import ImproperlyConfiguredException +from litestar.utils import ensure_async_callable + +if TYPE_CHECKING: + from litestar.types import AnyCallable, AsyncAnyCallable + +__all__ = ("EventListener", "listener") + +logger = logging.getLogger(__name__) + + +class EventListener: + """Decorator for event listeners""" + + __slots__ = ("event_ids", "fn", "listener_id") + + fn: AsyncAnyCallable + + def __init__(self, *event_ids: str) -> None: + """Create a decorator for event handlers. + + Args: + *event_ids: The id of the event to listen to or a list of + event ids to listen to. + """ + self.event_ids: frozenset[str] = frozenset(event_ids) + + def __call__(self, fn: AnyCallable) -> EventListener: + """Decorate a callable by wrapping it inside an instance of EventListener. + + Args: + fn: Callable to decorate. + + Returns: + An instance of EventListener + """ + if not callable(fn): + raise ImproperlyConfiguredException("EventListener instance should be called as a decorator on a callable") + + self.fn = self.wrap_in_error_handler(ensure_async_callable(fn)) + + return self + + @staticmethod + def wrap_in_error_handler(fn: AsyncAnyCallable) -> AsyncAnyCallable: + """Wrap a listener function to handle errors. + + Listeners are executed concurrently in a TaskGroup, so we need to ensure that exceptions do not propagate + to the task group which results in any other unfinished listeners to be cancelled, and the receive stream to + be closed. + + See https://github.com/litestar-org/litestar/issues/2809 + + Args: + fn: The listener function to wrap. + """ + + async def wrapped(*args: Any, **kwargs: Any) -> None: + """Wrap a listener function to handle errors.""" + try: + await fn(*args, **kwargs) + except Exception as exc: + logger.exception("Error while executing listener %s: %s", fn.__name__, exc) + + return wrapped + + def __hash__(self) -> int: + return hash(self.event_ids) + hash(self.fn) + + +listener = EventListener diff --git a/venv/lib/python3.11/site-packages/litestar/exceptions/__init__.py b/venv/lib/python3.11/site-packages/litestar/exceptions/__init__.py new file mode 100644 index 0000000..09065c8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/exceptions/__init__.py @@ -0,0 +1,42 @@ +from .base_exceptions import LitestarException, LitestarWarning, MissingDependencyException, SerializationException +from .dto_exceptions import DTOFactoryException, InvalidAnnotationException +from .http_exceptions import ( + ClientException, + HTTPException, + ImproperlyConfiguredException, + InternalServerException, + MethodNotAllowedException, + NoRouteMatchFoundException, + NotAuthorizedException, + NotFoundException, + PermissionDeniedException, + ServiceUnavailableException, + TemplateNotFoundException, + TooManyRequestsException, + ValidationException, +) +from .websocket_exceptions import WebSocketDisconnect, WebSocketException + +__all__ = ( + "ClientException", + "DTOFactoryException", + "HTTPException", + "ImproperlyConfiguredException", + "InternalServerException", + "InvalidAnnotationException", + "LitestarException", + "LitestarWarning", + "MethodNotAllowedException", + "MissingDependencyException", + "NoRouteMatchFoundException", + "NotAuthorizedException", + "NotFoundException", + "PermissionDeniedException", + "SerializationException", + "ServiceUnavailableException", + "TemplateNotFoundException", + "TooManyRequestsException", + "ValidationException", + "WebSocketDisconnect", + "WebSocketException", +) diff --git a/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1538618 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/base_exceptions.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/base_exceptions.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..efd605d --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/base_exceptions.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/dto_exceptions.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/dto_exceptions.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..24633ee --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/dto_exceptions.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/http_exceptions.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/http_exceptions.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..76176d4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/http_exceptions.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/websocket_exceptions.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/websocket_exceptions.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..90c2ce2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/exceptions/__pycache__/websocket_exceptions.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/exceptions/base_exceptions.py b/venv/lib/python3.11/site-packages/litestar/exceptions/base_exceptions.py new file mode 100644 index 0000000..bbd4040 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/exceptions/base_exceptions.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from typing import Any + +__all__ = ("MissingDependencyException", "SerializationException", "LitestarException", "LitestarWarning") + + +class LitestarException(Exception): + """Base exception class from which all Litestar exceptions inherit.""" + + detail: str + + def __init__(self, *args: Any, detail: str = "") -> None: + """Initialize ``LitestarException``. + + Args: + *args: args are converted to :class:`str` before passing to :class:`Exception` + detail: detail of the exception. + """ + str_args = [str(arg) for arg in args if arg] + if not detail: + if str_args: + detail, *str_args = str_args + elif hasattr(self, "detail"): + detail = self.detail + self.detail = detail + super().__init__(*str_args) + + def __repr__(self) -> str: + if self.detail: + return f"{self.__class__.__name__} - {self.detail}" + return self.__class__.__name__ + + def __str__(self) -> str: + return " ".join((*self.args, self.detail)).strip() + + +class MissingDependencyException(LitestarException, ImportError): + """Missing optional dependency. + + This exception is raised only when a module depends on a dependency that has not been installed. + """ + + def __init__(self, package: str, install_package: str | None = None, extra: str | None = None) -> None: + super().__init__( + f"Package {package!r} is not installed but required. You can install it by running " + f"'pip install litestar[{extra or install_package or package}]' to install litestar with the required extra " + f"or 'pip install {install_package or package}' to install the package separately" + ) + + +class SerializationException(LitestarException): + """Encoding or decoding of an object failed.""" + + +class LitestarWarning(UserWarning): + """Base class for Litestar warnings""" diff --git a/venv/lib/python3.11/site-packages/litestar/exceptions/dto_exceptions.py b/venv/lib/python3.11/site-packages/litestar/exceptions/dto_exceptions.py new file mode 100644 index 0000000..037e3c6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/exceptions/dto_exceptions.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from litestar.exceptions import LitestarException + +__all__ = ("DTOFactoryException", "InvalidAnnotationException") + + +class DTOFactoryException(LitestarException): + """Base DTO exception type.""" + + +class InvalidAnnotationException(DTOFactoryException): + """Unexpected DTO type argument.""" diff --git a/venv/lib/python3.11/site-packages/litestar/exceptions/http_exceptions.py b/venv/lib/python3.11/site-packages/litestar/exceptions/http_exceptions.py new file mode 100644 index 0000000..bd384c3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/exceptions/http_exceptions.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import Any + +from litestar.exceptions.base_exceptions import LitestarException +from litestar.status_codes import ( + HTTP_400_BAD_REQUEST, + HTTP_401_UNAUTHORIZED, + HTTP_403_FORBIDDEN, + HTTP_404_NOT_FOUND, + HTTP_405_METHOD_NOT_ALLOWED, + HTTP_429_TOO_MANY_REQUESTS, + HTTP_500_INTERNAL_SERVER_ERROR, + HTTP_503_SERVICE_UNAVAILABLE, +) + +__all__ = ( + "ClientException", + "HTTPException", + "ImproperlyConfiguredException", + "InternalServerException", + "MethodNotAllowedException", + "NoRouteMatchFoundException", + "NotAuthorizedException", + "NotFoundException", + "PermissionDeniedException", + "ServiceUnavailableException", + "TemplateNotFoundException", + "TooManyRequestsException", + "ValidationException", +) + + +class HTTPException(LitestarException): + """Base exception for HTTP error responses. + + These exceptions carry information to construct an HTTP response. + """ + + status_code: int = HTTP_500_INTERNAL_SERVER_ERROR + """Exception status code.""" + detail: str + """Exception details or message.""" + headers: dict[str, str] | None + """Headers to attach to the response.""" + extra: dict[str, Any] | list[Any] | None + """An extra mapping to attach to the exception.""" + + def __init__( + self, + *args: Any, + detail: str = "", + status_code: int | None = None, + headers: dict[str, str] | None = None, + extra: dict[str, Any] | list[Any] | None = None, + ) -> None: + """Initialize ``HTTPException``. + + Set ``detail`` and ``args`` if not provided. + + Args: + *args: if ``detail`` kwarg not provided, first arg should be error detail. + detail: Exception details or message. Will default to args[0] if not provided. + status_code: Exception HTTP status code. + headers: Headers to set on the response. + extra: An extra mapping to attach to the exception. + """ + super().__init__(*args, detail=detail) + self.status_code = status_code or self.status_code + self.extra = extra + self.headers = headers + if not self.detail: + self.detail = HTTPStatus(self.status_code).phrase + self.args = (f"{self.status_code}: {self.detail}", *self.args) + + def __repr__(self) -> str: + return f"{self.status_code} - {self.__class__.__name__} - {self.detail}" + + def __str__(self) -> str: + return " ".join(self.args).strip() + + +class ImproperlyConfiguredException(HTTPException, ValueError): + """Application has improper configuration.""" + + +class ClientException(HTTPException): + """Client error.""" + + status_code: int = HTTP_400_BAD_REQUEST + + +class ValidationException(ClientException, ValueError): + """Client data validation error.""" + + +class NotAuthorizedException(ClientException): + """Request lacks valid authentication credentials for the requested resource.""" + + status_code = HTTP_401_UNAUTHORIZED + + +class PermissionDeniedException(ClientException): + """Request understood, but not authorized.""" + + status_code = HTTP_403_FORBIDDEN + + +class NotFoundException(ClientException, ValueError): + """Cannot find the requested resource.""" + + status_code = HTTP_404_NOT_FOUND + + +class MethodNotAllowedException(ClientException): + """Server knows the request method, but the target resource doesn't support this method.""" + + status_code = HTTP_405_METHOD_NOT_ALLOWED + + +class TooManyRequestsException(ClientException): + """Request limits have been exceeded.""" + + status_code = HTTP_429_TOO_MANY_REQUESTS + + +class InternalServerException(HTTPException): + """Server encountered an unexpected condition that prevented it from fulfilling the request.""" + + status_code: int = HTTP_500_INTERNAL_SERVER_ERROR + + +class ServiceUnavailableException(InternalServerException): + """Server is not ready to handle the request.""" + + status_code = HTTP_503_SERVICE_UNAVAILABLE + + +class NoRouteMatchFoundException(InternalServerException): + """A route with the given name could not be found.""" + + +class TemplateNotFoundException(InternalServerException): + """Referenced template could not be found.""" + + def __init__(self, *args: Any, template_name: str) -> None: + """Initialize ``TemplateNotFoundException``. + + Args: + *args (Any): Passed through to ``super().__init__()`` - should not include ``detail``. + template_name (str): Name of template that could not be found. + """ + super().__init__(*args, detail=f"Template {template_name} not found.") diff --git a/venv/lib/python3.11/site-packages/litestar/exceptions/websocket_exceptions.py b/venv/lib/python3.11/site-packages/litestar/exceptions/websocket_exceptions.py new file mode 100644 index 0000000..2fed9ca --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/exceptions/websocket_exceptions.py @@ -0,0 +1,40 @@ +from typing import Any + +from litestar.exceptions.base_exceptions import LitestarException +from litestar.status_codes import WS_1000_NORMAL_CLOSURE + +__all__ = ("WebSocketDisconnect", "WebSocketException") + + +class WebSocketException(LitestarException): + """Exception class for websocket related events.""" + + code: int + """Exception code. For custom exceptions, this should be a number in the 4000+ range. Other codes can be found in + ``litestar.status_code`` with the ``WS_`` prefix. + """ + + def __init__(self, *args: Any, detail: str, code: int = 4500) -> None: + """Initialize ``WebSocketException``. + + Args: + *args: Any exception args. + detail: Exception details. + code: Exception code. Should be a number in the >= 1000. + """ + super().__init__(*args, detail=detail) + self.code = code + + +class WebSocketDisconnect(WebSocketException): + """Exception class for websocket disconnect events.""" + + def __init__(self, *args: Any, detail: str, code: int = WS_1000_NORMAL_CLOSURE) -> None: + """Initialize ``WebSocketDisconnect``. + + Args: + *args: Any exception args. + detail: Exception details. + code: Exception code. Should be a number in the >= 1000. + """ + super().__init__(*args, detail=detail, code=code) diff --git a/venv/lib/python3.11/site-packages/litestar/file_system.py b/venv/lib/python3.11/site-packages/litestar/file_system.py new file mode 100644 index 0000000..fcb77c7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/file_system.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from stat import S_ISDIR +from typing import TYPE_CHECKING, Any, AnyStr, cast + +from anyio import AsyncFile, Path, open_file + +from litestar.concurrency import sync_to_thread +from litestar.exceptions import InternalServerException, NotAuthorizedException +from litestar.types.file_types import FileSystemProtocol +from litestar.utils.predicates import is_async_callable + +__all__ = ("BaseLocalFileSystem", "FileSystemAdapter") + + +if TYPE_CHECKING: + from os import stat_result + + from _typeshed import OpenBinaryMode + + from litestar.types import PathType + from litestar.types.file_types import FileInfo + + +class BaseLocalFileSystem(FileSystemProtocol): + """Base class for a local file system.""" + + async def info(self, path: PathType, **kwargs: Any) -> FileInfo: + """Retrieve information about a given file path. + + Args: + path: A file path. + **kwargs: Any additional kwargs. + + Returns: + A dictionary of file info. + """ + result = await Path(path).stat() + return await FileSystemAdapter.parse_stat_result(path=path, result=result) + + async def open(self, file: PathType, mode: str, buffering: int = -1) -> AsyncFile[AnyStr]: # pyright: ignore + """Return a file-like object from the filesystem. + + Notes: + - The return value must be a context-manager + + Args: + file: Path to the target file. + mode: Mode, similar to the built ``open``. + buffering: Buffer size. + """ + return await open_file(file=file, mode=mode, buffering=buffering) # type: ignore[call-overload, no-any-return] + + +class FileSystemAdapter: + """Wrapper around a ``FileSystemProtocol``, normalising its interface.""" + + def __init__(self, file_system: FileSystemProtocol) -> None: + """Initialize an adapter from a given ``file_system`` + + Args: + file_system: A filesystem class adhering to the :class:`FileSystemProtocol <litestar.types.FileSystemProtocol>` + """ + self.file_system = file_system + + async def info(self, path: PathType) -> FileInfo: + """Proxies the call to the underlying FS Spec's ``info`` method, ensuring it's done in an async fashion and with + strong typing. + + Args: + path: A file path to load the info for. + + Returns: + A dictionary of file info. + """ + try: + awaitable = ( + self.file_system.info(str(path)) + if is_async_callable(self.file_system.info) + else sync_to_thread(self.file_system.info, str(path)) + ) + return cast("FileInfo", await awaitable) + except FileNotFoundError as e: + raise e + except PermissionError as e: + raise NotAuthorizedException(f"failed to read {path} due to missing permissions") from e + except OSError as e: # pragma: no cover + raise InternalServerException from e + + async def open( + self, + file: PathType, + mode: OpenBinaryMode = "rb", + buffering: int = -1, + ) -> AsyncFile[bytes]: + """Return a file-like object from the filesystem. + + Notes: + - The return value must function correctly in a context ``with`` block. + + Args: + file: Path to the target file. + mode: Mode, similar to the built ``open``. + buffering: Buffer size. + """ + try: + if is_async_callable(self.file_system.open): # pyright: ignore + return cast( + "AsyncFile[bytes]", + await self.file_system.open( + file=file, + mode=mode, + buffering=buffering, + ), + ) + return AsyncFile(await sync_to_thread(self.file_system.open, file, mode, buffering)) # type: ignore[arg-type] + except PermissionError as e: + raise NotAuthorizedException(f"failed to open {file} due to missing permissions") from e + except OSError as e: + raise InternalServerException from e + + @staticmethod + async def parse_stat_result(path: PathType, result: stat_result) -> FileInfo: + """Convert a ``stat_result`` instance into a ``FileInfo``. + + Args: + path: The file path for which the :func:`stat_result <os.stat_result>` is provided. + result: The :func:`stat_result <os.stat_result>` instance. + + Returns: + A dictionary of file info. + """ + file_info: FileInfo = { + "created": result.st_ctime, + "gid": result.st_gid, + "ino": result.st_ino, + "islink": await Path(path).is_symlink(), + "mode": result.st_mode, + "mtime": result.st_mtime, + "name": str(path), + "nlink": result.st_nlink, + "size": result.st_size, + "type": "directory" if S_ISDIR(result.st_mode) else "file", + "uid": result.st_uid, + } + + if file_info["islink"]: + file_info["destination"] = str(await Path(path).readlink()).encode("utf-8") + try: + file_info["size"] = (await Path(path).stat()).st_size + except OSError: # pragma: no cover + file_info["size"] = result.st_size + + return file_info diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/__init__.py b/venv/lib/python3.11/site-packages/litestar/handlers/__init__.py new file mode 100644 index 0000000..822fe7e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/__init__.py @@ -0,0 +1,29 @@ +from .asgi_handlers import ASGIRouteHandler, asgi +from .base import BaseRouteHandler +from .http_handlers import HTTPRouteHandler, delete, get, head, patch, post, put, route +from .websocket_handlers import ( + WebsocketListener, + WebsocketListenerRouteHandler, + WebsocketRouteHandler, + websocket, + websocket_listener, +) + +__all__ = ( + "ASGIRouteHandler", + "BaseRouteHandler", + "HTTPRouteHandler", + "WebsocketListener", + "WebsocketRouteHandler", + "WebsocketListenerRouteHandler", + "asgi", + "delete", + "get", + "head", + "patch", + "post", + "put", + "route", + "websocket", + "websocket_listener", +) diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6976d76 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/__pycache__/asgi_handlers.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/__pycache__/asgi_handlers.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..85de8f8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/__pycache__/asgi_handlers.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2331fc7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/asgi_handlers.py b/venv/lib/python3.11/site-packages/litestar/handlers/asgi_handlers.py new file mode 100644 index 0000000..91f3517 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/asgi_handlers.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping, Sequence + +from litestar.exceptions import ImproperlyConfiguredException +from litestar.handlers.base import BaseRouteHandler +from litestar.types.builtin_types import NoneType +from litestar.utils.predicates import is_async_callable + +__all__ = ("ASGIRouteHandler", "asgi") + + +if TYPE_CHECKING: + from litestar.types import ( + ExceptionHandlersMap, + Guard, + MaybePartial, # noqa: F401 + ) + + +class ASGIRouteHandler(BaseRouteHandler): + """ASGI Route Handler decorator. + + Use this decorator to decorate ASGI applications. + """ + + __slots__ = ("is_mount", "is_static") + + def __init__( + self, + path: str | Sequence[str] | None = None, + *, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + name: str | None = None, + opt: Mapping[str, Any] | None = None, + is_mount: bool = False, + is_static: bool = False, + signature_namespace: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Initialize ``ASGIRouteHandler``. + + Args: + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + name: A string identifying the route handler. + opt: A string key mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or + :class:`ASGI Scope <.types.Scope>`. + path: A path fragment for the route handler function or a list of path fragments. If not given defaults to + ``/`` + is_mount: A boolean dictating whether the handler's paths should be regarded as mount paths. Mount path + accept any arbitrary paths that begin with the defined prefixed path. For example, a mount with the path + ``/some-path/`` will accept requests for ``/some-path/`` and any sub path under this, e.g. + ``/some-path/sub-path/`` etc. + is_static: A boolean dictating whether the handler's paths should be regarded as static paths. Static paths + are used to deliver static files. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + **kwargs: Any additional kwarg - will be set in the opt dictionary. + """ + self.is_mount = is_mount or is_static + self.is_static = is_static + super().__init__( + path, + exception_handlers=exception_handlers, + guards=guards, + name=name, + opt=opt, + signature_namespace=signature_namespace, + **kwargs, + ) + + def _validate_handler_function(self) -> None: + """Validate the route handler function once it's set by inspecting its return annotations.""" + super()._validate_handler_function() + + if not self.parsed_fn_signature.return_type.is_subclass_of(NoneType): + raise ImproperlyConfiguredException("ASGI handler functions should return 'None'") + + if any(key not in self.parsed_fn_signature.parameters for key in ("scope", "send", "receive")): + raise ImproperlyConfiguredException( + "ASGI handler functions should define 'scope', 'send' and 'receive' arguments" + ) + if not is_async_callable(self.fn): + raise ImproperlyConfiguredException("Functions decorated with 'asgi' must be async functions") + + +asgi = ASGIRouteHandler diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/base.py b/venv/lib/python3.11/site-packages/litestar/handlers/base.py new file mode 100644 index 0000000..9dbb70e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/base.py @@ -0,0 +1,577 @@ +from __future__ import annotations + +from copy import copy +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence, cast + +from litestar._signature import SignatureModel +from litestar.config.app import ExperimentalFeatures +from litestar.di import Provide +from litestar.dto import DTOData +from litestar.exceptions import ImproperlyConfiguredException +from litestar.plugins import DIPlugin, PluginRegistry +from litestar.serialization import default_deserializer, default_serializer +from litestar.types import ( + Dependencies, + Empty, + ExceptionHandlersMap, + Guard, + Middleware, + TypeDecodersSequence, + TypeEncodersMap, +) +from litestar.typing import FieldDefinition +from litestar.utils import ensure_async_callable, get_name, normalize_path +from litestar.utils.helpers import unwrap_partial +from litestar.utils.signature import ParsedSignature, add_types_to_signature_namespace + +if TYPE_CHECKING: + from typing_extensions import Self + + from litestar.app import Litestar + from litestar.connection import ASGIConnection + from litestar.controller import Controller + from litestar.dto import AbstractDTO + from litestar.dto._backend import DTOBackend + from litestar.params import ParameterKwarg + from litestar.router import Router + from litestar.types import AnyCallable, AsyncAnyCallable, ExceptionHandler + from litestar.types.empty import EmptyType + +__all__ = ("BaseRouteHandler",) + + +class BaseRouteHandler: + """Base route handler. + + Serves as a subclass for all route handlers + """ + + __slots__ = ( + "_fn", + "_parsed_data_field", + "_parsed_fn_signature", + "_parsed_return_field", + "_resolved_data_dto", + "_resolved_dependencies", + "_resolved_guards", + "_resolved_layered_parameters", + "_resolved_return_dto", + "_resolved_signature_namespace", + "_resolved_type_decoders", + "_resolved_type_encoders", + "_signature_model", + "dependencies", + "dto", + "exception_handlers", + "guards", + "middleware", + "name", + "opt", + "owner", + "paths", + "return_dto", + "signature_namespace", + "type_decoders", + "type_encoders", + ) + + def __init__( + self, + path: str | Sequence[str] | None = None, + *, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + middleware: Sequence[Middleware] | None = None, + name: str | None = None, + opt: Mapping[str, Any] | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + signature_namespace: Mapping[str, Any] | None = None, + signature_types: Sequence[Any] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + **kwargs: Any, + ) -> None: + """Initialize ``HTTPRouteHandler``. + + Args: + path: A path fragment for the route handler function or a sequence of path fragments. If not given defaults + to ``/`` + dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + name: A string identifying the route handler. + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or + :class:`ASGI Scope <.types.Scope>`. + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature + modelling. + signature_types: A sequence of types for use in forward reference resolution during signature modeling. + These types will be added to the signature namespace using their ``__name__`` attribute. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + **kwargs: Any additional kwarg - will be set in the opt dictionary. + """ + self._parsed_fn_signature: ParsedSignature | EmptyType = Empty + self._parsed_return_field: FieldDefinition | EmptyType = Empty + self._parsed_data_field: FieldDefinition | None | EmptyType = Empty + self._resolved_data_dto: type[AbstractDTO] | None | EmptyType = Empty + self._resolved_dependencies: dict[str, Provide] | EmptyType = Empty + self._resolved_guards: list[Guard] | EmptyType = Empty + self._resolved_layered_parameters: dict[str, FieldDefinition] | EmptyType = Empty + self._resolved_return_dto: type[AbstractDTO] | None | EmptyType = Empty + self._resolved_signature_namespace: dict[str, Any] | EmptyType = Empty + self._resolved_type_decoders: TypeDecodersSequence | EmptyType = Empty + self._resolved_type_encoders: TypeEncodersMap | EmptyType = Empty + self._signature_model: type[SignatureModel] | EmptyType = Empty + + self.dependencies = dependencies + self.dto = dto + self.exception_handlers = exception_handlers + self.guards = guards + self.middleware = middleware + self.name = name + self.opt = dict(opt or {}) + self.opt.update(**kwargs) + self.owner: Controller | Router | None = None + self.return_dto = return_dto + self.signature_namespace = add_types_to_signature_namespace( + signature_types or [], dict(signature_namespace or {}) + ) + self.type_decoders = type_decoders + self.type_encoders = type_encoders + + self.paths = ( + {normalize_path(p) for p in path} if path and isinstance(path, list) else {normalize_path(path or "/")} # type: ignore[arg-type] + ) + + def __call__(self, fn: AsyncAnyCallable) -> Self: + """Replace a function with itself.""" + self._fn = fn + return self + + @property + def handler_id(self) -> str: + """A unique identifier used for generation of DTOs.""" + return f"{self!s}::{sum(id(layer) for layer in self.ownership_layers)}" + + @property + def default_deserializer(self) -> Callable[[Any, Any], Any]: + """Get a default deserializer for the route handler. + + Returns: + A default deserializer for the route handler. + + """ + return partial(default_deserializer, type_decoders=self.resolve_type_decoders()) + + @property + def default_serializer(self) -> Callable[[Any], Any]: + """Get a default serializer for the route handler. + + Returns: + A default serializer for the route handler. + + """ + return partial(default_serializer, type_encoders=self.resolve_type_encoders()) + + @property + def signature_model(self) -> type[SignatureModel]: + """Get the signature model for the route handler. + + Returns: + A signature model for the route handler. + + """ + if self._signature_model is Empty: + self._signature_model = SignatureModel.create( + dependency_name_set=self.dependency_name_set, + fn=cast("AnyCallable", self.fn), + parsed_signature=self.parsed_fn_signature, + data_dto=self.resolve_data_dto(), + type_decoders=self.resolve_type_decoders(), + ) + return self._signature_model + + @property + def fn(self) -> AsyncAnyCallable: + """Get the handler function. + + Raises: + ImproperlyConfiguredException: if handler fn is not set. + + Returns: + Handler function + """ + if not hasattr(self, "_fn"): + raise ImproperlyConfiguredException("No callable has been registered for this handler") + return self._fn + + @property + def parsed_fn_signature(self) -> ParsedSignature: + """Return the parsed signature of the handler function. + + This method is memoized so the computation occurs only once. + + Returns: + A ParsedSignature instance + """ + if self._parsed_fn_signature is Empty: + self._parsed_fn_signature = ParsedSignature.from_fn( + unwrap_partial(self.fn), self.resolve_signature_namespace() + ) + + return self._parsed_fn_signature + + @property + def parsed_return_field(self) -> FieldDefinition: + if self._parsed_return_field is Empty: + self._parsed_return_field = self.parsed_fn_signature.return_type + return self._parsed_return_field + + @property + def parsed_data_field(self) -> FieldDefinition | None: + if self._parsed_data_field is Empty: + self._parsed_data_field = self.parsed_fn_signature.parameters.get("data") + return self._parsed_data_field + + @property + def handler_name(self) -> str: + """Get the name of the handler function. + + Raises: + ImproperlyConfiguredException: if handler fn is not set. + + Returns: + Name of the handler function + """ + return get_name(unwrap_partial(self.fn)) + + @property + def dependency_name_set(self) -> set[str]: + """Set of all dependency names provided in the handler's ownership layers.""" + layered_dependencies = (layer.dependencies or {} for layer in self.ownership_layers) + return {name for layer in layered_dependencies for name in layer} # pyright: ignore + + @property + def ownership_layers(self) -> list[Self | Controller | Router]: + """Return the handler layers from the app down to the route handler. + + ``app -> ... -> route handler`` + """ + layers = [] + + cur: Any = self + while cur: + layers.append(cur) + cur = cur.owner + + return list(reversed(layers)) + + @property + def app(self) -> Litestar: + return cast("Litestar", self.ownership_layers[0]) + + def resolve_type_encoders(self) -> TypeEncodersMap: + """Return a merged type_encoders mapping. + + This method is memoized so the computation occurs only once. + + Returns: + A dict of type encoders + """ + if self._resolved_type_encoders is Empty: + self._resolved_type_encoders = {} + + for layer in self.ownership_layers: + if type_encoders := getattr(layer, "type_encoders", None): + self._resolved_type_encoders.update(type_encoders) + return cast("TypeEncodersMap", self._resolved_type_encoders) + + def resolve_type_decoders(self) -> TypeDecodersSequence: + """Return a merged type_encoders mapping. + + This method is memoized so the computation occurs only once. + + Returns: + A dict of type encoders + """ + if self._resolved_type_decoders is Empty: + self._resolved_type_decoders = [] + + for layer in self.ownership_layers: + if type_decoders := getattr(layer, "type_decoders", None): + self._resolved_type_decoders.extend(list(type_decoders)) + return cast("TypeDecodersSequence", self._resolved_type_decoders) + + def resolve_layered_parameters(self) -> dict[str, FieldDefinition]: + """Return all parameters declared above the handler.""" + if self._resolved_layered_parameters is Empty: + parameter_kwargs: dict[str, ParameterKwarg] = {} + + for layer in self.ownership_layers: + parameter_kwargs.update(getattr(layer, "parameters", {}) or {}) + + self._resolved_layered_parameters = { + key: FieldDefinition.from_kwarg(name=key, annotation=parameter.annotation, kwarg_definition=parameter) + for key, parameter in parameter_kwargs.items() + } + + return self._resolved_layered_parameters + + def resolve_guards(self) -> list[Guard]: + """Return all guards in the handlers scope, starting from highest to current layer.""" + if self._resolved_guards is Empty: + self._resolved_guards = [] + + for layer in self.ownership_layers: + self._resolved_guards.extend(layer.guards or []) # pyright: ignore + + self._resolved_guards = cast( + "list[Guard]", [ensure_async_callable(guard) for guard in self._resolved_guards] + ) + + return self._resolved_guards + + def _get_plugin_registry(self) -> PluginRegistry | None: + from litestar.app import Litestar + + root_owner = self.ownership_layers[0] + if isinstance(root_owner, Litestar): + return root_owner.plugins + return None + + def resolve_dependencies(self) -> dict[str, Provide]: + """Return all dependencies correlating to handler function's kwargs that exist in the handler's scope.""" + plugin_registry = self._get_plugin_registry() + if self._resolved_dependencies is Empty: + self._resolved_dependencies = {} + for layer in self.ownership_layers: + for key, provider in (layer.dependencies or {}).items(): + self._resolved_dependencies[key] = self._resolve_dependency( + key=key, provider=provider, plugin_registry=plugin_registry + ) + + return self._resolved_dependencies + + def _resolve_dependency( + self, key: str, provider: Provide | AnyCallable, plugin_registry: PluginRegistry | None + ) -> Provide: + if not isinstance(provider, Provide): + provider = Provide(provider) + + if self._resolved_dependencies is not Empty: # pragma: no cover + self._validate_dependency_is_unique(dependencies=self._resolved_dependencies, key=key, provider=provider) + + if not getattr(provider, "parsed_fn_signature", None): + dependency = unwrap_partial(provider.dependency) + plugin: DIPlugin | None = None + if plugin_registry: + plugin = next( + (p for p in plugin_registry.di if isinstance(p, DIPlugin) and p.has_typed_init(dependency)), + None, + ) + if plugin: + signature, init_type_hints = plugin.get_typed_init(dependency) + provider.parsed_fn_signature = ParsedSignature.from_signature(signature, init_type_hints) + else: + provider.parsed_fn_signature = ParsedSignature.from_fn(dependency, self.resolve_signature_namespace()) + + if not getattr(provider, "signature_model", None): + provider.signature_model = SignatureModel.create( + dependency_name_set=self.dependency_name_set, + fn=provider.dependency, + parsed_signature=provider.parsed_fn_signature, + data_dto=self.resolve_data_dto(), + type_decoders=self.resolve_type_decoders(), + ) + return provider + + def resolve_middleware(self) -> list[Middleware]: + """Build the middleware stack for the RouteHandler and return it. + + The middlewares are added from top to bottom (``app -> router -> controller -> route handler``) and then + reversed. + """ + resolved_middleware: list[Middleware] = [] + for layer in self.ownership_layers: + resolved_middleware.extend(layer.middleware or []) # pyright: ignore + return list(reversed(resolved_middleware)) + + def resolve_exception_handlers(self) -> ExceptionHandlersMap: + """Resolve the exception_handlers by starting from the route handler and moving up. + + This method is memoized so the computation occurs only once. + """ + resolved_exception_handlers: dict[int | type[Exception], ExceptionHandler] = {} + for layer in self.ownership_layers: + resolved_exception_handlers.update(layer.exception_handlers or {}) # pyright: ignore + return resolved_exception_handlers + + def resolve_opts(self) -> None: + """Build the route handler opt dictionary by going from top to bottom. + + When merging keys from multiple layers, if the same key is defined by multiple layers, the value from the + layer closest to the response handler will take precedence. + """ + + opt: dict[str, Any] = {} + for layer in self.ownership_layers: + opt.update(layer.opt or {}) # pyright: ignore + + self.opt = opt + + def resolve_signature_namespace(self) -> dict[str, Any]: + """Build the route handler signature namespace dictionary by going from top to bottom. + + When merging keys from multiple layers, if the same key is defined by multiple layers, the value from the + layer closest to the response handler will take precedence. + """ + if self._resolved_layered_parameters is Empty: + ns: dict[str, Any] = {} + for layer in self.ownership_layers: + ns.update(layer.signature_namespace) + + self._resolved_signature_namespace = ns + return cast("dict[str, Any]", self._resolved_signature_namespace) + + def _get_dto_backend_cls(self) -> type[DTOBackend] | None: + if ExperimentalFeatures.DTO_CODEGEN in self.app.experimental_features: + from litestar.dto._codegen_backend import DTOCodegenBackend + + return DTOCodegenBackend + return None + + def resolve_data_dto(self) -> type[AbstractDTO] | None: + """Resolve the data_dto by starting from the route handler and moving up. + If a handler is found it is returned, otherwise None is set. + This method is memoized so the computation occurs only once. + + Returns: + An optional :class:`DTO type <.dto.base_dto.AbstractDTO>` + """ + if self._resolved_data_dto is Empty: + if data_dtos := cast( + "list[type[AbstractDTO] | None]", + [layer.dto for layer in self.ownership_layers if layer.dto is not Empty], + ): + data_dto: type[AbstractDTO] | None = data_dtos[-1] + elif self.parsed_data_field and ( + plugins_for_data_type := [ + plugin + for plugin in self.app.plugins.serialization + if self.parsed_data_field.match_predicate_recursively(plugin.supports_type) + ] + ): + data_dto = plugins_for_data_type[0].create_dto_for_type(self.parsed_data_field) + else: + data_dto = None + + if self.parsed_data_field and data_dto: + data_dto.create_for_field_definition( + field_definition=self.parsed_data_field, + handler_id=self.handler_id, + backend_cls=self._get_dto_backend_cls(), + ) + + self._resolved_data_dto = data_dto + + return self._resolved_data_dto + + def resolve_return_dto(self) -> type[AbstractDTO] | None: + """Resolve the return_dto by starting from the route handler and moving up. + If a handler is found it is returned, otherwise None is set. + This method is memoized so the computation occurs only once. + + Returns: + An optional :class:`DTO type <.dto.base_dto.AbstractDTO>` + """ + if self._resolved_return_dto is Empty: + if return_dtos := cast( + "list[type[AbstractDTO] | None]", + [layer.return_dto for layer in self.ownership_layers if layer.return_dto is not Empty], + ): + return_dto: type[AbstractDTO] | None = return_dtos[-1] + elif plugins_for_return_type := [ + plugin + for plugin in self.app.plugins.serialization + if self.parsed_return_field.match_predicate_recursively(plugin.supports_type) + ]: + return_dto = plugins_for_return_type[0].create_dto_for_type(self.parsed_return_field) + else: + return_dto = self.resolve_data_dto() + + if return_dto and return_dto.is_supported_model_type_field(self.parsed_return_field): + return_dto.create_for_field_definition( + field_definition=self.parsed_return_field, + handler_id=self.handler_id, + backend_cls=self._get_dto_backend_cls(), + ) + self._resolved_return_dto = return_dto + else: + self._resolved_return_dto = None + + return self._resolved_return_dto + + async def authorize_connection(self, connection: ASGIConnection) -> None: + """Ensure the connection is authorized by running all the route guards in scope.""" + for guard in self.resolve_guards(): + await guard(connection, copy(self)) # type: ignore[misc] + + @staticmethod + def _validate_dependency_is_unique(dependencies: dict[str, Provide], key: str, provider: Provide) -> None: + """Validate that a given provider has not been already defined under a different key.""" + for dependency_key, value in dependencies.items(): + if provider == value: + raise ImproperlyConfiguredException( + f"Provider for key {key} is already defined under the different key {dependency_key}. " + f"If you wish to override a provider, it must have the same key." + ) + + def on_registration(self, app: Litestar) -> None: + """Called once per handler when the app object is instantiated. + + Args: + app: The :class:`Litestar<.app.Litestar>` app object. + + Returns: + None + """ + self._validate_handler_function() + self.resolve_dependencies() + self.resolve_guards() + self.resolve_middleware() + self.resolve_opts() + self.resolve_data_dto() + self.resolve_return_dto() + + def _validate_handler_function(self) -> None: + """Validate the route handler function once set by inspecting its return annotations.""" + if ( + self.parsed_data_field is not None + and self.parsed_data_field.is_subclass_of(DTOData) + and not self.resolve_data_dto() + ): + raise ImproperlyConfiguredException( + f"Handler function {self.handler_name} has a data parameter that is a subclass of DTOData but no " + "DTO has been registered for it." + ) + + def __str__(self) -> str: + """Return a unique identifier for the route handler. + + Returns: + A string + """ + target: type[AsyncAnyCallable] | AsyncAnyCallable # pyright: ignore + target = unwrap_partial(self.fn) + if not hasattr(target, "__qualname__"): + target = type(target) + return f"{target.__module__}.{target.__qualname__}" diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__init__.py b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__init__.py new file mode 100644 index 0000000..844f046 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__init__.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from .base import HTTPRouteHandler, route +from .decorators import delete, get, head, patch, post, put + +__all__ = ( + "HTTPRouteHandler", + "delete", + "get", + "head", + "patch", + "post", + "put", + "route", +) diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..848ac8f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__pycache__/_utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__pycache__/_utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..22a7f67 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__pycache__/_utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..eb39166 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__pycache__/decorators.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__pycache__/decorators.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0acd7c8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/__pycache__/decorators.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/_utils.py b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/_utils.py new file mode 100644 index 0000000..ec95145 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/_utils.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +from functools import lru_cache +from inspect import isawaitable +from typing import TYPE_CHECKING, Any, Sequence, cast + +from litestar.enums import HttpMethod +from litestar.exceptions import ValidationException +from litestar.response import Response +from litestar.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT +from litestar.types.builtin_types import NoneType + +if TYPE_CHECKING: + from litestar.app import Litestar + from litestar.background_tasks import BackgroundTask, BackgroundTasks + from litestar.connection import Request + from litestar.datastructures import Cookie, ResponseHeader + from litestar.types import AfterRequestHookHandler, ASGIApp, AsyncAnyCallable, Method, TypeEncodersMap + from litestar.typing import FieldDefinition + +__all__ = ( + "create_data_handler", + "create_generic_asgi_response_handler", + "create_response_handler", + "get_default_status_code", + "is_empty_response_annotation", + "normalize_headers", + "normalize_http_method", +) + + +def create_data_handler( + after_request: AfterRequestHookHandler | None, + background: BackgroundTask | BackgroundTasks | None, + cookies: frozenset[Cookie], + headers: frozenset[ResponseHeader], + media_type: str, + response_class: type[Response], + status_code: int, + type_encoders: TypeEncodersMap | None, +) -> AsyncAnyCallable: + """Create a handler function for arbitrary data. + + Args: + after_request: An after request handler. + background: A background task or background tasks. + cookies: A set of pre-defined cookies. + headers: A set of response headers. + media_type: The response media type. + response_class: The response class to use. + status_code: The response status code. + type_encoders: A mapping of types to encoder functions. + + Returns: + A handler function. + + """ + + async def handler( + data: Any, + request: Request[Any, Any, Any], + app: Litestar, + **kwargs: Any, + ) -> ASGIApp: + if isawaitable(data): + data = await data + + response = response_class( + background=background, + content=data, + media_type=media_type, + status_code=status_code, + type_encoders=type_encoders, + ) + + if after_request: + response = await after_request(response) # type: ignore[arg-type,misc] + + return response.to_asgi_response(app=None, request=request, headers=normalize_headers(headers), cookies=cookies) # pyright: ignore + + return handler + + +def create_generic_asgi_response_handler(after_request: AfterRequestHookHandler | None) -> AsyncAnyCallable: + """Create a handler function for Responses. + + Args: + after_request: An after request handler. + + Returns: + A handler function. + """ + + async def handler(data: ASGIApp, **kwargs: Any) -> ASGIApp: + return await after_request(data) if after_request else data # type: ignore[arg-type, misc, no-any-return] + + return handler + + +@lru_cache(1024) +def normalize_headers(headers: frozenset[ResponseHeader]) -> dict[str, str]: + """Given a dictionary of ResponseHeader, filter them and return a dictionary of values. + + Args: + headers: A dictionary of :class:`ResponseHeader <litestar.datastructures.ResponseHeader>` values + + Returns: + A string keyed dictionary of normalized values + """ + return { + header.name: cast("str", header.value) # we know value to be a string at this point because we validate it + # that it's not None when initializing a header with documentation_only=True + for header in headers + if not header.documentation_only + } + + +def create_response_handler( + after_request: AfterRequestHookHandler | None, + background: BackgroundTask | BackgroundTasks | None, + cookies: frozenset[Cookie], + headers: frozenset[ResponseHeader], + media_type: str, + status_code: int, + type_encoders: TypeEncodersMap | None, +) -> AsyncAnyCallable: + """Create a handler function for Litestar Responses. + + Args: + after_request: An after request handler. + background: A background task or background tasks. + cookies: A set of pre-defined cookies. + headers: A set of response headers. + media_type: The response media type. + status_code: The response status code. + type_encoders: A mapping of types to encoder functions. + + Returns: + A handler function. + """ + + normalized_headers = normalize_headers(headers) + cookie_list = list(cookies) + + async def handler( + data: Response, + app: Litestar, + request: Request, + **kwargs: Any, # kwargs is for return dto + ) -> ASGIApp: + response = await after_request(data) if after_request else data # type:ignore[arg-type,misc] + return response.to_asgi_response( # type: ignore[no-any-return] + app=None, + background=background, + cookies=cookie_list, + headers=normalized_headers, + media_type=media_type, + request=request, + status_code=status_code, + type_encoders=type_encoders, + ) + + return handler + + +def normalize_http_method(http_methods: HttpMethod | Method | Sequence[HttpMethod | Method]) -> set[Method]: + """Normalize HTTP method(s) into a set of upper-case method names. + + Args: + http_methods: A value for http method. + + Returns: + A normalized set of http methods. + """ + output: set[str] = set() + + if isinstance(http_methods, str): + http_methods = [http_methods] # pyright: ignore + + for method in http_methods: + method_name = method.value.upper() if isinstance(method, HttpMethod) else method.upper() + if method_name not in HTTP_METHOD_NAMES: + raise ValidationException(f"Invalid HTTP method: {method_name}") + output.add(method_name) + + return cast("set[Method]", output) + + +def get_default_status_code(http_methods: set[Method]) -> int: + """Return the default status code for a given set of HTTP methods. + + Args: + http_methods: A set of method strings + + Returns: + A status code + """ + if HttpMethod.POST in http_methods: + return HTTP_201_CREATED + if HttpMethod.DELETE in http_methods: + return HTTP_204_NO_CONTENT + return HTTP_200_OK + + +def is_empty_response_annotation(return_annotation: FieldDefinition) -> bool: + """Return whether the return annotation is an empty response. + + Args: + return_annotation: A return annotation. + + Returns: + Whether the return annotation is an empty response. + """ + return ( + return_annotation.is_subclass_of(NoneType) + or return_annotation.is_subclass_of(Response) + and return_annotation.has_inner_subclass_of(NoneType) + ) + + +HTTP_METHOD_NAMES = {m.value for m in HttpMethod} diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/base.py b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/base.py new file mode 100644 index 0000000..757253e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/base.py @@ -0,0 +1,591 @@ +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, AnyStr, Mapping, Sequence, TypedDict, cast + +from litestar._layers.utils import narrow_response_cookies, narrow_response_headers +from litestar.connection import Request +from litestar.datastructures.cookie import Cookie +from litestar.datastructures.response_header import ResponseHeader +from litestar.enums import HttpMethod, MediaType +from litestar.exceptions import ( + HTTPException, + ImproperlyConfiguredException, +) +from litestar.handlers.base import BaseRouteHandler +from litestar.handlers.http_handlers._utils import ( + create_data_handler, + create_generic_asgi_response_handler, + create_response_handler, + get_default_status_code, + is_empty_response_annotation, + normalize_http_method, +) +from litestar.openapi.spec import Operation +from litestar.response import Response +from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_304_NOT_MODIFIED +from litestar.types import ( + AfterRequestHookHandler, + AfterResponseHookHandler, + AnyCallable, + ASGIApp, + BeforeRequestHookHandler, + CacheKeyBuilder, + Dependencies, + Empty, + EmptyType, + ExceptionHandlersMap, + Guard, + Method, + Middleware, + ResponseCookies, + ResponseHeaders, + TypeEncodersMap, +) +from litestar.utils import ensure_async_callable +from litestar.utils.predicates import is_async_callable +from litestar.utils.warnings import warn_implicit_sync_to_thread, warn_sync_to_thread_with_async_callable + +if TYPE_CHECKING: + from typing import Any, Awaitable, Callable + + from litestar.app import Litestar + from litestar.background_tasks import BackgroundTask, BackgroundTasks + from litestar.config.response_cache import CACHE_FOREVER + from litestar.datastructures import CacheControlHeader, ETag + from litestar.dto import AbstractDTO + from litestar.openapi.datastructures import ResponseSpec + from litestar.openapi.spec import SecurityRequirement + from litestar.types.callable_types import AsyncAnyCallable, OperationIDCreator + from litestar.types.composite_types import TypeDecodersSequence + +__all__ = ("HTTPRouteHandler", "route") + + +class ResponseHandlerMap(TypedDict): + default_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType + response_type_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType + + +class HTTPRouteHandler(BaseRouteHandler): + """HTTP Route Decorator. + + Use this decorator to decorate an HTTP handler with multiple methods. + """ + + __slots__ = ( + "_resolved_after_response", + "_resolved_before_request", + "_response_handler_mapping", + "_resolved_include_in_schema", + "_resolved_tags", + "_resolved_security", + "after_request", + "after_response", + "background", + "before_request", + "cache", + "cache_control", + "cache_key_builder", + "content_encoding", + "content_media_type", + "deprecated", + "description", + "etag", + "has_sync_callable", + "http_methods", + "include_in_schema", + "media_type", + "operation_class", + "operation_id", + "raises", + "request_class", + "response_class", + "response_cookies", + "response_description", + "response_headers", + "responses", + "security", + "status_code", + "summary", + "sync_to_thread", + "tags", + "template_name", + ) + + has_sync_callable: bool + + def __init__( + self, + path: str | Sequence[str] | None = None, + *, + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + background: BackgroundTask | BackgroundTasks | None = None, + before_request: BeforeRequestHookHandler | None = None, + cache: bool | int | type[CACHE_FOREVER] = False, + cache_control: CacheControlHeader | None = None, + cache_key_builder: CacheKeyBuilder | None = None, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + etag: ETag | None = None, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + http_method: HttpMethod | Method | Sequence[HttpMethod | Method], + media_type: MediaType | str | None = None, + middleware: Sequence[Middleware] | None = None, + name: str | None = None, + opt: Mapping[str, Any] | None = None, + request_class: type[Request] | None = None, + response_class: type[Response] | None = None, + response_cookies: ResponseCookies | None = None, + response_headers: ResponseHeaders | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + status_code: int | None = None, + sync_to_thread: bool | None = None, + # OpenAPI related attributes + content_encoding: str | None = None, + content_media_type: str | None = None, + deprecated: bool = False, + description: str | None = None, + include_in_schema: bool | EmptyType = Empty, + operation_class: type[Operation] = Operation, + operation_id: str | OperationIDCreator | None = None, + raises: Sequence[type[HTTPException]] | None = None, + response_description: str | None = None, + responses: Mapping[int, ResponseSpec] | None = None, + signature_namespace: Mapping[str, Any] | None = None, + security: Sequence[SecurityRequirement] | None = None, + summary: str | None = None, + tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + **kwargs: Any, + ) -> None: + """Initialize ``HTTPRouteHandler``. + + Args: + path: A path fragment for the route handler function or a sequence of path fragments. + If not given defaults to ``/`` + after_request: A sync or async function executed before a :class:`Request <.connection.Request>` is passed + to any route handler. If this function returns a value, the request will not reach the route handler, + and instead this value will be used. + after_response: A sync or async function called after the response has been awaited. It receives the + :class:`Request <.connection.Request>` object and should not return any values. + background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or + :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. + Defaults to ``None``. + before_request: A sync or async function called immediately before calling the route handler. Receives + the :class:`Request <.connection.Request>` instance and any non-``None`` return value is used for the + response, bypassing the route handler. + cache: Enables response caching if configured on the application level. Valid values are ``True`` or a + number of seconds (e.g. ``120``) to cache the response. + cache_control: A ``cache-control`` header of type + :class:`CacheControlHeader <.datastructures.CacheControlHeader>` that will be added to the response. + cache_key_builder: A :class:`cache-key builder function <.types.CacheKeyBuilder>`. Allows for customization + of the cache key if caching is configured on the application level. + dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + etag: An ``etag`` header of type :class:`ETag <.datastructures.ETag>` that will be added to the response. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + http_method: An :class:`http method string <.types.Method>`, a member of the enum + :class:`HttpMethod <.enums.HttpMethod>` or a list of these that correlates to the methods the route + handler function should handle. + media_type: A member of the :class:`MediaType <.enums.MediaType>` enum or a string with a valid IANA + Media-Type. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + name: A string identifying the route handler. + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or + :class:`ASGI Scope <.types.Scope>`. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's + default request. + response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's + default response. + response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. + response_headers: A string keyed mapping of :class:`ResponseHeader <.datastructures.ResponseHeader>` + instances. + responses: A mapping of additional status codes and a description of their expected content. + This information will be included in the OpenAPI schema + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. + status_code: An http status code for the response. Defaults to ``200`` for mixed method or ``GET``, ``PUT`` and + ``PATCH``, ``201`` for ``POST`` and ``204`` for ``DELETE``. + sync_to_thread: A boolean dictating whether the handler function will be executed in a worker thread or the + main event loop. This has an effect only for sync handler functions. See using sync handler functions. + content_encoding: A string describing the encoding of the content, e.g. ``"base64"``. + content_media_type: A string designating the media-type of the content, e.g. ``"image/png"``. + deprecated: A boolean dictating whether this route should be marked as deprecated in the OpenAPI schema. + description: Text used for the route's schema description section. + include_in_schema: A boolean flag dictating whether the route handler should be documented in the OpenAPI schema. + operation_class: :class:`Operation <.openapi.spec.operation.Operation>` to be used with the route's OpenAPI schema. + operation_id: Either a string or a callable returning a string. An identifier used for the route's schema operationId. + raises: A list of exception classes extending from litestar.HttpException that is used for the OpenAPI documentation. + This list should describe all exceptions raised within the route handler's function/method. The Litestar + ValidationException will be added automatically for the schema if any validation is involved. + response_description: Text used for the route's response schema description section. + security: A sequence of dictionaries that contain information about which security scheme can be used on the endpoint. + summary: Text used for the route's schema summary section. + tags: A sequence of string tags that will be appended to the OpenAPI schema. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + **kwargs: Any additional kwarg - will be set in the opt dictionary. + """ + if not http_method: + raise ImproperlyConfiguredException("An http_method kwarg is required") + + self.http_methods = normalize_http_method(http_methods=http_method) + self.status_code = status_code or get_default_status_code(http_methods=self.http_methods) + + super().__init__( + path=path, + dependencies=dependencies, + dto=dto, + exception_handlers=exception_handlers, + guards=guards, + middleware=middleware, + name=name, + opt=opt, + return_dto=return_dto, + signature_namespace=signature_namespace, + type_decoders=type_decoders, + type_encoders=type_encoders, + **kwargs, + ) + + self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore + self.after_response = ensure_async_callable(after_response) if after_response else None + self.background = background + self.before_request = ensure_async_callable(before_request) if before_request else None + self.cache = cache + self.cache_control = cache_control + self.cache_key_builder = cache_key_builder + self.etag = etag + self.media_type: MediaType | str = media_type or "" + self.request_class = request_class + self.response_class = response_class + self.response_cookies: Sequence[Cookie] | None = narrow_response_cookies(response_cookies) + self.response_headers: Sequence[ResponseHeader] | None = narrow_response_headers(response_headers) + + self.sync_to_thread = sync_to_thread + # OpenAPI related attributes + self.content_encoding = content_encoding + self.content_media_type = content_media_type + self.deprecated = deprecated + self.description = description + self.include_in_schema = include_in_schema + self.operation_class = operation_class + self.operation_id = operation_id + self.raises = raises + self.response_description = response_description + self.summary = summary + self.tags = tags + self.security = security + self.responses = responses + # memoized attributes, defaulted to Empty + self._resolved_after_response: AsyncAnyCallable | None | EmptyType = Empty + self._resolved_before_request: AsyncAnyCallable | None | EmptyType = Empty + self._response_handler_mapping: ResponseHandlerMap = {"default_handler": Empty, "response_type_handler": Empty} + self._resolved_include_in_schema: bool | EmptyType = Empty + self._resolved_security: list[SecurityRequirement] | EmptyType = Empty + self._resolved_tags: list[str] | EmptyType = Empty + + def __call__(self, fn: AnyCallable) -> HTTPRouteHandler: + """Replace a function with itself.""" + if not is_async_callable(fn): + if self.sync_to_thread is None: + warn_implicit_sync_to_thread(fn, stacklevel=3) + elif self.sync_to_thread is not None: + warn_sync_to_thread_with_async_callable(fn, stacklevel=3) + + super().__call__(fn) + return self + + def resolve_request_class(self) -> type[Request]: + """Return the closest custom Request class in the owner graph or the default Request class. + + This method is memoized so the computation occurs only once. + + Returns: + The default :class:`Request <.connection.Request>` class for the route handler. + """ + return next( + (layer.request_class for layer in reversed(self.ownership_layers) if layer.request_class is not None), + Request, + ) + + def resolve_response_class(self) -> type[Response]: + """Return the closest custom Response class in the owner graph or the default Response class. + + This method is memoized so the computation occurs only once. + + Returns: + The default :class:`Response <.response.Response>` class for the route handler. + """ + return next( + (layer.response_class for layer in reversed(self.ownership_layers) if layer.response_class is not None), + Response, + ) + + def resolve_response_headers(self) -> frozenset[ResponseHeader]: + """Return all header parameters in the scope of the handler function. + + Returns: + A dictionary mapping keys to :class:`ResponseHeader <.datastructures.ResponseHeader>` instances. + """ + resolved_response_headers: dict[str, ResponseHeader] = {} + + for layer in self.ownership_layers: + if layer_response_headers := layer.response_headers: + if isinstance(layer_response_headers, Mapping): + # this can't happen unless you manually set response_headers on an instance, which would result in a + # type-checking error on everything but the controller. We cover this case nevertheless + resolved_response_headers.update( + {name: ResponseHeader(name=name, value=value) for name, value in layer_response_headers.items()} + ) + else: + resolved_response_headers.update({h.name: h for h in layer_response_headers}) + for extra_header in ("cache_control", "etag"): + if header_model := getattr(layer, extra_header, None): + resolved_response_headers[header_model.HEADER_NAME] = ResponseHeader( + name=header_model.HEADER_NAME, + value=header_model.to_header(), + documentation_only=header_model.documentation_only, + ) + + return frozenset(resolved_response_headers.values()) + + def resolve_response_cookies(self) -> frozenset[Cookie]: + """Return a list of Cookie instances. Filters the list to ensure each cookie key is unique. + + Returns: + A list of :class:`Cookie <.datastructures.Cookie>` instances. + """ + response_cookies: set[Cookie] = set() + for layer in reversed(self.ownership_layers): + if layer_response_cookies := layer.response_cookies: + if isinstance(layer_response_cookies, Mapping): + # this can't happen unless you manually set response_cookies on an instance, which would result in a + # type-checking error on everything but the controller. We cover this case nevertheless + response_cookies.update( + {Cookie(key=key, value=value) for key, value in layer_response_cookies.items()} + ) + else: + response_cookies.update(cast("set[Cookie]", layer_response_cookies)) + return frozenset(response_cookies) + + def resolve_before_request(self) -> AsyncAnyCallable | None: + """Resolve the before_handler handler by starting from the route handler and moving up. + + If a handler is found it is returned, otherwise None is set. + This method is memoized so the computation occurs only once. + + Returns: + An optional :class:`before request lifecycle hook handler <.types.BeforeRequestHookHandler>` + """ + if self._resolved_before_request is Empty: + before_request_handlers = [layer.before_request for layer in self.ownership_layers if layer.before_request] + self._resolved_before_request = before_request_handlers[-1] if before_request_handlers else None + return cast("AsyncAnyCallable | None", self._resolved_before_request) + + def resolve_after_response(self) -> AsyncAnyCallable | None: + """Resolve the after_response handler by starting from the route handler and moving up. + + If a handler is found it is returned, otherwise None is set. + This method is memoized so the computation occurs only once. + + Returns: + An optional :class:`after response lifecycle hook handler <.types.AfterResponseHookHandler>` + """ + if self._resolved_after_response is Empty: + after_response_handlers: list[AsyncAnyCallable] = [ + layer.after_response # type: ignore[misc] + for layer in self.ownership_layers + if layer.after_response + ] + self._resolved_after_response = after_response_handlers[-1] if after_response_handlers else None + + return cast("AsyncAnyCallable | None", self._resolved_after_response) + + def resolve_include_in_schema(self) -> bool: + """Resolve the 'include_in_schema' property by starting from the route handler and moving up. + + If 'include_in_schema' is found in any of the ownership layers, the last value found is returned. + If not found in any layer, the default value ``True`` is returned. + + Returns: + bool: The resolved 'include_in_schema' property. + """ + if self._resolved_include_in_schema is Empty: + include_in_schemas = [ + i.include_in_schema for i in self.ownership_layers if isinstance(i.include_in_schema, bool) + ] + self._resolved_include_in_schema = include_in_schemas[-1] if include_in_schemas else True + + return self._resolved_include_in_schema + + def resolve_security(self) -> list[SecurityRequirement]: + """Resolve the security property by starting from the route handler and moving up. + + Security requirements are additive, so the security requirements of the route handler are the sum of all + security requirements of the ownership layers. + + Returns: + list[SecurityRequirement]: The resolved security property. + """ + if self._resolved_security is Empty: + self._resolved_security = [] + for layer in self.ownership_layers: + if isinstance(layer.security, Sequence): + self._resolved_security.extend(layer.security) + + return self._resolved_security + + def resolve_tags(self) -> list[str]: + """Resolve the tags property by starting from the route handler and moving up. + + Tags are additive, so the tags of the route handler are the sum of all tags of the ownership layers. + + Returns: + list[str]: A sorted list of unique tags. + """ + if self._resolved_tags is Empty: + tag_set = set() + for layer in self.ownership_layers: + for tag in layer.tags or []: + tag_set.add(tag) + self._resolved_tags = sorted(tag_set) + + return self._resolved_tags + + def get_response_handler(self, is_response_type_data: bool = False) -> Callable[[Any], Awaitable[ASGIApp]]: + """Resolve the response_handler function for the route handler. + + This method is memoized so the computation occurs only once. + + Args: + is_response_type_data: Whether to return a handler for 'Response' instances. + + Returns: + Async Callable to handle an HTTP Request + """ + if self._response_handler_mapping["default_handler"] is Empty: + after_request_handlers: list[AsyncAnyCallable] = [ + layer.after_request # type: ignore[misc] + for layer in self.ownership_layers + if layer.after_request + ] + after_request = cast( + "AfterRequestHookHandler | None", + after_request_handlers[-1] if after_request_handlers else None, + ) + + media_type = self.media_type.value if isinstance(self.media_type, Enum) else self.media_type + response_class = self.resolve_response_class() + headers = self.resolve_response_headers() + cookies = self.resolve_response_cookies() + type_encoders = self.resolve_type_encoders() + + return_type = self.parsed_fn_signature.return_type + return_annotation = return_type.annotation + + self._response_handler_mapping["response_type_handler"] = response_type_handler = create_response_handler( + after_request=after_request, + background=self.background, + cookies=cookies, + headers=headers, + media_type=media_type, + status_code=self.status_code, + type_encoders=type_encoders, + ) + + if return_type.is_subclass_of(Response): + self._response_handler_mapping["default_handler"] = response_type_handler + elif is_async_callable(return_annotation) or return_annotation is ASGIApp: + self._response_handler_mapping["default_handler"] = create_generic_asgi_response_handler( + after_request=after_request + ) + else: + self._response_handler_mapping["default_handler"] = create_data_handler( + after_request=after_request, + background=self.background, + cookies=cookies, + headers=headers, + media_type=media_type, + response_class=response_class, + status_code=self.status_code, + type_encoders=type_encoders, + ) + + return cast( + "Callable[[Any], Awaitable[ASGIApp]]", + self._response_handler_mapping["response_type_handler"] + if is_response_type_data + else self._response_handler_mapping["default_handler"], + ) + + async def to_response(self, app: Litestar, data: Any, request: Request) -> ASGIApp: + """Return a :class:`Response <.response.Response>` from the handler by resolving and calling it. + + Args: + app: The :class:`Litestar <litestar.app.Litestar>` app instance + data: Either an instance of a :class:`Response <.response.Response>`, + a Response instance or an arbitrary value. + request: A :class:`Request <.connection.Request>` instance + + Returns: + A Response instance + """ + if return_dto_type := self.resolve_return_dto(): + data = return_dto_type(request).data_to_encodable_type(data) + + response_handler = self.get_response_handler(is_response_type_data=isinstance(data, Response)) + return await response_handler(app=app, data=data, request=request) # type: ignore[call-arg] + + def on_registration(self, app: Litestar) -> None: + super().on_registration(app) + self.resolve_after_response() + self.resolve_include_in_schema() + self.has_sync_callable = not is_async_callable(self.fn) + + if self.has_sync_callable and self.sync_to_thread: + self._fn = ensure_async_callable(self.fn) + self.has_sync_callable = False + + def _validate_handler_function(self) -> None: + """Validate the route handler function once it is set by inspecting its return annotations.""" + super()._validate_handler_function() + + return_type = self.parsed_fn_signature.return_type + + if return_type.annotation is Empty: + raise ImproperlyConfiguredException( + "A return value of a route handler function should be type annotated. " + "If your function doesn't return a value, annotate it as returning 'None'." + ) + + if ( + self.status_code < 200 or self.status_code in {HTTP_204_NO_CONTENT, HTTP_304_NOT_MODIFIED} + ) and not is_empty_response_annotation(return_type): + raise ImproperlyConfiguredException( + "A status code 204, 304 or in the range below 200 does not support a response body. " + "If the function should return a value, change the route handler status code to an appropriate value.", + ) + + if not self.media_type: + if return_type.is_subclass_of((str, bytes)) or return_type.annotation is AnyStr: + self.media_type = MediaType.TEXT + elif not return_type.is_subclass_of(Response): + self.media_type = MediaType.JSON + + if "socket" in self.parsed_fn_signature.parameters: + raise ImproperlyConfiguredException("The 'socket' kwarg is not supported with http handlers") + + if "data" in self.parsed_fn_signature.parameters and "GET" in self.http_methods: + raise ImproperlyConfiguredException("'data' kwarg is unsupported for 'GET' request handlers") + + +route = HTTPRouteHandler diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/decorators.py b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/decorators.py new file mode 100644 index 0000000..1ae72e5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/http_handlers/decorators.py @@ -0,0 +1,1096 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from litestar.enums import HttpMethod, MediaType +from litestar.exceptions import HTTPException, ImproperlyConfiguredException +from litestar.openapi.spec import Operation +from litestar.response.file import ASGIFileResponse, File +from litestar.types import Empty, TypeDecodersSequence +from litestar.types.builtin_types import NoneType +from litestar.utils import is_class_and_subclass + +from .base import HTTPRouteHandler + +if TYPE_CHECKING: + from typing import Any, Mapping, Sequence + + from litestar.background_tasks import BackgroundTask, BackgroundTasks + from litestar.config.response_cache import CACHE_FOREVER + from litestar.connection import Request + from litestar.datastructures import CacheControlHeader, ETag + from litestar.dto import AbstractDTO + from litestar.openapi.datastructures import ResponseSpec + from litestar.openapi.spec import SecurityRequirement + from litestar.response import Response + from litestar.types import ( + AfterRequestHookHandler, + AfterResponseHookHandler, + BeforeRequestHookHandler, + CacheKeyBuilder, + Dependencies, + EmptyType, + ExceptionHandlersMap, + Guard, + Middleware, + ResponseCookies, + ResponseHeaders, + TypeEncodersMap, + ) + from litestar.types.callable_types import OperationIDCreator + + +__all__ = ("get", "head", "post", "put", "patch", "delete") + +MSG_SEMANTIC_ROUTE_HANDLER_WITH_HTTP = "semantic route handlers cannot define http_method" + + +class delete(HTTPRouteHandler): + """DELETE Route Decorator. + + Use this decorator to decorate an HTTP handler for DELETE requests. + """ + + def __init__( + self, + path: str | None | Sequence[str] = None, + *, + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + background: BackgroundTask | BackgroundTasks | None = None, + before_request: BeforeRequestHookHandler | None = None, + cache: bool | int | type[CACHE_FOREVER] = False, + cache_control: CacheControlHeader | None = None, + cache_key_builder: CacheKeyBuilder | None = None, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + etag: ETag | None = None, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + media_type: MediaType | str | None = None, + middleware: Sequence[Middleware] | None = None, + name: str | None = None, + opt: Mapping[str, Any] | None = None, + request_class: type[Request] | None = None, + response_class: type[Response] | None = None, + response_cookies: ResponseCookies | None = None, + response_headers: ResponseHeaders | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + signature_namespace: Mapping[str, Any] | None = None, + status_code: int | None = None, + sync_to_thread: bool | None = None, + # OpenAPI related attributes + content_encoding: str | None = None, + content_media_type: str | None = None, + deprecated: bool = False, + description: str | None = None, + include_in_schema: bool | EmptyType = Empty, + operation_class: type[Operation] = Operation, + operation_id: str | OperationIDCreator | None = None, + raises: Sequence[type[HTTPException]] | None = None, + response_description: str | None = None, + responses: Mapping[int, ResponseSpec] | None = None, + security: Sequence[SecurityRequirement] | None = None, + summary: str | None = None, + tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + **kwargs: Any, + ) -> None: + """Initialize ``delete`` + + Args: + path: A path fragment for the route handler function or a sequence of path fragments. + If not given defaults to ``/`` + after_request: A sync or async function executed before a :class:`Request <.connection.Request>` is passed + to any route handler. If this function returns a value, the request will not reach the route handler, + and instead this value will be used. + after_response: A sync or async function called after the response has been awaited. It receives the + :class:`Request <.connection.Request>` object and should not return any values. + background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or + :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. + Defaults to ``None``. + before_request: A sync or async function called immediately before calling the route handler. Receives + the :class:`.connection.Request` instance and any non-``None`` return value is used for the response, + bypassing the route handler. + cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number + of seconds (e.g. ``120``) to cache the response. + cache_control: A ``cache-control`` header of type + :class:`CacheControlHeader <.datastructures.CacheControlHeader>` that will be added to the response. + cache_key_builder: A :class:`cache-key builder function <.types.CacheKeyBuilder>`. Allows for customization + of the cache key if caching is configured on the application level. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. + etag: An ``etag`` header of type :class:`ETag <.datastructures.ETag>` that will be added to the response. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + http_method: An :class:`http method string <.types.Method>`, a member of the enum + :class:`HttpMethod <litestar.enums.HttpMethod>` or a list of these that correlates to the methods the + route handler function should handle. + media_type: A member of the :class:`MediaType <.enums.MediaType>` enum or a string with a + valid IANA Media-Type. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + name: A string identifying the route handler. + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's + default request. + response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's + default response. + response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. + response_headers: A string keyed mapping of :class:`ResponseHeader <.datastructures.ResponseHeader>` + instances. + responses: A mapping of additional status codes and a description of their expected content. + This information will be included in the OpenAPI schema + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. + status_code: An http status code for the response. Defaults to ``200`` for mixed method or ``GET``, ``PUT`` + and ``PATCH``, ``201`` for ``POST`` and ``204`` for ``DELETE``. + sync_to_thread: A boolean dictating whether the handler function will be executed in a worker thread or the + main event loop. This has an effect only for sync handler functions. See using sync handler functions. + content_encoding: A string describing the encoding of the content, e.g. ``base64``. + content_media_type: A string designating the media-type of the content, e.g. ``image/png``. + deprecated: A boolean dictating whether this route should be marked as deprecated in the OpenAPI schema. + description: Text used for the route's schema description section. + include_in_schema: A boolean flag dictating whether the route handler should be documented in the OpenAPI schema. + operation_class: :class:`Operation <.openapi.spec.operation.Operation>` to be used with the route's OpenAPI schema. + operation_id: Either a string or a callable returning a string. An identifier used for the route's schema operationId. + raises: A list of exception classes extending from litestar.HttpException that is used for the OpenAPI documentation. + This list should describe all exceptions raised within the route handler's function/method. The Litestar + ValidationException will be added automatically for the schema if any validation is involved. + response_description: Text used for the route's response schema description section. + security: A sequence of dictionaries that contain information about which security scheme can be used on the endpoint. + summary: Text used for the route's schema summary section. + tags: A sequence of string tags that will be appended to the OpenAPI schema. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + **kwargs: Any additional kwarg - will be set in the opt dictionary. + """ + if "http_method" in kwargs: + raise ImproperlyConfiguredException(MSG_SEMANTIC_ROUTE_HANDLER_WITH_HTTP) + super().__init__( + after_request=after_request, + after_response=after_response, + background=background, + before_request=before_request, + cache=cache, + cache_control=cache_control, + cache_key_builder=cache_key_builder, + content_encoding=content_encoding, + content_media_type=content_media_type, + dependencies=dependencies, + deprecated=deprecated, + description=description, + dto=dto, + etag=etag, + exception_handlers=exception_handlers, + guards=guards, + http_method=HttpMethod.DELETE, + include_in_schema=include_in_schema, + media_type=media_type, + middleware=middleware, + name=name, + operation_class=operation_class, + operation_id=operation_id, + opt=opt, + path=path, + raises=raises, + request_class=request_class, + response_class=response_class, + response_cookies=response_cookies, + response_description=response_description, + response_headers=response_headers, + responses=responses, + return_dto=return_dto, + security=security, + signature_namespace=signature_namespace, + status_code=status_code, + summary=summary, + sync_to_thread=sync_to_thread, + tags=tags, + type_decoders=type_decoders, + type_encoders=type_encoders, + **kwargs, + ) + + +class get(HTTPRouteHandler): + """GET Route Decorator. + + Use this decorator to decorate an HTTP handler for GET requests. + """ + + def __init__( + self, + path: str | None | Sequence[str] = None, + *, + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + background: BackgroundTask | BackgroundTasks | None = None, + before_request: BeforeRequestHookHandler | None = None, + cache: bool | int | type[CACHE_FOREVER] = False, + cache_control: CacheControlHeader | None = None, + cache_key_builder: CacheKeyBuilder | None = None, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + etag: ETag | None = None, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + media_type: MediaType | str | None = None, + middleware: Sequence[Middleware] | None = None, + name: str | None = None, + opt: Mapping[str, Any] | None = None, + request_class: type[Request] | None = None, + response_class: type[Response] | None = None, + response_cookies: ResponseCookies | None = None, + response_headers: ResponseHeaders | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + signature_namespace: Mapping[str, Any] | None = None, + status_code: int | None = None, + sync_to_thread: bool | None = None, + # OpenAPI related attributes + content_encoding: str | None = None, + content_media_type: str | None = None, + deprecated: bool = False, + description: str | None = None, + include_in_schema: bool | EmptyType = Empty, + operation_class: type[Operation] = Operation, + operation_id: str | OperationIDCreator | None = None, + raises: Sequence[type[HTTPException]] | None = None, + response_description: str | None = None, + responses: Mapping[int, ResponseSpec] | None = None, + security: Sequence[SecurityRequirement] | None = None, + summary: str | None = None, + tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + **kwargs: Any, + ) -> None: + """Initialize ``get``. + + Args: + path: A path fragment for the route handler function or a sequence of path fragments. + If not given defaults to ``/`` + after_request: A sync or async function executed before a :class:`Request <.connection.Request>` is passed + to any route handler. If this function returns a value, the request will not reach the route handler, + and instead this value will be used. + after_response: A sync or async function called after the response has been awaited. It receives the + :class:`Request <.connection.Request>` object and should not return any values. + background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or + :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. + Defaults to ``None``. + before_request: A sync or async function called immediately before calling the route handler. Receives + the :class:`.connection.Request` instance and any non-``None`` return value is used for the response, + bypassing the route handler. + cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number + of seconds (e.g. ``120``) to cache the response. + cache_control: A ``cache-control`` header of type + :class:`CacheControlHeader <.datastructures.CacheControlHeader>` that will be added to the response. + cache_key_builder: A :class:`cache-key builder function <.types.CacheKeyBuilder>`. Allows for customization + of the cache key if caching is configured on the application level. + dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + etag: An ``etag`` header of type :class:`ETag <.datastructures.ETag>` that will be added to the response. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + http_method: An :class:`http method string <.types.Method>`, a member of the enum + :class:`HttpMethod <litestar.enums.HttpMethod>` or a list of these that correlates to the methods the + route handler function should handle. + media_type: A member of the :class:`MediaType <.enums.MediaType>` enum or a string with a + valid IANA Media-Type. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + name: A string identifying the route handler. + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's + default request. + response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's + default response. + response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. + response_headers: A string keyed mapping of :class:`ResponseHeader <.datastructures.ResponseHeader>` + instances. + responses: A mapping of additional status codes and a description of their expected content. + This information will be included in the OpenAPI schema + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. + status_code: An http status code for the response. Defaults to ``200`` for mixed method or ``GET``, ``PUT`` and + ``PATCH``, ``201`` for ``POST`` and ``204`` for ``DELETE``. + sync_to_thread: A boolean dictating whether the handler function will be executed in a worker thread or the + main event loop. This has an effect only for sync handler functions. See using sync handler functions. + content_encoding: A string describing the encoding of the content, e.g. ``base64``. + content_media_type: A string designating the media-type of the content, e.g. ``image/png``. + deprecated: A boolean dictating whether this route should be marked as deprecated in the OpenAPI schema. + description: Text used for the route's schema description section. + include_in_schema: A boolean flag dictating whether the route handler should be documented in the OpenAPI schema. + operation_class: :class:`Operation <.openapi.spec.operation.Operation>` to be used with the route's OpenAPI schema. + operation_id: Either a string or a callable returning a string. An identifier used for the route's schema operationId. + raises: A list of exception classes extending from litestar.HttpException that is used for the OpenAPI documentation. + This list should describe all exceptions raised within the route handler's function/method. The Litestar + ValidationException will be added automatically for the schema if any validation is involved. + response_description: Text used for the route's response schema description section. + security: A sequence of dictionaries that contain information about which security scheme can be used on the endpoint. + summary: Text used for the route's schema summary section. + tags: A sequence of string tags that will be appended to the OpenAPI schema. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + **kwargs: Any additional kwarg - will be set in the opt dictionary. + """ + if "http_method" in kwargs: + raise ImproperlyConfiguredException(MSG_SEMANTIC_ROUTE_HANDLER_WITH_HTTP) + + super().__init__( + after_request=after_request, + after_response=after_response, + background=background, + before_request=before_request, + cache=cache, + cache_control=cache_control, + cache_key_builder=cache_key_builder, + content_encoding=content_encoding, + content_media_type=content_media_type, + dependencies=dependencies, + deprecated=deprecated, + description=description, + dto=dto, + etag=etag, + exception_handlers=exception_handlers, + guards=guards, + http_method=HttpMethod.GET, + include_in_schema=include_in_schema, + media_type=media_type, + middleware=middleware, + name=name, + operation_class=operation_class, + operation_id=operation_id, + opt=opt, + path=path, + raises=raises, + request_class=request_class, + response_class=response_class, + response_cookies=response_cookies, + response_description=response_description, + response_headers=response_headers, + responses=responses, + return_dto=return_dto, + security=security, + signature_namespace=signature_namespace, + status_code=status_code, + summary=summary, + sync_to_thread=sync_to_thread, + tags=tags, + type_decoders=type_decoders, + type_encoders=type_encoders, + **kwargs, + ) + + +class head(HTTPRouteHandler): + """HEAD Route Decorator. + + Use this decorator to decorate an HTTP handler for HEAD requests. + """ + + def __init__( + self, + path: str | None | Sequence[str] = None, + *, + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + background: BackgroundTask | BackgroundTasks | None = None, + before_request: BeforeRequestHookHandler | None = None, + cache: bool | int | type[CACHE_FOREVER] = False, + cache_control: CacheControlHeader | None = None, + cache_key_builder: CacheKeyBuilder | None = None, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + etag: ETag | None = None, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + media_type: MediaType | str | None = None, + middleware: Sequence[Middleware] | None = None, + name: str | None = None, + opt: Mapping[str, Any] | None = None, + request_class: type[Request] | None = None, + response_class: type[Response] | None = None, + response_cookies: ResponseCookies | None = None, + response_headers: ResponseHeaders | None = None, + signature_namespace: Mapping[str, Any] | None = None, + status_code: int | None = None, + sync_to_thread: bool | None = None, + # OpenAPI related attributes + content_encoding: str | None = None, + content_media_type: str | None = None, + deprecated: bool = False, + description: str | None = None, + include_in_schema: bool | EmptyType = Empty, + operation_class: type[Operation] = Operation, + operation_id: str | OperationIDCreator | None = None, + raises: Sequence[type[HTTPException]] | None = None, + response_description: str | None = None, + responses: Mapping[int, ResponseSpec] | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + security: Sequence[SecurityRequirement] | None = None, + summary: str | None = None, + tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + **kwargs: Any, + ) -> None: + """Initialize ``head``. + + Notes: + - A response to a head request cannot include a body. + See: [MDN](https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/HEAD). + + Args: + path: A path fragment for the route handler function or a sequence of path fragments. + If not given defaults to ``/`` + after_request: A sync or async function executed before a :class:`Request <.connection.Request>` is passed + to any route handler. If this function returns a value, the request will not reach the route handler, + and instead this value will be used. + after_response: A sync or async function called after the response has been awaited. It receives the + :class:`Request <.connection.Request>` object and should not return any values. + background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or + :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. + Defaults to ``None``. + before_request: A sync or async function called immediately before calling the route handler. Receives + the :class:`.connection.Request` instance and any non-``None`` return value is used for the response, + bypassing the route handler. + cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number + of seconds (e.g. ``120``) to cache the response. + cache_control: A ``cache-control`` header of type + :class:`CacheControlHeader <.datastructures.CacheControlHeader>` that will be added to the response. + cache_key_builder: A :class:`cache-key builder function <.types.CacheKeyBuilder>`. Allows for customization + of the cache key if caching is configured on the application level. + dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + etag: An ``etag`` header of type :class:`ETag <.datastructures.ETag>` that will be added to the response. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + http_method: An :class:`http method string <.types.Method>`, a member of the enum + :class:`HttpMethod <litestar.enums.HttpMethod>` or a list of these that correlates to the methods the + route handler function should handle. + media_type: A member of the :class:`MediaType <.enums.MediaType>` enum or a string with a + valid IANA Media-Type. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + name: A string identifying the route handler. + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's + default request. + response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's + default response. + response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. + response_headers: A string keyed mapping of :class:`ResponseHeader <.datastructures.ResponseHeader>` + instances. + responses: A mapping of additional status codes and a description of their expected content. + This information will be included in the OpenAPI schema + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. + status_code: An http status code for the response. Defaults to ``200`` for mixed method or ``GET``, ``PUT`` and + ``PATCH``, ``201`` for ``POST`` and ``204`` for ``DELETE``. + sync_to_thread: A boolean dictating whether the handler function will be executed in a worker thread or the + main event loop. This has an effect only for sync handler functions. See using sync handler functions. + content_encoding: A string describing the encoding of the content, e.g. ``base64``. + content_media_type: A string designating the media-type of the content, e.g. ``image/png``. + deprecated: A boolean dictating whether this route should be marked as deprecated in the OpenAPI schema. + description: Text used for the route's schema description section. + include_in_schema: A boolean flag dictating whether the route handler should be documented in the OpenAPI schema. + operation_class: :class:`Operation <.openapi.spec.operation.Operation>` to be used with the route's OpenAPI schema. + operation_id: Either a string or a callable returning a string. An identifier used for the route's schema operationId. + raises: A list of exception classes extending from litestar.HttpException that is used for the OpenAPI documentation. + This list should describe all exceptions raised within the route handler's function/method. The Litestar + ValidationException will be added automatically for the schema if any validation is involved. + response_description: Text used for the route's response schema description section. + security: A sequence of dictionaries that contain information about which security scheme can be used on the endpoint. + summary: Text used for the route's schema summary section. + tags: A sequence of string tags that will be appended to the OpenAPI schema. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + **kwargs: Any additional kwarg - will be set in the opt dictionary. + """ + if "http_method" in kwargs: + raise ImproperlyConfiguredException(MSG_SEMANTIC_ROUTE_HANDLER_WITH_HTTP) + + super().__init__( + after_request=after_request, + after_response=after_response, + background=background, + before_request=before_request, + cache=cache, + cache_control=cache_control, + cache_key_builder=cache_key_builder, + content_encoding=content_encoding, + content_media_type=content_media_type, + dependencies=dependencies, + deprecated=deprecated, + description=description, + dto=dto, + etag=etag, + exception_handlers=exception_handlers, + guards=guards, + http_method=HttpMethod.HEAD, + include_in_schema=include_in_schema, + media_type=media_type, + middleware=middleware, + name=name, + operation_class=operation_class, + operation_id=operation_id, + opt=opt, + path=path, + raises=raises, + request_class=request_class, + response_class=response_class, + response_cookies=response_cookies, + response_description=response_description, + response_headers=response_headers, + responses=responses, + return_dto=return_dto, + security=security, + signature_namespace=signature_namespace, + status_code=status_code, + summary=summary, + sync_to_thread=sync_to_thread, + tags=tags, + type_decoders=type_decoders, + type_encoders=type_encoders, + **kwargs, + ) + + def _validate_handler_function(self) -> None: + """Validate the route handler function once it is set by inspecting its return annotations.""" + super()._validate_handler_function() + + # we allow here File and File because these have special setting for head responses + return_annotation = self.parsed_fn_signature.return_type.annotation + if not ( + return_annotation in {NoneType, None} + or is_class_and_subclass(return_annotation, File) + or is_class_and_subclass(return_annotation, ASGIFileResponse) + ): + raise ImproperlyConfiguredException("A response to a head request should not have a body") + + +class patch(HTTPRouteHandler): + """PATCH Route Decorator. + + Use this decorator to decorate an HTTP handler for PATCH requests. + """ + + def __init__( + self, + path: str | None | Sequence[str] = None, + *, + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + background: BackgroundTask | BackgroundTasks | None = None, + before_request: BeforeRequestHookHandler | None = None, + cache: bool | int | type[CACHE_FOREVER] = False, + cache_control: CacheControlHeader | None = None, + cache_key_builder: CacheKeyBuilder | None = None, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + etag: ETag | None = None, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + media_type: MediaType | str | None = None, + middleware: Sequence[Middleware] | None = None, + name: str | None = None, + opt: Mapping[str, Any] | None = None, + request_class: type[Request] | None = None, + response_class: type[Response] | None = None, + response_cookies: ResponseCookies | None = None, + response_headers: ResponseHeaders | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + signature_namespace: Mapping[str, Any] | None = None, + status_code: int | None = None, + sync_to_thread: bool | None = None, + # OpenAPI related attributes + content_encoding: str | None = None, + content_media_type: str | None = None, + deprecated: bool = False, + description: str | None = None, + include_in_schema: bool | EmptyType = Empty, + operation_class: type[Operation] = Operation, + operation_id: str | OperationIDCreator | None = None, + raises: Sequence[type[HTTPException]] | None = None, + response_description: str | None = None, + responses: Mapping[int, ResponseSpec] | None = None, + security: Sequence[SecurityRequirement] | None = None, + summary: str | None = None, + tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + **kwargs: Any, + ) -> None: + """Initialize ``patch``. + + Args: + path: A path fragment for the route handler function or a sequence of path fragments. + If not given defaults to ``/`` + after_request: A sync or async function executed before a :class:`Request <.connection.Request>` is passed + to any route handler. If this function returns a value, the request will not reach the route handler, + and instead this value will be used. + after_response: A sync or async function called after the response has been awaited. It receives the + :class:`Request <.connection.Request>` object and should not return any values. + background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or + :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. + Defaults to ``None``. + before_request: A sync or async function called immediately before calling the route handler. Receives + the :class:`.connection.Request` instance and any non-``None`` return value is used for the response, + bypassing the route handler. + cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number + of seconds (e.g. ``120``) to cache the response. + cache_control: A ``cache-control`` header of type + :class:`CacheControlHeader <.datastructures.CacheControlHeader>` that will be added to the response. + cache_key_builder: A :class:`cache-key builder function <.types.CacheKeyBuilder>`. Allows for customization + of the cache key if caching is configured on the application level. + dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + etag: An ``etag`` header of type :class:`ETag <.datastructures.ETag>` that will be added to the response. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + http_method: An :class:`http method string <.types.Method>`, a member of the enum + :class:`HttpMethod <litestar.enums.HttpMethod>` or a list of these that correlates to the methods the + route handler function should handle. + media_type: A member of the :class:`MediaType <.enums.MediaType>` enum or a string with a + valid IANA Media-Type. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + name: A string identifying the route handler. + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's + default request. + response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's + default response. + response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. + response_headers: A string keyed mapping of :class:`ResponseHeader <.datastructures.ResponseHeader>` + instances. + responses: A mapping of additional status codes and a description of their expected content. + This information will be included in the OpenAPI schema + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. + status_code: An http status code for the response. Defaults to ``200`` for mixed method or ``GET``, ``PUT`` and + ``PATCH``, ``201`` for ``POST`` and ``204`` for ``DELETE``. + sync_to_thread: A boolean dictating whether the handler function will be executed in a worker thread or the + main event loop. This has an effect only for sync handler functions. See using sync handler functions. + content_encoding: A string describing the encoding of the content, e.g. ``base64``. + content_media_type: A string designating the media-type of the content, e.g. ``image/png``. + deprecated: A boolean dictating whether this route should be marked as deprecated in the OpenAPI schema. + description: Text used for the route's schema description section. + include_in_schema: A boolean flag dictating whether the route handler should be documented in the OpenAPI schema. + operation_class: :class:`Operation <.openapi.spec.operation.Operation>` to be used with the route's OpenAPI schema. + operation_id: Either a string or a callable returning a string. An identifier used for the route's schema operationId. + raises: A list of exception classes extending from litestar.HttpException that is used for the OpenAPI documentation. + This list should describe all exceptions raised within the route handler's function/method. The Litestar + ValidationException will be added automatically for the schema if any validation is involved. + response_description: Text used for the route's response schema description section. + security: A sequence of dictionaries that contain information about which security scheme can be used on the endpoint. + summary: Text used for the route's schema summary section. + tags: A sequence of string tags that will be appended to the OpenAPI schema. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + **kwargs: Any additional kwarg - will be set in the opt dictionary. + """ + if "http_method" in kwargs: + raise ImproperlyConfiguredException(MSG_SEMANTIC_ROUTE_HANDLER_WITH_HTTP) + super().__init__( + after_request=after_request, + after_response=after_response, + background=background, + before_request=before_request, + cache=cache, + cache_control=cache_control, + cache_key_builder=cache_key_builder, + content_encoding=content_encoding, + content_media_type=content_media_type, + dependencies=dependencies, + deprecated=deprecated, + description=description, + dto=dto, + etag=etag, + exception_handlers=exception_handlers, + guards=guards, + http_method=HttpMethod.PATCH, + include_in_schema=include_in_schema, + media_type=media_type, + middleware=middleware, + name=name, + operation_class=operation_class, + operation_id=operation_id, + opt=opt, + path=path, + raises=raises, + request_class=request_class, + response_class=response_class, + response_cookies=response_cookies, + response_description=response_description, + response_headers=response_headers, + responses=responses, + return_dto=return_dto, + security=security, + signature_namespace=signature_namespace, + status_code=status_code, + summary=summary, + sync_to_thread=sync_to_thread, + tags=tags, + type_decoders=type_decoders, + type_encoders=type_encoders, + **kwargs, + ) + + +class post(HTTPRouteHandler): + """POST Route Decorator. + + Use this decorator to decorate an HTTP handler for POST requests. + """ + + def __init__( + self, + path: str | None | Sequence[str] = None, + *, + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + background: BackgroundTask | BackgroundTasks | None = None, + before_request: BeforeRequestHookHandler | None = None, + cache: bool | int | type[CACHE_FOREVER] = False, + cache_control: CacheControlHeader | None = None, + cache_key_builder: CacheKeyBuilder | None = None, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + etag: ETag | None = None, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + media_type: MediaType | str | None = None, + middleware: Sequence[Middleware] | None = None, + name: str | None = None, + opt: Mapping[str, Any] | None = None, + request_class: type[Request] | None = None, + response_class: type[Response] | None = None, + response_cookies: ResponseCookies | None = None, + response_headers: ResponseHeaders | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + signature_namespace: Mapping[str, Any] | None = None, + status_code: int | None = None, + sync_to_thread: bool | None = None, + # OpenAPI related attributes + content_encoding: str | None = None, + content_media_type: str | None = None, + deprecated: bool = False, + description: str | None = None, + include_in_schema: bool | EmptyType = Empty, + operation_class: type[Operation] = Operation, + operation_id: str | OperationIDCreator | None = None, + raises: Sequence[type[HTTPException]] | None = None, + response_description: str | None = None, + responses: Mapping[int, ResponseSpec] | None = None, + security: Sequence[SecurityRequirement] | None = None, + summary: str | None = None, + tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + **kwargs: Any, + ) -> None: + """Initialize ``post`` + + Args: + path: A path fragment for the route handler function or a sequence of path fragments. + If not given defaults to ``/`` + after_request: A sync or async function executed before a :class:`Request <.connection.Request>` is passed + to any route handler. If this function returns a value, the request will not reach the route handler, + and instead this value will be used. + after_response: A sync or async function called after the response has been awaited. It receives the + :class:`Request <.connection.Request>` object and should not return any values. + background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or + :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. + Defaults to ``None``. + before_request: A sync or async function called immediately before calling the route handler. Receives + the :class:`.connection.Request` instance and any non-``None`` return value is used for the response, + bypassing the route handler. + cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number + of seconds (e.g. ``120``) to cache the response. + cache_control: A ``cache-control`` header of type + :class:`CacheControlHeader <.datastructures.CacheControlHeader>` that will be added to the response. + cache_key_builder: A :class:`cache-key builder function <.types.CacheKeyBuilder>`. Allows for customization + of the cache key if caching is configured on the application level. + dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + etag: An ``etag`` header of type :class:`ETag <.datastructures.ETag>` that will be added to the response. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + http_method: An :class:`http method string <.types.Method>`, a member of the enum + :class:`HttpMethod <litestar.enums.HttpMethod>` or a list of these that correlates to the methods the + route handler function should handle. + media_type: A member of the :class:`MediaType <.enums.MediaType>` enum or a string with a + valid IANA Media-Type. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + name: A string identifying the route handler. + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's + default request. + response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's + default response. + response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. + response_headers: A string keyed mapping of :class:`ResponseHeader <.datastructures.ResponseHeader>` + instances. + responses: A mapping of additional status codes and a description of their expected content. + This information will be included in the OpenAPI schema + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. + status_code: An http status code for the response. Defaults to ``200`` for mixed method or ``GET``, ``PUT`` and + ``PATCH``, ``201`` for ``POST`` and ``204`` for ``DELETE``. + sync_to_thread: A boolean dictating whether the handler function will be executed in a worker thread or the + main event loop. This has an effect only for sync handler functions. See using sync handler functions. + content_encoding: A string describing the encoding of the content, e.g. ``base64``. + content_media_type: A string designating the media-type of the content, e.g. ``image/png``. + deprecated: A boolean dictating whether this route should be marked as deprecated in the OpenAPI schema. + description: Text used for the route's schema description section. + include_in_schema: A boolean flag dictating whether the route handler should be documented in the OpenAPI schema. + operation_class: :class:`Operation <.openapi.spec.operation.Operation>` to be used with the route's OpenAPI schema. + operation_id: Either a string or a callable returning a string. An identifier used for the route's schema operationId. + raises: A list of exception classes extending from litestar.HttpException that is used for the OpenAPI documentation. + This list should describe all exceptions raised within the route handler's function/method. The Litestar + ValidationException will be added automatically for the schema if any validation is involved. + response_description: Text used for the route's response schema description section. + security: A sequence of dictionaries that contain information about which security scheme can be used on the endpoint. + summary: Text used for the route's schema summary section. + tags: A sequence of string tags that will be appended to the OpenAPI schema. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + **kwargs: Any additional kwarg - will be set in the opt dictionary. + """ + if "http_method" in kwargs: + raise ImproperlyConfiguredException(MSG_SEMANTIC_ROUTE_HANDLER_WITH_HTTP) + super().__init__( + after_request=after_request, + after_response=after_response, + background=background, + before_request=before_request, + cache=cache, + cache_control=cache_control, + cache_key_builder=cache_key_builder, + content_encoding=content_encoding, + content_media_type=content_media_type, + dependencies=dependencies, + deprecated=deprecated, + description=description, + dto=dto, + exception_handlers=exception_handlers, + etag=etag, + guards=guards, + http_method=HttpMethod.POST, + include_in_schema=include_in_schema, + media_type=media_type, + middleware=middleware, + name=name, + operation_class=operation_class, + operation_id=operation_id, + opt=opt, + path=path, + raises=raises, + request_class=request_class, + response_class=response_class, + response_cookies=response_cookies, + response_description=response_description, + response_headers=response_headers, + responses=responses, + return_dto=return_dto, + signature_namespace=signature_namespace, + security=security, + status_code=status_code, + summary=summary, + sync_to_thread=sync_to_thread, + tags=tags, + type_decoders=type_decoders, + type_encoders=type_encoders, + **kwargs, + ) + + +class put(HTTPRouteHandler): + """PUT Route Decorator. + + Use this decorator to decorate an HTTP handler for PUT requests. + """ + + def __init__( + self, + path: str | None | Sequence[str] = None, + *, + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + background: BackgroundTask | BackgroundTasks | None = None, + before_request: BeforeRequestHookHandler | None = None, + cache: bool | int | type[CACHE_FOREVER] = False, + cache_control: CacheControlHeader | None = None, + cache_key_builder: CacheKeyBuilder | None = None, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + etag: ETag | None = None, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + media_type: MediaType | str | None = None, + middleware: Sequence[Middleware] | None = None, + name: str | None = None, + opt: Mapping[str, Any] | None = None, + request_class: type[Request] | None = None, + response_class: type[Response] | None = None, + response_cookies: ResponseCookies | None = None, + response_headers: ResponseHeaders | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + signature_namespace: Mapping[str, Any] | None = None, + status_code: int | None = None, + sync_to_thread: bool | None = None, + # OpenAPI related attributes + content_encoding: str | None = None, + content_media_type: str | None = None, + deprecated: bool = False, + description: str | None = None, + include_in_schema: bool | EmptyType = Empty, + operation_class: type[Operation] = Operation, + operation_id: str | OperationIDCreator | None = None, + raises: Sequence[type[HTTPException]] | None = None, + response_description: str | None = None, + responses: Mapping[int, ResponseSpec] | None = None, + security: Sequence[SecurityRequirement] | None = None, + summary: str | None = None, + tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + **kwargs: Any, + ) -> None: + """Initialize ``put`` + + Args: + path: A path fragment for the route handler function or a sequence of path fragments. + If not given defaults to ``/`` + after_request: A sync or async function executed before a :class:`Request <.connection.Request>` is passed + to any route handler. If this function returns a value, the request will not reach the route handler, + and instead this value will be used. + after_response: A sync or async function called after the response has been awaited. It receives the + :class:`Request <.connection.Request>` object and should not return any values. + background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or + :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. + Defaults to ``None``. + before_request: A sync or async function called immediately before calling the route handler. Receives + the :class:`.connection.Request` instance and any non-``None`` return value is used for the response, + bypassing the route handler. + cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number + of seconds (e.g. ``120``) to cache the response. + cache_control: A ``cache-control`` header of type + :class:`CacheControlHeader <.datastructures.CacheControlHeader>` that will be added to the response. + cache_key_builder: A :class:`cache-key builder function <.types.CacheKeyBuilder>`. Allows for customization + of the cache key if caching is configured on the application level. + dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + etag: An ``etag`` header of type :class:`ETag <.datastructures.ETag>` that will be added to the response. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + http_method: An :class:`http method string <.types.Method>`, a member of the enum + :class:`HttpMethod <litestar.enums.HttpMethod>` or a list of these that correlates to the methods the + route handler function should handle. + media_type: A member of the :class:`MediaType <.enums.MediaType>` enum or a string with a + valid IANA Media-Type. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + name: A string identifying the route handler. + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's + default request. + response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's + default response. + response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. + response_headers: A string keyed mapping of :class:`ResponseHeader <.datastructures.ResponseHeader>` + instances. + responses: A mapping of additional status codes and a description of their expected content. + This information will be included in the OpenAPI schema + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. + status_code: An http status code for the response. Defaults to ``200`` for mixed method or ``GET``, ``PUT`` and + ``PATCH``, ``201`` for ``POST`` and ``204`` for ``DELETE``. + sync_to_thread: A boolean dictating whether the handler function will be executed in a worker thread or the + main event loop. This has an effect only for sync handler functions. See using sync handler functions. + content_encoding: A string describing the encoding of the content, e.g. ``base64``. + content_media_type: A string designating the media-type of the content, e.g. ``image/png``. + deprecated: A boolean dictating whether this route should be marked as deprecated in the OpenAPI schema. + description: Text used for the route's schema description section. + include_in_schema: A boolean flag dictating whether the route handler should be documented in the OpenAPI schema. + operation_class: :class:`Operation <.openapi.spec.operation.Operation>` to be used with the route's OpenAPI schema. + operation_id: Either a string or a callable returning a string. An identifier used for the route's schema operationId. + raises: A list of exception classes extending from litestar.HttpException that is used for the OpenAPI documentation. + This list should describe all exceptions raised within the route handler's function/method. The Litestar + ValidationException will be added automatically for the schema if any validation is involved. + response_description: Text used for the route's response schema description section. + security: A sequence of dictionaries that contain information about which security scheme can be used on the endpoint. + summary: Text used for the route's schema summary section. + tags: A sequence of string tags that will be appended to the OpenAPI schema. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + **kwargs: Any additional kwarg - will be set in the opt dictionary. + """ + if "http_method" in kwargs: + raise ImproperlyConfiguredException(MSG_SEMANTIC_ROUTE_HANDLER_WITH_HTTP) + super().__init__( + after_request=after_request, + after_response=after_response, + background=background, + before_request=before_request, + cache=cache, + cache_control=cache_control, + cache_key_builder=cache_key_builder, + content_encoding=content_encoding, + content_media_type=content_media_type, + dependencies=dependencies, + deprecated=deprecated, + description=description, + dto=dto, + exception_handlers=exception_handlers, + etag=etag, + guards=guards, + http_method=HttpMethod.PUT, + include_in_schema=include_in_schema, + media_type=media_type, + middleware=middleware, + name=name, + operation_class=operation_class, + operation_id=operation_id, + opt=opt, + path=path, + raises=raises, + request_class=request_class, + response_class=response_class, + response_cookies=response_cookies, + response_description=response_description, + response_headers=response_headers, + responses=responses, + return_dto=return_dto, + security=security, + signature_namespace=signature_namespace, + status_code=status_code, + summary=summary, + sync_to_thread=sync_to_thread, + tags=tags, + type_decoders=type_decoders, + type_encoders=type_encoders, + **kwargs, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__init__.py b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__init__.py new file mode 100644 index 0000000..5b24734 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__init__.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from litestar.handlers.websocket_handlers.listener import ( + WebsocketListener, + WebsocketListenerRouteHandler, + websocket_listener, +) +from litestar.handlers.websocket_handlers.route_handler import WebsocketRouteHandler, websocket + +__all__ = ( + "WebsocketListener", + "WebsocketListenerRouteHandler", + "WebsocketRouteHandler", + "websocket", + "websocket_listener", +) diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f6d1115 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/_utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/_utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c5ae4c8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/_utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/listener.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/listener.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..38b8219 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/listener.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/route_handler.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/route_handler.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0e92ccd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/route_handler.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/_utils.py b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/_utils.py new file mode 100644 index 0000000..bcd90ac --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/_utils.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from functools import wraps +from inspect import Parameter, Signature +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict + +from msgspec.json import Encoder as JsonEncoder + +from litestar.di import Provide +from litestar.serialization import decode_json +from litestar.types.builtin_types import NoneType +from litestar.utils import ensure_async_callable +from litestar.utils.helpers import unwrap_partial + +if TYPE_CHECKING: + from litestar import WebSocket + from litestar.handlers.websocket_handlers.listener import WebsocketListenerRouteHandler + from litestar.types import AnyCallable + from litestar.utils.signature import ParsedSignature + + +def create_handle_receive(listener: WebsocketListenerRouteHandler) -> Callable[[WebSocket], Coroutine[Any, None, None]]: + if data_dto := listener.resolve_data_dto(): + + async def handle_receive(socket: WebSocket) -> Any: + received_data = await socket.receive_data(mode=listener._receive_mode) + return data_dto(socket).decode_bytes( + received_data.encode("utf-8") if isinstance(received_data, str) else received_data + ) + + elif listener.parsed_data_field and listener.parsed_data_field.annotation is str: + + async def handle_receive(socket: WebSocket) -> Any: + received_data = await socket.receive_data(mode=listener._receive_mode) + return received_data.decode("utf-8") if isinstance(received_data, bytes) else received_data + + elif listener.parsed_data_field and listener.parsed_data_field.annotation is bytes: + + async def handle_receive(socket: WebSocket) -> Any: + received_data = await socket.receive_data(mode=listener._receive_mode) + return received_data.encode("utf-8") if isinstance(received_data, str) else received_data + + else: + + async def handle_receive(socket: WebSocket) -> Any: + received_data = await socket.receive_data(mode=listener._receive_mode) + return decode_json(value=received_data, type_decoders=socket.route_handler.resolve_type_decoders()) + + return handle_receive + + +def create_handle_send( + listener: WebsocketListenerRouteHandler, +) -> Callable[[WebSocket, Any], Coroutine[None, None, None]]: + json_encoder = JsonEncoder(enc_hook=listener.default_serializer) + + if return_dto := listener.resolve_return_dto(): + + async def handle_send(socket: WebSocket, data: Any) -> None: + encoded_data = return_dto(socket).data_to_encodable_type(data) + data = json_encoder.encode(encoded_data) + await socket.send_data(data=data, mode=listener._send_mode) + + elif listener.parsed_return_field.is_subclass_of((str, bytes)) or ( + listener.parsed_return_field.is_optional and listener.parsed_return_field.has_inner_subclass_of((str, bytes)) + ): + + async def handle_send(socket: WebSocket, data: Any) -> None: + await socket.send_data(data=data, mode=listener._send_mode) + + else: + + async def handle_send(socket: WebSocket, data: Any) -> None: + data = json_encoder.encode(data) + await socket.send_data(data=data, mode=listener._send_mode) + + return handle_send + + +class ListenerHandler: + __slots__ = ("_can_send_data", "_fn", "_listener", "_pass_socket") + + def __init__( + self, + listener: WebsocketListenerRouteHandler, + fn: AnyCallable, + parsed_signature: ParsedSignature, + namespace: dict[str, Any], + ) -> None: + self._can_send_data = not parsed_signature.return_type.is_subclass_of(NoneType) + self._fn = ensure_async_callable(fn) + self._listener = listener + self._pass_socket = "socket" in parsed_signature.parameters + + async def __call__( + self, + *args: Any, + socket: WebSocket, + connection_lifespan_dependencies: Dict[str, Any], # noqa: UP006 + **kwargs: Any, + ) -> None: + lifespan_mananger = self._listener._connection_lifespan or self._listener.default_connection_lifespan + handle_send = self._listener.resolve_send_handler() if self._can_send_data else None + handle_receive = self._listener.resolve_receive_handler() + + if self._pass_socket: + kwargs["socket"] = socket + + async with lifespan_mananger(**connection_lifespan_dependencies): + while True: + received_data = await handle_receive(socket) + data = await self._fn(*args, data=received_data, **kwargs) + if handle_send: + await handle_send(socket, data) + + +def create_handler_signature(callback_signature: Signature) -> Signature: + """Creates a :class:`Signature` for the handler function for signature modelling. + + This is required for two reasons: + + 1. the :class:`.handlers.WebsocketHandler` signature model cannot contain the ``data`` parameter, which is + required for :class:`.handlers.websocket_listener` handlers. + 2. the :class;`.handlers.WebsocketHandler` signature model must include the ``socket`` parameter, which is + optional for :class:`.handlers.websocket_listener` handlers. + + Args: + callback_signature: The :class:`Signature` of the listener callback. + + Returns: + The :class:`Signature` for the listener callback as required for signature modelling. + """ + new_params = [p for p in callback_signature.parameters.values() if p.name != "data"] + if "socket" not in callback_signature.parameters: + new_params.append(Parameter(name="socket", kind=Parameter.KEYWORD_ONLY, annotation="WebSocket")) + + new_params.append( + Parameter(name="connection_lifespan_dependencies", kind=Parameter.KEYWORD_ONLY, annotation="Dict[str, Any]") + ) + + return callback_signature.replace(parameters=new_params) + + +def create_stub_dependency(src: AnyCallable) -> Provide: + """Create a stub dependency, accepting any kwargs defined in ``src``, and + wrap it in ``Provide`` + """ + src = unwrap_partial(src) + + @wraps(src) + async def stub(**kwargs: Any) -> Dict[str, Any]: # noqa: UP006 + return kwargs + + return Provide(stub) diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/listener.py b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/listener.py new file mode 100644 index 0000000..86fefc9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/listener.py @@ -0,0 +1,417 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Callable, + Dict, + Mapping, + Optional, + cast, + overload, +) + +from litestar._signature import SignatureModel +from litestar.connection import WebSocket +from litestar.exceptions import ImproperlyConfiguredException, WebSocketDisconnect +from litestar.types import ( + AnyCallable, + Dependencies, + Empty, + EmptyType, + ExceptionHandler, + Guard, + Middleware, + TypeEncodersMap, +) +from litestar.utils import ensure_async_callable +from litestar.utils.signature import ParsedSignature, get_fn_type_hints + +from ._utils import ( + ListenerHandler, + create_handle_receive, + create_handle_send, + create_handler_signature, + create_stub_dependency, +) +from .route_handler import WebsocketRouteHandler + +if TYPE_CHECKING: + from typing import Coroutine + + from typing_extensions import Self + + from litestar import Router + from litestar.dto import AbstractDTO + from litestar.types.asgi_types import WebSocketMode + from litestar.types.composite_types import TypeDecodersSequence + +__all__ = ("WebsocketListener", "WebsocketListenerRouteHandler", "websocket_listener") + + +class WebsocketListenerRouteHandler(WebsocketRouteHandler): + """A websocket listener that automatically accepts a connection, handles disconnects, + invokes a callback function every time new data is received and sends any data + returned + """ + + __slots__ = { + "connection_accept_handler": "Callback to accept a WebSocket connection. By default, calls WebSocket.accept", + "on_accept": "Callback invoked after a WebSocket connection has been accepted", + "on_disconnect": "Callback invoked after a WebSocket connection has been closed", + "weboscket_class": "WebSocket class", + "_connection_lifespan": None, + "_handle_receive": None, + "_handle_send": None, + "_receive_mode": None, + "_send_mode": None, + } + + @overload + def __init__( + self, + path: str | list[str] | None = None, + *, + connection_lifespan: Callable[..., AbstractAsyncContextManager[Any]] | None = None, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None, + guards: list[Guard] | None = None, + middleware: list[Middleware] | None = None, + receive_mode: WebSocketMode = "text", + send_mode: WebSocketMode = "text", + name: str | None = None, + opt: dict[str, Any] | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + signature_namespace: Mapping[str, Any] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, + **kwargs: Any, + ) -> None: ... + + @overload + def __init__( + self, + path: str | list[str] | None = None, + *, + connection_accept_handler: Callable[[WebSocket], Coroutine[Any, Any, None]] = WebSocket.accept, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None, + guards: list[Guard] | None = None, + middleware: list[Middleware] | None = None, + receive_mode: WebSocketMode = "text", + send_mode: WebSocketMode = "text", + name: str | None = None, + on_accept: AnyCallable | None = None, + on_disconnect: AnyCallable | None = None, + opt: dict[str, Any] | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + signature_namespace: Mapping[str, Any] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, + **kwargs: Any, + ) -> None: ... + + def __init__( + self, + path: str | list[str] | None = None, + *, + connection_accept_handler: Callable[[WebSocket], Coroutine[Any, Any, None]] = WebSocket.accept, + connection_lifespan: Callable[..., AbstractAsyncContextManager[Any]] | None = None, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None, + guards: list[Guard] | None = None, + middleware: list[Middleware] | None = None, + receive_mode: WebSocketMode = "text", + send_mode: WebSocketMode = "text", + name: str | None = None, + on_accept: AnyCallable | None = None, + on_disconnect: AnyCallable | None = None, + opt: dict[str, Any] | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + signature_namespace: Mapping[str, Any] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, + **kwargs: Any, + ) -> None: + """Initialize ``WebsocketRouteHandler`` + + Args: + path: A path fragment for the route handler function or a sequence of path fragments. If not given defaults + to ``/`` + connection_accept_handler: A callable that accepts a :class:`WebSocket <.connection.WebSocket>` instance + and returns a coroutine that when awaited, will accept the connection. Defaults to ``WebSocket.accept``. + connection_lifespan: An asynchronous context manager, handling the lifespan of the connection. By default, + it calls the ``connection_accept_handler``, ``on_connect`` and ``on_disconnect``. Can request any + dependencies, for example the :class:`WebSocket <.connection.WebSocket>` connection + dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + receive_mode: Websocket mode to receive data in, either `text` or `binary`. + send_mode: Websocket mode to receive data in, either `text` or `binary`. + name: A string identifying the route handler. + on_accept: Callback invoked after a connection has been accepted. Can request any dependencies, for example + the :class:`WebSocket <.connection.WebSocket>` connection + on_disconnect: Callback invoked after a connection has been closed. Can request any dependencies, for + example the :class:`WebSocket <.connection.WebSocket>` connection + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or + :class:`ASGI Scope <.types.Scope>`. + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature + modelling. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + **kwargs: Any additional kwarg - will be set in the opt dictionary. + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's + default websocket class. + """ + if connection_lifespan and any([on_accept, on_disconnect, connection_accept_handler is not WebSocket.accept]): + raise ImproperlyConfiguredException( + "connection_lifespan can not be used with connection hooks " + "(on_accept, on_disconnect, connection_accept_handler)", + ) + + self._receive_mode: WebSocketMode = receive_mode + self._send_mode: WebSocketMode = send_mode + self._connection_lifespan = connection_lifespan + self._send_handler: Callable[[WebSocket, Any], Coroutine[None, None, None]] | EmptyType = Empty + self._receive_handler: Callable[[WebSocket], Any] | EmptyType = Empty + + self.connection_accept_handler = connection_accept_handler + self.on_accept = ensure_async_callable(on_accept) if on_accept else None + self.on_disconnect = ensure_async_callable(on_disconnect) if on_disconnect else None + self.type_decoders = type_decoders + self.type_encoders = type_encoders + self.websocket_class = websocket_class + + listener_dependencies = dict(dependencies or {}) + + listener_dependencies["connection_lifespan_dependencies"] = create_stub_dependency( + connection_lifespan or self.default_connection_lifespan + ) + + if self.on_accept: + listener_dependencies["on_accept_dependencies"] = create_stub_dependency(self.on_accept) + + if self.on_disconnect: + listener_dependencies["on_disconnect_dependencies"] = create_stub_dependency(self.on_disconnect) + + super().__init__( + path=path, + dependencies=listener_dependencies, + exception_handlers=exception_handlers, + guards=guards, + middleware=middleware, + name=name, + opt=opt, + signature_namespace=signature_namespace, + dto=dto, + return_dto=return_dto, + type_decoders=type_decoders, + type_encoders=type_encoders, + websocket_class=websocket_class, + **kwargs, + ) + + def __call__(self, fn: AnyCallable) -> Self: + parsed_signature = ParsedSignature.from_fn(fn, self.resolve_signature_namespace()) + + if "data" not in parsed_signature.parameters: + raise ImproperlyConfiguredException("Websocket listeners must accept a 'data' parameter") + + for param in ("request", "body"): + if param in parsed_signature.parameters: + raise ImproperlyConfiguredException(f"The {param} kwarg is not supported with websocket listeners") + + # we are manipulating the signature of the decorated function below, so we must store the original values for + # use elsewhere. + self._parsed_return_field = parsed_signature.return_type + self._parsed_data_field = parsed_signature.parameters.get("data") + self._parsed_fn_signature = ParsedSignature.from_signature( + create_handler_signature(parsed_signature.original_signature), + fn_type_hints={ + **get_fn_type_hints(fn, namespace=self.resolve_signature_namespace()), + **get_fn_type_hints(ListenerHandler.__call__, namespace=self.resolve_signature_namespace()), + }, + ) + + return super().__call__( + ListenerHandler( + listener=self, fn=fn, parsed_signature=parsed_signature, namespace=self.resolve_signature_namespace() + ) + ) + + def _validate_handler_function(self) -> None: + """Validate the route handler function once it's set by inspecting its return annotations.""" + # validation occurs in the call method + + @property + def signature_model(self) -> type[SignatureModel]: + """Get the signature model for the route handler. + + Returns: + A signature model for the route handler. + + """ + if self._signature_model is Empty: + self._signature_model = SignatureModel.create( + dependency_name_set=self.dependency_name_set, + fn=cast("AnyCallable", self.fn), + parsed_signature=self.parsed_fn_signature, + type_decoders=self.resolve_type_decoders(), + ) + return self._signature_model + + @asynccontextmanager + async def default_connection_lifespan( + self, + socket: WebSocket, + on_accept_dependencies: Optional[Dict[str, Any]] = None, # noqa: UP006, UP007 + on_disconnect_dependencies: Optional[Dict[str, Any]] = None, # noqa: UP006, UP007 + ) -> AsyncGenerator[None, None]: + """Handle the connection lifespan of a :class:`WebSocket <.connection.WebSocket>`. + + Args: + socket: The :class:`WebSocket <.connection.WebSocket>` connection + on_accept_dependencies: Dependencies requested by the :attr:`on_accept` hook + on_disconnect_dependencies: Dependencies requested by the :attr:`on_disconnect` hook + + By, default this will + + - Call :attr:`connection_accept_handler` to accept a connection + - Call :attr:`on_accept` if defined after a connection has been accepted + - Call :attr:`on_disconnect` upon leaving the context + """ + await self.connection_accept_handler(socket) + + if self.on_accept: + await self.on_accept(**(on_accept_dependencies or {})) + + try: + yield + except WebSocketDisconnect: + pass + finally: + if self.on_disconnect: + await self.on_disconnect(**(on_disconnect_dependencies or {})) + + def resolve_receive_handler(self) -> Callable[[WebSocket], Any]: + if self._receive_handler is Empty: + self._receive_handler = create_handle_receive(self) + return self._receive_handler + + def resolve_send_handler(self) -> Callable[[WebSocket, Any], Coroutine[None, None, None]]: + if self._send_handler is Empty: + self._send_handler = create_handle_send(self) + return self._send_handler + + +websocket_listener = WebsocketListenerRouteHandler + + +class WebsocketListener(ABC): + path: str | list[str] | None = None + """A path fragment for the route handler function or a sequence of path fragments. If not given defaults to ``/``""" + dependencies: Dependencies | None = None + """A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances.""" + dto: type[AbstractDTO] | None | EmptyType = Empty + """:class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and validation of request data""" + exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None + """A mapping of status codes and/or exception types to handler functions.""" + guards: list[Guard] | None = None + """A sequence of :class:`Guard <.types.Guard>` callables.""" + middleware: list[Middleware] | None = None + """A sequence of :class:`Middleware <.types.Middleware>`.""" + on_accept: AnyCallable | None = None + """Called after a :class:`WebSocket <.connection.WebSocket>` connection has been accepted. Can receive any dependencies""" + on_disconnect: AnyCallable | None = None + """Called after a :class:`WebSocket <.connection.WebSocket>` connection has been disconnected. Can receive any dependencies""" + receive_mode: WebSocketMode = "text" + """:class:`WebSocket <.connection.WebSocket>` mode to receive data in, either ``text`` or ``binary``.""" + send_mode: WebSocketMode = "text" + """Websocket mode to send data in, either `text` or `binary`.""" + name: str | None = None + """A string identifying the route handler.""" + opt: dict[str, Any] | None = None + """ + A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or wherever you + have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + """ + return_dto: type[AbstractDTO] | None | EmptyType = Empty + """:class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing outbound response data.""" + signature_namespace: Mapping[str, Any] | None = None + """ + A mapping of names to types for use in forward reference resolution during signature modelling. + """ + type_decoders: TypeDecodersSequence | None = None + """ + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. + """ + type_encoders: TypeEncodersMap | None = None + """ + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + """ + websocket_class: type[WebSocket] | None = None + """ + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's + default websocket class. + """ + + def __init__(self, owner: Router) -> None: + """Initialize a WebsocketListener instance. + + Args: + owner: The :class:`Router <.router.Router>` instance that owns this listener. + """ + self._owner = owner + + def to_handler(self) -> WebsocketListenerRouteHandler: + handler = WebsocketListenerRouteHandler( + dependencies=self.dependencies, + dto=self.dto, + exception_handlers=self.exception_handlers, + guards=self.guards, + middleware=self.middleware, + send_mode=self.send_mode, + receive_mode=self.receive_mode, + name=self.name, + on_accept=self.on_accept, + on_disconnect=self.on_disconnect, + opt=self.opt, + path=self.path, + return_dto=self.return_dto, + signature_namespace=self.signature_namespace, + type_decoders=self.type_decoders, + type_encoders=self.type_encoders, + websocket_class=self.websocket_class, + )(self.on_receive) + handler.owner = self._owner + return handler + + @abstractmethod + def on_receive(self, *args: Any, **kwargs: Any) -> Any: + """Called after data has been received from the WebSocket. + + This should take a ``data`` argument, receiving the processed WebSocket data, + and can additionally include handler dependencies such as ``state``, or other + regular dependencies. + + Data returned from this function will be serialized and sent via the socket + according to handler configuration. + """ + raise NotImplementedError diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/route_handler.py b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/route_handler.py new file mode 100644 index 0000000..edb49c3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/route_handler.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping + +from litestar.connection import WebSocket +from litestar.exceptions import ImproperlyConfiguredException +from litestar.handlers import BaseRouteHandler +from litestar.types.builtin_types import NoneType +from litestar.utils.predicates import is_async_callable + +if TYPE_CHECKING: + from litestar.types import Dependencies, ExceptionHandler, Guard, Middleware + + +class WebsocketRouteHandler(BaseRouteHandler): + """Websocket route handler decorator. + + Use this decorator to decorate websocket handler functions. + """ + + def __init__( + self, + path: str | list[str] | None = None, + *, + dependencies: Dependencies | None = None, + exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None, + guards: list[Guard] | None = None, + middleware: list[Middleware] | None = None, + name: str | None = None, + opt: dict[str, Any] | None = None, + signature_namespace: Mapping[str, Any] | None = None, + websocket_class: type[WebSocket] | None = None, + **kwargs: Any, + ) -> None: + """Initialize ``WebsocketRouteHandler`` + + Args: + path: A path fragment for the route handler function or a sequence of path fragments. If not given defaults + to ``/`` + dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + name: A string identifying the route handler. + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or + :class:`ASGI Scope <.types.Scope>`. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + **kwargs: Any additional kwarg - will be set in the opt dictionary. + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's + default websocket class. + """ + self.websocket_class = websocket_class + + super().__init__( + path=path, + dependencies=dependencies, + exception_handlers=exception_handlers, + guards=guards, + middleware=middleware, + name=name, + opt=opt, + signature_namespace=signature_namespace, + **kwargs, + ) + + def resolve_websocket_class(self) -> type[WebSocket]: + """Return the closest custom WebSocket class in the owner graph or the default Websocket class. + + This method is memoized so the computation occurs only once. + + Returns: + The default :class:`WebSocket <.connection.WebSocket>` class for the route handler. + """ + return next( + (layer.websocket_class for layer in reversed(self.ownership_layers) if layer.websocket_class is not None), + WebSocket, + ) + + def _validate_handler_function(self) -> None: + """Validate the route handler function once it's set by inspecting its return annotations.""" + super()._validate_handler_function() + + if not self.parsed_fn_signature.return_type.is_subclass_of(NoneType): + raise ImproperlyConfiguredException("Websocket handler functions should return 'None'") + + if "socket" not in self.parsed_fn_signature.parameters: + raise ImproperlyConfiguredException("Websocket handlers must set a 'socket' kwarg") + + for param in ("request", "body", "data"): + if param in self.parsed_fn_signature.parameters: + raise ImproperlyConfiguredException(f"The {param} kwarg is not supported with websocket handlers") + + if not is_async_callable(self.fn): + raise ImproperlyConfiguredException("Functions decorated with 'websocket' must be async functions") + + +websocket = WebsocketRouteHandler diff --git a/venv/lib/python3.11/site-packages/litestar/logging/__init__.py b/venv/lib/python3.11/site-packages/litestar/logging/__init__.py new file mode 100644 index 0000000..b05ceba --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/logging/__init__.py @@ -0,0 +1,3 @@ +from .config import BaseLoggingConfig, LoggingConfig, StructLoggingConfig + +__all__ = ("BaseLoggingConfig", "StructLoggingConfig", "LoggingConfig") diff --git a/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6a830eb --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/_utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/_utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ba06e72 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/_utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/config.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/config.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f1c97d4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/config.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/picologging.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/picologging.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1488885 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/picologging.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/standard.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/standard.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ceda61d --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/logging/__pycache__/standard.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/logging/_utils.py b/venv/lib/python3.11/site-packages/litestar/logging/_utils.py new file mode 100644 index 0000000..ee67b9f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/logging/_utils.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from typing import Any + +__all__ = ("resolve_handlers",) + + +def resolve_handlers(handlers: list[Any]) -> list[Any]: + """Convert list of string of handlers to the object of respective handler. + + Indexing the list performs the evaluation of the object. + + Args: + handlers: An instance of 'ConvertingList' + + Returns: + A list of resolved handlers. + + Notes: + Due to missing typing in 'typeshed' we cannot type this as ConvertingList for now. + """ + return [handlers[i] for i in range(len(handlers))] diff --git a/venv/lib/python3.11/site-packages/litestar/logging/config.py b/venv/lib/python3.11/site-packages/litestar/logging/config.py new file mode 100644 index 0000000..a4a3713 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/logging/config.py @@ -0,0 +1,509 @@ +from __future__ import annotations + +import sys +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass, field, fields +from importlib.util import find_spec +from logging import INFO +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +from litestar.exceptions import ImproperlyConfiguredException, MissingDependencyException +from litestar.serialization.msgspec_hooks import _msgspec_json_encoder +from litestar.utils.deprecation import deprecated + +__all__ = ("BaseLoggingConfig", "LoggingConfig", "StructLoggingConfig") + + +if TYPE_CHECKING: + from collections.abc import Iterable + from typing import NoReturn + + # these imports are duplicated on purpose so sphinx autodoc can find and link them + from structlog.types import BindableLogger, Processor, WrappedLogger + from structlog.typing import EventDict + + from litestar.types import Logger, Scope + from litestar.types.callable_types import ExceptionLoggingHandler, GetLogger + + +try: + from structlog.types import BindableLogger, Processor, WrappedLogger +except ImportError: + BindableLogger = Any # type: ignore[assignment, misc] + Processor = Any # type: ignore[misc] + WrappedLogger = Any # type: ignore[misc] + + +default_handlers: dict[str, dict[str, Any]] = { + "console": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "standard", + }, + "queue_listener": { + "class": "litestar.logging.standard.QueueListenerHandler", + "level": "DEBUG", + "formatter": "standard", + }, +} + +if sys.version_info >= (3, 12, 0): + default_handlers["queue_listener"].update( + { + "class": "logging.handlers.QueueHandler", + "queue": { + "()": "queue.Queue", + "maxsize": -1, + }, + "listener": "litestar.logging.standard.LoggingQueueListener", + "handlers": ["console"], + } + ) + + # do not format twice, the console handler will do the job + del default_handlers["queue_listener"]["formatter"] + + +default_picologging_handlers: dict[str, dict[str, Any]] = { + "console": { + "class": "picologging.StreamHandler", + "level": "DEBUG", + "formatter": "standard", + }, + "queue_listener": { + "class": "litestar.logging.picologging.QueueListenerHandler", + "level": "DEBUG", + "formatter": "standard", + }, +} + + +def get_logger_placeholder(_: str | None = None) -> NoReturn: + """Raise: An :class:`ImproperlyConfiguredException <.exceptions.ImproperlyConfiguredException>`""" + raise ImproperlyConfiguredException( + "cannot call '.get_logger' without passing 'logging_config' to the Litestar constructor first" + ) + + +def _get_default_handlers() -> dict[str, dict[str, Any]]: + """Return the default logging handlers for the config. + + Returns: + A dictionary of logging handlers + """ + if find_spec("picologging"): + return default_picologging_handlers + return default_handlers + + +def _default_exception_logging_handler_factory( + is_struct_logger: bool, traceback_line_limit: int +) -> ExceptionLoggingHandler: + """Create an exception logging handler function. + + Args: + is_struct_logger: Whether the logger is a structlog instance. + traceback_line_limit: Maximal number of lines to log from the + traceback. + + Returns: + An exception logging handler. + """ + + def _default_exception_logging_handler(logger: Logger, scope: Scope, tb: list[str]) -> None: + # we limit the length of the stack trace to 20 lines. + first_line = tb.pop(0) + + if is_struct_logger: + logger.exception( + "Uncaught Exception", + connection_type=scope["type"], + path=scope["path"], + traceback="".join(tb[-traceback_line_limit:]), + ) + else: + stack_trace = first_line + "".join(tb[-traceback_line_limit:]) + logger.exception( + "exception raised on %s connection to route %s\n\n%s", scope["type"], scope["path"], stack_trace + ) + + return _default_exception_logging_handler + + +class BaseLoggingConfig(ABC): + """Abstract class that should be extended by logging configs.""" + + __slots__ = ("log_exceptions", "traceback_line_limit", "exception_logging_handler") + + log_exceptions: Literal["always", "debug", "never"] + """Should exceptions be logged, defaults to log exceptions when ``app.debug == True``'""" + traceback_line_limit: int + """Max number of lines to print for exception traceback""" + exception_logging_handler: ExceptionLoggingHandler | None + """Handler function for logging exceptions.""" + + @abstractmethod + def configure(self) -> GetLogger: + """Return logger with the given configuration. + + Returns: + A 'logging.getLogger' like function. + """ + raise NotImplementedError("abstract method") + + @staticmethod + def set_level(logger: Any, level: int) -> None: + """Provides a consistent interface to call `setLevel` for all loggers.""" + raise NotImplementedError("abstract method") + + +@dataclass +class LoggingConfig(BaseLoggingConfig): + """Configuration class for standard logging. + + Notes: + - If 'picologging' is installed it will be used by default. + """ + + version: Literal[1] = field(default=1) + """The only valid value at present is 1.""" + incremental: bool = field(default=False) + """Whether the configuration is to be interpreted as incremental to the existing configuration. + + Notes: + - This option is ignored for 'picologging' + """ + disable_existing_loggers: bool = field(default=False) + """Whether any existing non-root loggers are to be disabled.""" + filters: dict[str, dict[str, Any]] | None = field(default=None) + """A dict in which each key is a filter id and each value is a dict describing how to configure the corresponding + Filter instance. + """ + propagate: bool = field(default=True) + """If messages must propagate to handlers higher up the logger hierarchy from this logger.""" + formatters: dict[str, dict[str, Any]] = field( + default_factory=lambda: { + "standard": {"format": "%(levelname)s - %(asctime)s - %(name)s - %(module)s - %(message)s"} + } + ) + handlers: dict[str, dict[str, Any]] = field(default_factory=_get_default_handlers) + """A dict in which each key is a handler id and each value is a dict describing how to configure the corresponding + Handler instance. + """ + loggers: dict[str, dict[str, Any]] = field( + default_factory=lambda: { + "litestar": {"level": "INFO", "handlers": ["queue_listener"], "propagate": False}, + } + ) + """A dict in which each key is a logger name and each value is a dict describing how to configure the corresponding + Logger instance. + """ + root: dict[str, dict[str, Any] | list[Any] | str] = field( + default_factory=lambda: { + "handlers": ["queue_listener"], + "level": "INFO", + } + ) + """This will be the configuration for the root logger. + + Processing of the configuration will be as for any logger, except that the propagate setting will not be applicable. + """ + configure_root_logger: bool = field(default=True) + """Should the root logger be configured, defaults to True for ease of configuration.""" + log_exceptions: Literal["always", "debug", "never"] = field(default="debug") + """Should exceptions be logged, defaults to log exceptions when 'app.debug == True'""" + traceback_line_limit: int = field(default=20) + """Max number of lines to print for exception traceback""" + exception_logging_handler: ExceptionLoggingHandler | None = field(default=None) + """Handler function for logging exceptions.""" + + def __post_init__(self) -> None: + if "queue_listener" not in self.handlers: + self.handlers["queue_listener"] = _get_default_handlers()["queue_listener"] + + if "litestar" not in self.loggers: + self.loggers["litestar"] = { + "level": "INFO", + "handlers": ["queue_listener"], + "propagate": False, + } + + if self.log_exceptions != "never" and self.exception_logging_handler is None: + self.exception_logging_handler = _default_exception_logging_handler_factory( + is_struct_logger=False, traceback_line_limit=self.traceback_line_limit + ) + + def configure(self) -> GetLogger: + """Return logger with the given configuration. + + Returns: + A 'logging.getLogger' like function. + """ + + excluded_fields: tuple[str, ...] + if "picologging" in " ".join([handler["class"] for handler in self.handlers.values()]): + try: + from picologging import config, getLogger + except ImportError as e: + raise MissingDependencyException("picologging") from e + + excluded_fields = ("incremental", "configure_root_logger") + else: + from logging import config, getLogger # type: ignore[no-redef, assignment] + + excluded_fields = ("configure_root_logger",) + + values = { + _field.name: getattr(self, _field.name) + for _field in fields(self) + if getattr(self, _field.name) is not None and _field.name not in excluded_fields + } + + if not self.configure_root_logger: + values.pop("root") + config.dictConfig(values) + return cast("Callable[[str], Logger]", getLogger) + + @staticmethod + def set_level(logger: Logger, level: int) -> None: + """Provides a consistent interface to call `setLevel` for all loggers.""" + logger.setLevel(level) + + +class StructlogEventFilter: + """Remove keys from the log event. + + Add an instance to the processor chain. + + .. code-block:: python + :caption: Examples + + structlog.configure( + ..., + processors=[ + ..., + EventFilter(["color_message"]), + ..., + ], + ) + + """ + + def __init__(self, filter_keys: Iterable[str]) -> None: + """Initialize the EventFilter. + + Args: + filter_keys: Iterable of string keys to be excluded from the log event. + """ + self.filter_keys = filter_keys + + def __call__(self, _: WrappedLogger, __: str, event_dict: EventDict) -> EventDict: + """Receive the log event, and filter keys. + + Args: + _ (): + __ (): + event_dict (): The data to be logged. + + Returns: + The log event with any key in `self.filter_keys` removed. + """ + for key in self.filter_keys: + event_dict.pop(key, None) + return event_dict + + +def default_json_serializer(value: EventDict, **_: Any) -> bytes: + return _msgspec_json_encoder.encode(value) + + +def stdlib_json_serializer(value: EventDict, **_: Any) -> str: # pragma: no cover + return _msgspec_json_encoder.encode(value).decode("utf-8") + + +def default_structlog_processors(as_json: bool = True) -> list[Processor]: # pyright: ignore + """Set the default processors for structlog. + + Returns: + An optional list of processors. + """ + try: + import structlog + from structlog.dev import RichTracebackFormatter + + if as_json: + return [ + structlog.contextvars.merge_contextvars, + structlog.processors.add_log_level, + structlog.processors.format_exc_info, + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.JSONRenderer(serializer=default_json_serializer), + ] + return [ + structlog.contextvars.merge_contextvars, + structlog.processors.add_log_level, + structlog.processors.TimeStamper(fmt="iso"), + structlog.dev.ConsoleRenderer( + colors=True, exception_formatter=RichTracebackFormatter(max_frames=1, show_locals=False, width=80) + ), + ] + + except ImportError: + return [] + + +def default_structlog_standard_lib_processors(as_json: bool = True) -> list[Processor]: # pyright: ignore + """Set the default processors for structlog stdlib. + + Returns: + An optional list of processors. + """ + try: + import structlog + from structlog.dev import RichTracebackFormatter + + if as_json: + return [ + structlog.processors.TimeStamper(fmt="iso"), + structlog.stdlib.add_log_level, + structlog.stdlib.ExtraAdder(), + StructlogEventFilter(["color_message"]), + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + structlog.processors.JSONRenderer(serializer=stdlib_json_serializer), + ] + return [ + structlog.processors.TimeStamper(fmt="iso"), + structlog.stdlib.add_log_level, + structlog.stdlib.ExtraAdder(), + StructlogEventFilter(["color_message"]), + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + structlog.dev.ConsoleRenderer( + colors=True, exception_formatter=RichTracebackFormatter(max_frames=1, show_locals=False, width=80) + ), + ] + except ImportError: + return [] + + +def default_logger_factory(as_json: bool = True) -> Callable[..., WrappedLogger] | None: + """Set the default logger factory for structlog. + + Returns: + An optional logger factory. + """ + try: + import structlog + + if as_json: + return structlog.BytesLoggerFactory() + return structlog.WriteLoggerFactory() + except ImportError: + return None + + +@dataclass +class StructLoggingConfig(BaseLoggingConfig): + """Configuration class for structlog. + + Notes: + - requires ``structlog`` to be installed. + """ + + processors: list[Processor] | None = field(default=None) # pyright: ignore + """Iterable of structlog logging processors.""" + standard_lib_logging_config: LoggingConfig | None = field(default=None) # pyright: ignore + """Optional customized standard logging configuration. + + Use this when you need to modify the standard library outside of the Structlog pre-configured implementation. + """ + wrapper_class: type[BindableLogger] | None = field(default=None) # pyright: ignore + """Structlog bindable logger.""" + context_class: dict[str, Any] | None = None + """Context class (a 'contextvar' context) for the logger.""" + logger_factory: Callable[..., WrappedLogger] | None = field(default=None) # pyright: ignore + """Logger factory to use.""" + cache_logger_on_first_use: bool = field(default=True) + """Whether to cache the logger configuration and reuse.""" + log_exceptions: Literal["always", "debug", "never"] = field(default="debug") + """Should exceptions be logged, defaults to log exceptions when 'app.debug == True'""" + traceback_line_limit: int = field(default=20) + """Max number of lines to print for exception traceback""" + exception_logging_handler: ExceptionLoggingHandler | None = field(default=None) + """Handler function for logging exceptions.""" + pretty_print_tty: bool = field(default=True) + """Pretty print log output when run from an interactive terminal.""" + + def __post_init__(self) -> None: + if self.processors is None: + self.processors = default_structlog_processors(not sys.stderr.isatty() and self.pretty_print_tty) + if self.logger_factory is None: + self.logger_factory = default_logger_factory(not sys.stderr.isatty() and self.pretty_print_tty) + if self.log_exceptions != "never" and self.exception_logging_handler is None: + self.exception_logging_handler = _default_exception_logging_handler_factory( + is_struct_logger=True, traceback_line_limit=self.traceback_line_limit + ) + try: + import structlog + + if self.standard_lib_logging_config is None: + self.standard_lib_logging_config = LoggingConfig( + formatters={ + "standard": { + "()": structlog.stdlib.ProcessorFormatter, + "processors": default_structlog_standard_lib_processors( + as_json=not sys.stderr.isatty() and self.pretty_print_tty + ), + } + } + ) + except ImportError: + self.standard_lib_logging_config = LoggingConfig() + + def configure(self) -> GetLogger: + """Return logger with the given configuration. + + Returns: + A 'logging.getLogger' like function. + """ + try: + import structlog + except ImportError as e: + raise MissingDependencyException("structlog") from e + + structlog.configure( + **{ + k: v + for k, v in asdict(self).items() + if k + not in ( + "standard_lib_logging_config", + "log_exceptions", + "traceback_line_limit", + "exception_logging_handler", + "pretty_print_tty", + ) + } + ) + return structlog.get_logger + + @staticmethod + def set_level(logger: Logger, level: int) -> None: + """Provides a consistent interface to call `setLevel` for all loggers.""" + + try: + import structlog + + structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(level)) + except ImportError: + """""" + return + + +@deprecated(version="2.6.0", removal_in="3.0.0", alternative="`StructLoggingConfig.set_level`") +def default_wrapper_class(log_level: int = INFO) -> type[BindableLogger] | None: # pragma: no cover # pyright: ignore + try: # pragma: no cover + import structlog + + return structlog.make_filtering_bound_logger(log_level) + except ImportError: + return None diff --git a/venv/lib/python3.11/site-packages/litestar/logging/picologging.py b/venv/lib/python3.11/site-packages/litestar/logging/picologging.py new file mode 100644 index 0000000..2cd599f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/logging/picologging.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import atexit +from queue import Queue +from typing import Any + +from litestar.exceptions import MissingDependencyException +from litestar.logging._utils import resolve_handlers + +__all__ = ("QueueListenerHandler",) + + +try: + import picologging # noqa: F401 +except ImportError as e: + raise MissingDependencyException("picologging") from e + +from picologging import StreamHandler +from picologging.handlers import QueueHandler, QueueListener + + +class QueueListenerHandler(QueueHandler): + """Configure queue listener and handler to support non-blocking logging configuration.""" + + def __init__(self, handlers: list[Any] | None = None) -> None: + """Initialize ``QueueListenerHandler``. + + Args: + handlers: Optional 'ConvertingList' + + Notes: + - Requires ``picologging`` to be installed. + """ + super().__init__(Queue(-1)) + handlers = resolve_handlers(handlers) if handlers else [StreamHandler()] + self.listener = QueueListener(self.queue, *handlers) + self.listener.start() + + atexit.register(self.listener.stop) diff --git a/venv/lib/python3.11/site-packages/litestar/logging/standard.py b/venv/lib/python3.11/site-packages/litestar/logging/standard.py new file mode 100644 index 0000000..131c0ed --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/logging/standard.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import atexit +from logging import Handler, LogRecord, StreamHandler +from logging.handlers import QueueHandler, QueueListener +from queue import Queue +from typing import Any + +from litestar.logging._utils import resolve_handlers + +__all__ = ("LoggingQueueListener", "QueueListenerHandler") + + +class LoggingQueueListener(QueueListener): + """Custom ``QueueListener`` which starts and stops the listening process.""" + + def __init__(self, queue: Queue[LogRecord], *handlers: Handler, respect_handler_level: bool = False) -> None: + """Initialize ``LoggingQueueListener``. + + Args: + queue: The queue to send messages to + *handlers: A list of handlers which will handle entries placed on the queue + respect_handler_level: If ``respect_handler_level`` is ``True``, a handler's level is respected (compared with the level for the message) when deciding whether to pass messages to that handler + """ + super().__init__(queue, *handlers, respect_handler_level=respect_handler_level) + self.start() + atexit.register(self.stop) + + +class QueueListenerHandler(QueueHandler): + """Configure queue listener and handler to support non-blocking logging configuration. + + .. caution:: + + This handler doesn't work with Python >= 3.12 and ``logging.config.dictConfig``. It might + be deprecated in the future. Please use ``logging.QueueHandler`` instead. + """ + + def __init__(self, handlers: list[Any] | None = None) -> None: + """Initialize ``QueueListenerHandler``. + + Args: + handlers: Optional 'ConvertingList' + """ + super().__init__(Queue(-1)) + handlers = resolve_handlers(handlers) if handlers else [StreamHandler()] + self.listener = LoggingQueueListener(self.queue, *handlers) # type: ignore[arg-type] diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__init__.py b/venv/lib/python3.11/site-packages/litestar/middleware/__init__.py new file mode 100644 index 0000000..7024e54 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__init__.py @@ -0,0 +1,17 @@ +from litestar.middleware.authentication import ( + AbstractAuthenticationMiddleware, + AuthenticationResult, +) +from litestar.middleware.base import ( + AbstractMiddleware, + DefineMiddleware, + MiddlewareProtocol, +) + +__all__ = ( + "AbstractAuthenticationMiddleware", + "AbstractMiddleware", + "AuthenticationResult", + "DefineMiddleware", + "MiddlewareProtocol", +) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c807c7a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/_utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/_utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..95f4515 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/_utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/allowed_hosts.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/allowed_hosts.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..48bfb2a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/allowed_hosts.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/authentication.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/authentication.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..193563c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/authentication.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6a0ef6f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/cors.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/cors.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f4277f4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/cors.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/csrf.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/csrf.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4679eea --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/csrf.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/logging.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/logging.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..20db6ff --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/logging.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/rate_limit.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/rate_limit.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..83090e5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/rate_limit.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/response_cache.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/response_cache.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1672eb4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/__pycache__/response_cache.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/_utils.py b/venv/lib/python3.11/site-packages/litestar/middleware/_utils.py new file mode 100644 index 0000000..778a508 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/_utils.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Pattern, Sequence + +from litestar.exceptions import ImproperlyConfiguredException + +__all__ = ("build_exclude_path_pattern", "should_bypass_middleware") + + +if TYPE_CHECKING: + from litestar.types import Method, Scope, Scopes + + +def build_exclude_path_pattern(*, exclude: str | list[str] | None = None) -> Pattern | None: + """Build single path pattern from list of patterns to opt-out from middleware processing. + + Args: + exclude: A pattern or a list of patterns. + + Returns: + An optional pattern to match against scope["path"] to opt-out from middleware processing. + """ + if exclude is None: + return None + + try: + return re.compile("|".join(exclude)) if isinstance(exclude, list) else re.compile(exclude) + except re.error as e: # pragma: no cover + raise ImproperlyConfiguredException( + "Unable to compile exclude patterns for middleware. Please make sure you passed a valid regular expression." + ) from e + + +def should_bypass_middleware( + *, + exclude_http_methods: Sequence[Method] | None = None, + exclude_opt_key: str | None = None, + exclude_path_pattern: Pattern | None = None, + scope: Scope, + scopes: Scopes, +) -> bool: + """Determine weather a middleware should be bypassed. + + Args: + exclude_http_methods: A sequence of http methods that do not require authentication. + exclude_opt_key: Key in ``opt`` with which a route handler can "opt-out" of a middleware. + exclude_path_pattern: If this pattern matches scope["path"], the middleware should be bypassed. + scope: The ASGI scope. + scopes: A set with the ASGI scope types that are supported by the middleware. + + Returns: + A boolean indicating if a middleware should be bypassed + """ + if scope["type"] not in scopes: + return True + + if exclude_opt_key and scope["route_handler"].opt.get(exclude_opt_key): + return True + + if exclude_http_methods and scope.get("method") in exclude_http_methods: + return True + + return bool( + exclude_path_pattern + and exclude_path_pattern.findall( + scope["raw_path"].decode() if getattr(scope.get("route_handler", {}), "is_mount", False) else scope["path"] + ) + ) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/allowed_hosts.py b/venv/lib/python3.11/site-packages/litestar/middleware/allowed_hosts.py new file mode 100644 index 0000000..0172176 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/allowed_hosts.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Pattern + +from litestar.datastructures import URL, MutableScopeHeaders +from litestar.middleware.base import AbstractMiddleware +from litestar.response.base import ASGIResponse +from litestar.response.redirect import ASGIRedirectResponse +from litestar.status_codes import HTTP_400_BAD_REQUEST + +__all__ = ("AllowedHostsMiddleware",) + + +if TYPE_CHECKING: + from litestar.config.allowed_hosts import AllowedHostsConfig + from litestar.types import ASGIApp, Receive, Scope, Send + + +class AllowedHostsMiddleware(AbstractMiddleware): + """Middleware ensuring the host of a request originated in a trusted host.""" + + def __init__(self, app: ASGIApp, config: AllowedHostsConfig) -> None: + """Initialize ``AllowedHostsMiddleware``. + + Args: + app: The ``next`` ASGI app to call. + config: An instance of AllowedHostsConfig. + """ + + super().__init__(app=app, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key, scopes=config.scopes) + + self.allowed_hosts_regex: Pattern | None = None + self.redirect_domains: Pattern | None = None + + if any(host == "*" for host in config.allowed_hosts): + return + + allowed_hosts: set[str] = { + rf".*\.{host.replace('*.', '')}$" if host.startswith("*.") else host for host in config.allowed_hosts + } + + self.allowed_hosts_regex = re.compile("|".join(sorted(allowed_hosts))) # pyright: ignore + + if config.www_redirect and ( + redirect_domains := {host.replace("www.", "") for host in config.allowed_hosts if host.startswith("www.")} + ): + self.redirect_domains = re.compile("|".join(sorted(redirect_domains))) # pyright: ignore + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + if self.allowed_hosts_regex is None: + await self.app(scope, receive, send) + return + + headers = MutableScopeHeaders(scope=scope) + if host := headers.get("host", headers.get("x-forwarded-host", "")).split(":")[0]: + if self.allowed_hosts_regex.fullmatch(host): + await self.app(scope, receive, send) + return + + if self.redirect_domains is not None and self.redirect_domains.fullmatch(host): + url = URL.from_scope(scope) + redirect_url = url.with_replacements(netloc=f"www.{url.netloc}") + redirect_response = ASGIRedirectResponse(path=str(redirect_url)) + await redirect_response(scope, receive, send) + return + + response = ASGIResponse(body=b'{"message":"invalid host header"}', status_code=HTTP_400_BAD_REQUEST) + await response(scope, receive, send) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/authentication.py b/venv/lib/python3.11/site-packages/litestar/middleware/authentication.py new file mode 100644 index 0000000..9502df0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/authentication.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Sequence + +from litestar.connection import ASGIConnection +from litestar.enums import HttpMethod, ScopeType +from litestar.middleware._utils import ( + build_exclude_path_pattern, + should_bypass_middleware, +) + +__all__ = ("AbstractAuthenticationMiddleware", "AuthenticationResult") + + +if TYPE_CHECKING: + from litestar.types import ASGIApp, Method, Receive, Scope, Scopes, Send + + +@dataclass +class AuthenticationResult: + """Pydantic model for authentication data.""" + + __slots__ = ("user", "auth") + + user: Any + """The user model, this can be any value corresponding to a user of the API.""" + auth: Any + """The auth value, this can for example be a JWT token.""" + + +class AbstractAuthenticationMiddleware(ABC): + """Abstract AuthenticationMiddleware that allows users to create their own AuthenticationMiddleware by extending it + and overriding :meth:`AbstractAuthenticationMiddleware.authenticate_request`. + """ + + __slots__ = ( + "app", + "exclude", + "exclude_http_methods", + "exclude_opt_key", + "scopes", + ) + + def __init__( + self, + app: ASGIApp, + exclude: str | list[str] | None = None, + exclude_from_auth_key: str = "exclude_from_auth", + exclude_http_methods: Sequence[Method] | None = None, + scopes: Scopes | None = None, + ) -> None: + """Initialize ``AbstractAuthenticationMiddleware``. + + Args: + app: An ASGIApp, this value is the next ASGI handler to call in the middleware stack. + exclude: A pattern or list of patterns to skip in the authentication middleware. + exclude_from_auth_key: An identifier to use on routes to disable authentication for a particular route. + exclude_http_methods: A sequence of http methods that do not require authentication. + scopes: ASGI scopes processed by the authentication middleware. + """ + self.app = app + self.exclude = build_exclude_path_pattern(exclude=exclude) + self.exclude_http_methods = (HttpMethod.OPTIONS,) if exclude_http_methods is None else exclude_http_methods + self.exclude_opt_key = exclude_from_auth_key + self.scopes = scopes or {ScopeType.HTTP, ScopeType.WEBSOCKET} + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + if not should_bypass_middleware( + exclude_http_methods=self.exclude_http_methods, + exclude_opt_key=self.exclude_opt_key, + exclude_path_pattern=self.exclude, + scope=scope, + scopes=self.scopes, + ): + auth_result = await self.authenticate_request(ASGIConnection(scope)) + scope["user"] = auth_result.user + scope["auth"] = auth_result.auth + await self.app(scope, receive, send) + + @abstractmethod + async def authenticate_request(self, connection: ASGIConnection) -> AuthenticationResult: + """Receive the http connection and return an :class:`AuthenticationResult`. + + Notes: + - This method must be overridden by subclasses. + + Args: + connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. + + Raises: + NotAuthorizedException | PermissionDeniedException: if authentication fails. + + Returns: + An instance of :class:`AuthenticationResult <litestar.middleware.authentication.AuthenticationResult>`. + """ + raise NotImplementedError("authenticate_request must be overridden by subclasses") diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/base.py b/venv/lib/python3.11/site-packages/litestar/middleware/base.py new file mode 100644 index 0000000..43106c9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/base.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Callable, Protocol, runtime_checkable + +from litestar.enums import ScopeType +from litestar.middleware._utils import ( + build_exclude_path_pattern, + should_bypass_middleware, +) + +__all__ = ("AbstractMiddleware", "DefineMiddleware", "MiddlewareProtocol") + + +if TYPE_CHECKING: + from litestar.types import Scopes + from litestar.types.asgi_types import ASGIApp, Receive, Scope, Send + + +@runtime_checkable +class MiddlewareProtocol(Protocol): + """Abstract middleware protocol.""" + + __slots__ = ("app",) + + app: ASGIApp + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Execute the ASGI middleware. + + Called by the previous middleware in the stack if a response is not awaited prior. + + Upon completion, middleware should call the next ASGI handler and await it - or await a response created in its + closure. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + + +class DefineMiddleware: + """Container enabling passing ``*args`` and ``**kwargs`` to Middleware class constructors and factory functions.""" + + __slots__ = ("middleware", "args", "kwargs") + + def __init__(self, middleware: Callable[..., ASGIApp], *args: Any, **kwargs: Any) -> None: + """Initialize ``DefineMiddleware``. + + Args: + middleware: A callable that returns an ASGIApp. + *args: Positional arguments to pass to the callable. + **kwargs: Key word arguments to pass to the callable. + + Notes: + The callable will be passed a kwarg ``app``, which is the next ASGI app to call in the middleware stack. + It therefore must define such a kwarg. + """ + self.middleware = middleware + self.args = args + self.kwargs = kwargs + + def __call__(self, app: ASGIApp) -> ASGIApp: + """Call the middleware constructor or factory. + + Args: + app: An ASGIApp, this value is the next ASGI handler to call in the middleware stack. + + Returns: + Calls :class:`DefineMiddleware.middleware <.DefineMiddleware>` and returns the ASGIApp created. + """ + + return self.middleware(*self.args, app=app, **self.kwargs) + + +class AbstractMiddleware: + """Abstract middleware providing base functionality common to all middlewares, for dynamically engaging/bypassing + the middleware based on paths, ``opt``-keys and scope types. + + When implementing new middleware, this class should be used as a base. + """ + + scopes: Scopes = {ScopeType.HTTP, ScopeType.WEBSOCKET} + exclude: str | list[str] | None = None + exclude_opt_key: str | None = None + + def __init__( + self, + app: ASGIApp, + exclude: str | list[str] | None = None, + exclude_opt_key: str | None = None, + scopes: Scopes | None = None, + ) -> None: + """Initialize the middleware. + + Args: + app: The ``next`` ASGI app to call. + exclude: A pattern or list of patterns to match against a request's path. + If a match is found, the middleware will be skipped. + exclude_opt_key: An identifier that is set in the route handler + ``opt`` key which allows skipping the middleware. + scopes: ASGI scope types, should be a set including + either or both 'ScopeType.HTTP' and 'ScopeType.WEBSOCKET'. + """ + self.app = app + self.scopes = scopes or self.scopes + self.exclude_opt_key = exclude_opt_key or self.exclude_opt_key + self.exclude_pattern = build_exclude_path_pattern(exclude=(exclude or self.exclude)) + + @classmethod + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + + original__call__ = cls.__call__ + + async def wrapped_call(self: AbstractMiddleware, scope: Scope, receive: Receive, send: Send) -> None: + if should_bypass_middleware( + scope=scope, + scopes=self.scopes, + exclude_path_pattern=self.exclude_pattern, + exclude_opt_key=self.exclude_opt_key, + ): + await self.app(scope, receive, send) + else: + await original__call__(self, scope, receive, send) # pyright: ignore + + # https://github.com/python/mypy/issues/2427#issuecomment-384229898 + setattr(cls, "__call__", wrapped_call) + + @abstractmethod + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Execute the ASGI middleware. + + Called by the previous middleware in the stack if a response is not awaited prior. + + Upon completion, middleware should call the next ASGI handler and await it - or await a response created in its + closure. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + raise NotImplementedError("abstract method must be implemented") diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/__init__.py b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__init__.py new file mode 100644 index 0000000..0885932 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__init__.py @@ -0,0 +1,4 @@ +from litestar.middleware.compression.facade import CompressionFacade +from litestar.middleware.compression.middleware import CompressionMiddleware + +__all__ = ("CompressionMiddleware", "CompressionFacade") diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..80ea058 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/brotli_facade.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/brotli_facade.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..7378c0f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/brotli_facade.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/facade.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/facade.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..d336c8f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/facade.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/gzip_facade.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/gzip_facade.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..66e1df4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/gzip_facade.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/middleware.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/middleware.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a683673 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/__pycache__/middleware.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/brotli_facade.py b/venv/lib/python3.11/site-packages/litestar/middleware/compression/brotli_facade.py new file mode 100644 index 0000000..3d01950 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/brotli_facade.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from litestar.enums import CompressionEncoding +from litestar.exceptions import MissingDependencyException +from litestar.middleware.compression.facade import CompressionFacade + +try: + from brotli import MODE_FONT, MODE_GENERIC, MODE_TEXT, Compressor +except ImportError as e: + raise MissingDependencyException("brotli") from e + + +if TYPE_CHECKING: + from io import BytesIO + + from litestar.config.compression import CompressionConfig + + +class BrotliCompression(CompressionFacade): + __slots__ = ("compressor", "buffer", "compression_encoding") + + encoding = CompressionEncoding.BROTLI + + def __init__( + self, + buffer: BytesIO, + compression_encoding: Literal[CompressionEncoding.BROTLI] | str, + config: CompressionConfig, + ) -> None: + self.buffer = buffer + self.compression_encoding = compression_encoding + modes: dict[Literal["generic", "text", "font"], int] = { + "text": int(MODE_TEXT), + "font": int(MODE_FONT), + "generic": int(MODE_GENERIC), + } + self.compressor = Compressor( + quality=config.brotli_quality, + mode=modes[config.brotli_mode], + lgwin=config.brotli_lgwin, + lgblock=config.brotli_lgblock, + ) + + def write(self, body: bytes) -> None: + self.buffer.write(self.compressor.process(body)) + self.buffer.write(self.compressor.flush()) + + def close(self) -> None: + self.buffer.write(self.compressor.finish()) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/facade.py b/venv/lib/python3.11/site-packages/litestar/middleware/compression/facade.py new file mode 100644 index 0000000..0074b57 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/facade.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar, Protocol + +if TYPE_CHECKING: + from io import BytesIO + + from litestar.config.compression import CompressionConfig + from litestar.enums import CompressionEncoding + + +class CompressionFacade(Protocol): + """A unified facade offering a uniform interface for different compression libraries.""" + + encoding: ClassVar[str] + """The encoding of the compression.""" + + def __init__( + self, buffer: BytesIO, compression_encoding: CompressionEncoding | str, config: CompressionConfig + ) -> None: + """Initialize ``CompressionFacade``. + + Args: + buffer: A bytes IO buffer to write the compressed data into. + compression_encoding: The compression encoding used. + config: The app compression config. + """ + ... + + def write(self, body: bytes) -> None: + """Write compressed bytes. + + Args: + body: Message body to process + + Returns: + None + """ + ... + + def close(self) -> None: + """Close the compression stream. + + Returns: + None + """ + ... diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/gzip_facade.py b/venv/lib/python3.11/site-packages/litestar/middleware/compression/gzip_facade.py new file mode 100644 index 0000000..b10ef73 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/gzip_facade.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from gzip import GzipFile +from typing import TYPE_CHECKING, Literal + +from litestar.enums import CompressionEncoding +from litestar.middleware.compression.facade import CompressionFacade + +if TYPE_CHECKING: + from io import BytesIO + + from litestar.config.compression import CompressionConfig + + +class GzipCompression(CompressionFacade): + __slots__ = ("compressor", "buffer", "compression_encoding") + + encoding = CompressionEncoding.GZIP + + def __init__( + self, buffer: BytesIO, compression_encoding: Literal[CompressionEncoding.GZIP] | str, config: CompressionConfig + ) -> None: + self.buffer = buffer + self.compression_encoding = compression_encoding + self.compressor = GzipFile(mode="wb", fileobj=buffer, compresslevel=config.gzip_compress_level) + + def write(self, body: bytes) -> None: + self.compressor.write(body) + self.compressor.flush() + + def close(self) -> None: + self.compressor.close() diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/compression/middleware.py b/venv/lib/python3.11/site-packages/litestar/middleware/compression/middleware.py new file mode 100644 index 0000000..7ea7853 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/compression/middleware.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +from io import BytesIO +from typing import TYPE_CHECKING, Any, Literal + +from litestar.datastructures import Headers, MutableScopeHeaders +from litestar.enums import CompressionEncoding, ScopeType +from litestar.middleware.base import AbstractMiddleware +from litestar.middleware.compression.gzip_facade import GzipCompression +from litestar.utils.empty import value_or_default +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from litestar.config.compression import CompressionConfig + from litestar.middleware.compression.facade import CompressionFacade + from litestar.types import ( + ASGIApp, + HTTPResponseStartEvent, + Message, + Receive, + Scope, + Send, + ) + + try: + from brotli import Compressor + except ImportError: + Compressor = Any + + +class CompressionMiddleware(AbstractMiddleware): + """Compression Middleware Wrapper. + + This is a wrapper allowing for generic compression configuration / handler middleware + """ + + def __init__(self, app: ASGIApp, config: CompressionConfig) -> None: + """Initialize ``CompressionMiddleware`` + + Args: + app: The ``next`` ASGI app to call. + config: An instance of CompressionConfig. + """ + super().__init__( + app=app, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key, scopes={ScopeType.HTTP} + ) + self.config = config + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + accept_encoding = Headers.from_scope(scope).get("accept-encoding", "") + config = self.config + + if config.compression_facade.encoding in accept_encoding: + await self.app( + scope, + receive, + self.create_compression_send_wrapper( + send=send, compression_encoding=config.compression_facade.encoding, scope=scope + ), + ) + return + + if config.gzip_fallback and CompressionEncoding.GZIP in accept_encoding: + await self.app( + scope, + receive, + self.create_compression_send_wrapper( + send=send, compression_encoding=CompressionEncoding.GZIP, scope=scope + ), + ) + return + + await self.app(scope, receive, send) + + def create_compression_send_wrapper( + self, + send: Send, + compression_encoding: Literal[CompressionEncoding.BROTLI, CompressionEncoding.GZIP] | str, + scope: Scope, + ) -> Send: + """Wrap ``send`` to handle brotli compression. + + Args: + send: The ASGI send function. + compression_encoding: The compression encoding used. + scope: The ASGI connection scope + + Returns: + An ASGI send function. + """ + bytes_buffer = BytesIO() + + facade: CompressionFacade + # We can't use `self.config.compression_facade` directly if the compression is `gzip` since + # it may be being used as a fallback. + if compression_encoding == CompressionEncoding.GZIP: + facade = GzipCompression(buffer=bytes_buffer, compression_encoding=compression_encoding, config=self.config) + else: + facade = self.config.compression_facade( + buffer=bytes_buffer, compression_encoding=compression_encoding, config=self.config + ) + + initial_message: HTTPResponseStartEvent | None = None + started = False + + connection_state = ScopeState.from_scope(scope) + + async def send_wrapper(message: Message) -> None: + """Handle and compresses the HTTP Message with brotli. + + Args: + message (Message): An ASGI Message. + """ + nonlocal started + nonlocal initial_message + + if message["type"] == "http.response.start": + initial_message = message + return + + if initial_message is not None and value_or_default(connection_state.is_cached, False): + await send(initial_message) + await send(message) + return + + if initial_message and message["type"] == "http.response.body": + body = message["body"] + more_body = message.get("more_body") + + if not started: + started = True + if more_body: + headers = MutableScopeHeaders(initial_message) + headers["Content-Encoding"] = compression_encoding + headers.extend_header_value("vary", "Accept-Encoding") + del headers["Content-Length"] + connection_state.response_compressed = True + + facade.write(body) + + message["body"] = bytes_buffer.getvalue() + bytes_buffer.seek(0) + bytes_buffer.truncate() + await send(initial_message) + await send(message) + + elif len(body) >= self.config.minimum_size: + facade.write(body) + facade.close() + body = bytes_buffer.getvalue() + + headers = MutableScopeHeaders(initial_message) + headers["Content-Encoding"] = compression_encoding + headers["Content-Length"] = str(len(body)) + headers.extend_header_value("vary", "Accept-Encoding") + message["body"] = body + connection_state.response_compressed = True + + await send(initial_message) + await send(message) + + else: + await send(initial_message) + await send(message) + + else: + facade.write(body) + if not more_body: + facade.close() + + message["body"] = bytes_buffer.getvalue() + + bytes_buffer.seek(0) + bytes_buffer.truncate() + + if not more_body: + bytes_buffer.close() + + await send(message) + + return send_wrapper diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/cors.py b/venv/lib/python3.11/site-packages/litestar/middleware/cors.py new file mode 100644 index 0000000..6c4de31 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/cors.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from litestar.datastructures import Headers, MutableScopeHeaders +from litestar.enums import ScopeType +from litestar.middleware.base import AbstractMiddleware + +__all__ = ("CORSMiddleware",) + + +if TYPE_CHECKING: + from litestar.config.cors import CORSConfig + from litestar.types import ASGIApp, Message, Receive, Scope, Send + + +class CORSMiddleware(AbstractMiddleware): + """CORS Middleware.""" + + __slots__ = ("config",) + + def __init__(self, app: ASGIApp, config: CORSConfig) -> None: + """Middleware that adds CORS validation to the application. + + Args: + app: The ``next`` ASGI app to call. + config: An instance of :class:`CORSConfig <litestar.config.cors.CORSConfig>` + """ + super().__init__(app=app, scopes={ScopeType.HTTP}) + self.config = config + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + headers = Headers.from_scope(scope=scope) + if origin := headers.get("origin"): + await self.app(scope, receive, self.send_wrapper(send=send, origin=origin, has_cookie="cookie" in headers)) + else: + await self.app(scope, receive, send) + + def send_wrapper(self, send: Send, origin: str, has_cookie: bool) -> Send: + """Wrap ``send`` to ensure that state is not disconnected. + + Args: + has_cookie: Boolean flag dictating if the connection has a cookie set. + origin: The value of the ``Origin`` header. + send: The ASGI send function. + + Returns: + An ASGI send function. + """ + + async def wrapped_send(message: Message) -> None: + if message["type"] == "http.response.start": + message.setdefault("headers", []) + headers = MutableScopeHeaders.from_message(message=message) + headers.update(self.config.simple_headers) + + if (self.config.is_allow_all_origins and has_cookie) or ( + not self.config.is_allow_all_origins and self.config.is_origin_allowed(origin=origin) + ): + headers["Access-Control-Allow-Origin"] = origin + headers["Vary"] = "Origin" + + # We don't want to overwrite this for preflight requests. + allow_headers = headers.get("Access-Control-Allow-Headers") + if not allow_headers and self.config.allow_headers: + headers["Access-Control-Allow-Headers"] = ", ".join(sorted(set(self.config.allow_headers))) + + allow_methods = headers.get("Access-Control-Allow-Methods") + if not allow_methods and self.config.allow_methods: + headers["Access-Control-Allow-Methods"] = ", ".join(sorted(set(self.config.allow_methods))) + + await send(message) + + return wrapped_send diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/csrf.py b/venv/lib/python3.11/site-packages/litestar/middleware/csrf.py new file mode 100644 index 0000000..94dd422 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/csrf.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import hashlib +import hmac +import secrets +from secrets import compare_digest +from typing import TYPE_CHECKING, Any + +from litestar.datastructures import MutableScopeHeaders +from litestar.datastructures.cookie import Cookie +from litestar.enums import RequestEncodingType, ScopeType +from litestar.exceptions import PermissionDeniedException +from litestar.middleware._utils import ( + build_exclude_path_pattern, + should_bypass_middleware, +) +from litestar.middleware.base import MiddlewareProtocol +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from litestar.config.csrf import CSRFConfig + from litestar.connection import Request + from litestar.types import ( + ASGIApp, + HTTPSendMessage, + Message, + Receive, + Scope, + Scopes, + Send, + ) + +__all__ = ("CSRFMiddleware",) + + +CSRF_SECRET_BYTES = 32 +CSRF_SECRET_LENGTH = CSRF_SECRET_BYTES * 2 + + +def generate_csrf_hash(token: str, secret: str) -> str: + """Generate an HMAC that signs the CSRF token. + + Args: + token: A hashed token. + secret: A secret value. + + Returns: + A CSRF hash. + """ + return hmac.new(secret.encode(), token.encode(), hashlib.sha256).hexdigest() + + +def generate_csrf_token(secret: str) -> str: + """Generate a CSRF token that includes a randomly generated string signed by an HMAC. + + Args: + secret: A secret string. + + Returns: + A unique CSRF token. + """ + token = secrets.token_hex(CSRF_SECRET_BYTES) + token_hash = generate_csrf_hash(token=token, secret=secret) + return token + token_hash + + +class CSRFMiddleware(MiddlewareProtocol): + """CSRF Middleware class. + + This Middleware protects against attacks by setting a CSRF cookie with a token and verifying it in request headers. + """ + + scopes: Scopes = {ScopeType.HTTP} + + def __init__(self, app: ASGIApp, config: CSRFConfig) -> None: + """Initialize ``CSRFMiddleware``. + + Args: + app: The ``next`` ASGI app to call. + config: The CSRFConfig instance. + """ + self.app = app + self.config = config + self.exclude = build_exclude_path_pattern(exclude=config.exclude) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + if scope["type"] != ScopeType.HTTP: + await self.app(scope, receive, send) + return + + request: Request[Any, Any, Any] = scope["app"].request_class(scope=scope, receive=receive) + content_type, _ = request.content_type + csrf_cookie = request.cookies.get(self.config.cookie_name) + existing_csrf_token = request.headers.get(self.config.header_name) + + if not existing_csrf_token and content_type in { + RequestEncodingType.URL_ENCODED, + RequestEncodingType.MULTI_PART, + }: + form = await request.form() + existing_csrf_token = form.get("_csrf_token", None) + + connection_state = ScopeState.from_scope(scope) + if request.method in self.config.safe_methods or should_bypass_middleware( + scope=scope, + scopes=self.scopes, + exclude_opt_key=self.config.exclude_from_csrf_key, + exclude_path_pattern=self.exclude, + ): + token = connection_state.csrf_token = csrf_cookie or generate_csrf_token(secret=self.config.secret) + await self.app(scope, receive, self.create_send_wrapper(send=send, csrf_cookie=csrf_cookie, token=token)) + elif ( + existing_csrf_token is not None + and csrf_cookie is not None + and self._csrf_tokens_match(existing_csrf_token, csrf_cookie) + ): + connection_state.csrf_token = existing_csrf_token + await self.app(scope, receive, send) + else: + raise PermissionDeniedException("CSRF token verification failed") + + def create_send_wrapper(self, send: Send, token: str, csrf_cookie: str | None) -> Send: + """Wrap ``send`` to handle CSRF validation. + + Args: + token: The CSRF token. + send: The ASGI send function. + csrf_cookie: CSRF cookie. + + Returns: + An ASGI send function. + """ + + async def send_wrapper(message: Message) -> None: + """Send function that wraps the original send to inject a cookie. + + Args: + message: An ASGI ``Message`` + + Returns: + None + """ + if csrf_cookie is None and message["type"] == "http.response.start": + message.setdefault("headers", []) + self._set_cookie_if_needed(message=message, token=token) + await send(message) + + return send_wrapper + + def _set_cookie_if_needed(self, message: HTTPSendMessage, token: str) -> None: + headers = MutableScopeHeaders.from_message(message) + cookie = Cookie( + key=self.config.cookie_name, + value=token, + path=self.config.cookie_path, + secure=self.config.cookie_secure, + httponly=self.config.cookie_httponly, + samesite=self.config.cookie_samesite, + domain=self.config.cookie_domain, + ) + headers.add("set-cookie", cookie.to_header(header="")) + + def _decode_csrf_token(self, token: str) -> str | None: + """Decode a CSRF token and validate its HMAC.""" + if len(token) < CSRF_SECRET_LENGTH + 1: + return None + + token_secret = token[:CSRF_SECRET_LENGTH] + existing_hash = token[CSRF_SECRET_LENGTH:] + expected_hash = generate_csrf_hash(token=token_secret, secret=self.config.secret) + return token_secret if compare_digest(existing_hash, expected_hash) else None + + def _csrf_tokens_match(self, request_csrf_token: str, cookie_csrf_token: str) -> bool: + """Take the CSRF tokens from the request and the cookie and verify both are valid and identical.""" + decoded_request_token = self._decode_csrf_token(request_csrf_token) + decoded_cookie_token = self._decode_csrf_token(cookie_csrf_token) + if decoded_request_token is None or decoded_cookie_token is None: + return False + + return compare_digest(decoded_request_token, decoded_cookie_token) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__init__.py b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__init__.py new file mode 100644 index 0000000..5328adf --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__init__.py @@ -0,0 +1,3 @@ +from litestar.middleware.exceptions.middleware import ExceptionHandlerMiddleware + +__all__ = ("ExceptionHandlerMiddleware",) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c443e00 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/_debug_response.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/_debug_response.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b41fc85 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/_debug_response.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/middleware.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/middleware.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2259206 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/__pycache__/middleware.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/_debug_response.py b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/_debug_response.py new file mode 100644 index 0000000..99e8c87 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/_debug_response.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +from html import escape +from inspect import getinnerframes +from pathlib import Path +from traceback import format_exception +from typing import TYPE_CHECKING, Any + +from litestar.enums import MediaType +from litestar.response import Response +from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR +from litestar.utils import get_name + +__all__ = ( + "create_debug_response", + "create_exception_html", + "create_frame_html", + "create_html_response_content", + "create_line_html", + "create_plain_text_response_content", + "get_symbol_name", +) + + +if TYPE_CHECKING: + from inspect import FrameInfo + + from litestar.connection import Request + from litestar.types import TypeEncodersMap + +tpl_dir = Path(__file__).parent / "templates" + + +def get_symbol_name(frame: FrameInfo) -> str: + """Return full name of the function that is being executed by the given frame. + + Args: + frame: An instance of [FrameInfo](https://docs.python.org/3/library/inspect.html#inspect.FrameInfo). + + Notes: + - class detection assumes standard names (self and cls) of params. + - if current class name can not be determined only function (method) name will be returned. + - we can not distinguish static methods from ordinary functions at the moment. + + Returns: + A string containing full function name. + """ + + locals_dict = frame.frame.f_locals + # this piece assumes that the code uses standard names "self" and "cls" + # in instance and class methods + instance_or_cls = inst if (inst := locals_dict.get("self")) is not None else locals_dict.get("cls") + + classname = f"{get_name(instance_or_cls)}." if instance_or_cls is not None else "" + + return f"{classname}{frame.function}" + + +def create_line_html( + line: str, + line_no: int, + frame_index: int, + idx: int, +) -> str: + """Produce HTML representation of a line including real line number in the source code. + + Args: + line: A string representing the current line. + line_no: The line number associated with the executed line. + frame_index: Index of the executed line in the code context. + idx: Index of the current line in the code context. + + Returns: + A string containing HTML representation of the given line. + """ + template = '<tr class="{line_class}"><td class="line_no">{line_no}</td><td class="code_line">{line}</td></tr>' + data = { + # line_no - frame_index produces actual line number of the very first line in the frame code context. + # so adding index (aka relative number) of a line in the code context we can calculate its actual number in the source file, + "line_no": line_no - frame_index + idx, + "line": escape(line).replace(" ", " "), + "line_class": "executed-line" if idx == frame_index else "", + } + return template.format(**data) + + +def create_frame_html(frame: FrameInfo, collapsed: bool) -> str: + """Produce HTML representation of the given frame object including filename containing source code and name of the + function being executed. + + Args: + frame: An instance of [FrameInfo](https://docs.python.org/3/library/inspect.html#inspect.FrameInfo). + collapsed: Flag controlling whether frame should be collapsed on the page load. + + Returns: + A string containing HTML representation of the execution frame. + """ + frame_tpl = (tpl_dir / "frame.html").read_text() + + code_lines: list[str] = [ + create_line_html(line, frame.lineno, frame.index or 0, idx) for idx, line in enumerate(frame.code_context or []) + ] + data = { + "file": escape(frame.filename), + "line": frame.lineno, + "symbol_name": escape(get_symbol_name(frame)), + "code": "".join(code_lines), + "frame_class": "collapsed" if collapsed else "", + } + return frame_tpl.format(**data) + + +def create_exception_html(exc: BaseException, line_limit: int) -> str: + """Produce HTML representation of the exception frames. + + Args: + exc: An Exception instance to generate. + line_limit: Number of lines of code context to return, which are centered around the executed line. + + Returns: + A string containing HTML representation of the execution frames related to the exception. + """ + frames = getinnerframes(exc.__traceback__, line_limit) if exc.__traceback__ else [] + result = [create_frame_html(frame=frame, collapsed=idx > 0) for idx, frame in enumerate(reversed(frames))] + return "".join(result) + + +def create_html_response_content(exc: Exception, request: Request, line_limit: int = 15) -> str: + """Given an exception, produces its traceback in HTML. + + Args: + exc: An Exception instance to render debug response from. + request: A :class:`Request <litestar.connection.Request>` instance. + line_limit: Number of lines of code context to return, which are centered around the executed line. + + Returns: + A string containing HTML page with exception traceback. + """ + exception_data: list[str] = [create_exception_html(exc, line_limit)] + cause = exc.__cause__ + while cause: + cause_data = create_exception_html(cause, line_limit) + cause_header = '<h4 class="cause-header">The above exception was caused by</h4>' + cause_error_description = f"<h3><span>{escape(str(cause))}</span></h3>" + cause_error = f"<h4><span>{escape(cause.__class__.__name__)}</span></h4>" + exception_data.append( + f'<div class="cause-wrapper">{cause_header}{cause_error}{cause_error_description}{cause_data}</div>' + ) + cause = cause.__cause__ + + scripts = (tpl_dir / "scripts.js").read_text() + styles = (tpl_dir / "styles.css").read_text() + body_tpl = (tpl_dir / "body.html").read_text() + return body_tpl.format( + scripts=scripts, + styles=styles, + error=f"<span>{escape(exc.__class__.__name__)}</span> on {request.method} {escape(request.url.path)}", + error_description=escape(str(exc)), + exception_data="".join(exception_data), + ) + + +def create_plain_text_response_content(exc: Exception) -> str: + """Given an exception, produces its traceback in plain text. + + Args: + exc: An Exception instance to render debug response from. + + Returns: + A string containing exception traceback. + """ + return "".join(format_exception(type(exc), value=exc, tb=exc.__traceback__)) + + +def create_debug_response(request: Request, exc: Exception) -> Response: + """Create debug response either in plain text or HTML depending on client capabilities. + + Args: + request: A :class:`Request <litestar.connection.Request>` instance. + exc: An Exception instance to render debug response from. + + Returns: + A response with a rendered exception traceback. + """ + if MediaType.HTML in request.headers.get("accept", ""): + content: Any = create_html_response_content(exc=exc, request=request) + media_type = MediaType.HTML + elif MediaType.JSON in request.headers.get("accept", ""): + content = {"details": create_plain_text_response_content(exc), "status_code": HTTP_500_INTERNAL_SERVER_ERROR} + media_type = MediaType.JSON + else: + content = create_plain_text_response_content(exc) + media_type = MediaType.TEXT + + return Response( + content=content, + media_type=media_type, + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + type_encoders=_get_type_encoders_for_request(request), + ) + + +def _get_type_encoders_for_request(request: Request) -> TypeEncodersMap | None: + try: + return request.route_handler.resolve_type_encoders() + # we might be in a 404, or before we could resolve the handler, so this + # could potentially error out. In this case we fall back on the application + # type encoders + except (KeyError, AttributeError): + return request.app.type_encoders diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/middleware.py b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/middleware.py new file mode 100644 index 0000000..f3ff157 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/middleware.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import pdb # noqa: T100 +from dataclasses import asdict, dataclass, field +from inspect import getmro +from sys import exc_info +from traceback import format_exception +from typing import TYPE_CHECKING, Any, Type, cast + +from litestar.datastructures import Headers +from litestar.enums import MediaType, ScopeType +from litestar.exceptions import HTTPException, LitestarException, WebSocketException +from litestar.middleware.cors import CORSMiddleware +from litestar.middleware.exceptions._debug_response import _get_type_encoders_for_request, create_debug_response +from litestar.serialization import encode_json +from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR +from litestar.utils.deprecation import warn_deprecation + +__all__ = ("ExceptionHandlerMiddleware", "ExceptionResponseContent", "create_exception_response") + + +if TYPE_CHECKING: + from starlette.exceptions import HTTPException as StarletteHTTPException + + from litestar import Response + from litestar.app import Litestar + from litestar.connection import Request + from litestar.logging import BaseLoggingConfig + from litestar.types import ( + ASGIApp, + ExceptionHandler, + ExceptionHandlersMap, + Logger, + Receive, + Scope, + Send, + ) + from litestar.types.asgi_types import WebSocketCloseEvent + + +def get_exception_handler(exception_handlers: ExceptionHandlersMap, exc: Exception) -> ExceptionHandler | None: + """Given a dictionary that maps exceptions and status codes to handler functions, and an exception, returns the + appropriate handler if existing. + + Status codes are given preference over exception type. + + If no status code match exists, each class in the MRO of the exception type is checked and + the first matching handler is returned. + + Finally, if a ``500`` handler is registered, it will be returned for any exception that isn't a + subclass of :class:`HTTPException <litestar.exceptions.HTTPException>`. + + Args: + exception_handlers: Mapping of status codes and exception types to handlers. + exc: Exception Instance to be resolved to a handler. + + Returns: + Optional exception handler callable. + """ + if not exception_handlers: + return None + + default_handler: ExceptionHandler | None = None + if isinstance(exc, HTTPException): + if exception_handler := exception_handlers.get(exc.status_code): + return exception_handler + else: + default_handler = exception_handlers.get(HTTP_500_INTERNAL_SERVER_ERROR) + + return next( + (exception_handlers[cast("Type[Exception]", cls)] for cls in getmro(type(exc)) if cls in exception_handlers), + default_handler, + ) + + +@dataclass +class ExceptionResponseContent: + """Represent the contents of an exception-response.""" + + status_code: int + """Exception status code.""" + detail: str + """Exception details or message.""" + media_type: MediaType | str + """Media type of the response.""" + headers: dict[str, str] | None = field(default=None) + """Headers to attach to the response.""" + extra: dict[str, Any] | list[Any] | None = field(default=None) + """An extra mapping to attach to the exception.""" + + def to_response(self, request: Request | None = None) -> Response: + """Create a response from the model attributes. + + Returns: + A response instance. + """ + from litestar.response import Response + + content: Any = {k: v for k, v in asdict(self).items() if k not in ("headers", "media_type") and v is not None} + + if self.media_type != MediaType.JSON: + content = encode_json(content) + + return Response( + content=content, + headers=self.headers, + status_code=self.status_code, + media_type=self.media_type, + type_encoders=_get_type_encoders_for_request(request) if request is not None else None, + ) + + +def _starlette_exception_handler(request: Request[Any, Any, Any], exc: StarletteHTTPException) -> Response: + return create_exception_response( + request=request, + exc=HTTPException( + detail=exc.detail, + status_code=exc.status_code, + headers=exc.headers, + ), + ) + + +def create_exception_response(request: Request[Any, Any, Any], exc: Exception) -> Response: + """Construct a response from an exception. + + Notes: + - For instances of :class:`HTTPException <litestar.exceptions.HTTPException>` or other exception classes that have a + ``status_code`` attribute (e.g. Starlette exceptions), the status code is drawn from the exception, otherwise + response status is ``HTTP_500_INTERNAL_SERVER_ERROR``. + + Args: + request: The request that triggered the exception. + exc: An exception. + + Returns: + Response: HTTP response constructed from exception details. + """ + headers: dict[str, Any] | None + extra: dict[str, Any] | list | None + + if isinstance(exc, HTTPException): + status_code = exc.status_code + headers = exc.headers + extra = exc.extra + else: + status_code = HTTP_500_INTERNAL_SERVER_ERROR + headers = None + extra = None + + detail = ( + exc.detail + if isinstance(exc, LitestarException) and status_code != HTTP_500_INTERNAL_SERVER_ERROR + else "Internal Server Error" + ) + + try: + media_type = request.route_handler.media_type + except (KeyError, AttributeError): + media_type = MediaType.JSON + + content = ExceptionResponseContent( + status_code=status_code, + detail=detail, + headers=headers, + extra=extra, + media_type=media_type, + ) + return content.to_response(request=request) + + +class ExceptionHandlerMiddleware: + """Middleware used to wrap an ASGIApp inside a try catch block and handle any exceptions raised. + + This used in multiple layers of Litestar. + """ + + def __init__(self, app: ASGIApp, debug: bool | None, exception_handlers: ExceptionHandlersMap) -> None: + """Initialize ``ExceptionHandlerMiddleware``. + + Args: + app: The ``next`` ASGI app to call. + debug: Whether ``debug`` mode is enabled. Deprecated. Debug mode will be inferred from the request scope + exception_handlers: A dictionary mapping status codes and/or exception types to handler functions. + + .. deprecated:: 2.0.0 + The ``debug`` parameter is deprecated. It will be inferred from the request scope + """ + self.app = app + self.exception_handlers = exception_handlers + self.debug = debug + if debug is not None: + warn_deprecation( + "2.0.0", + deprecated_name="debug", + kind="parameter", + info="Debug mode will be inferred from the request scope", + ) + + self._get_debug = self._get_debug_scope if debug is None else lambda *a: debug + + @staticmethod + def _get_debug_scope(scope: Scope) -> bool: + return scope["app"].debug + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI-callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + try: + await self.app(scope, receive, send) + except Exception as e: # noqa: BLE001 + litestar_app = scope["app"] + + if litestar_app.logging_config and (logger := litestar_app.logger): + self.handle_exception_logging(logger=logger, logging_config=litestar_app.logging_config, scope=scope) + + for hook in litestar_app.after_exception: + await hook(e, scope) + + if litestar_app.pdb_on_exception: + pdb.post_mortem() + + if scope["type"] == ScopeType.HTTP: + await self.handle_request_exception( + litestar_app=litestar_app, scope=scope, receive=receive, send=send, exc=e + ) + else: + await self.handle_websocket_exception(send=send, exc=e) + + async def handle_request_exception( + self, litestar_app: Litestar, scope: Scope, receive: Receive, send: Send, exc: Exception + ) -> None: + """Handle exception raised inside 'http' scope routes. + + Args: + litestar_app: The litestar app instance. + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + exc: The caught exception. + + Returns: + None. + """ + + headers = Headers.from_scope(scope=scope) + if litestar_app.cors_config and (origin := headers.get("origin")): + cors_middleware = CORSMiddleware(app=self.app, config=litestar_app.cors_config) + send = cors_middleware.send_wrapper(send=send, origin=origin, has_cookie="cookie" in headers) + + exception_handler = get_exception_handler(self.exception_handlers, exc) or self.default_http_exception_handler + request: Request[Any, Any, Any] = litestar_app.request_class(scope=scope, receive=receive, send=send) + response = exception_handler(request, exc) + await response.to_asgi_response(app=None, request=request)(scope=scope, receive=receive, send=send) + + @staticmethod + async def handle_websocket_exception(send: Send, exc: Exception) -> None: + """Handle exception raised inside 'websocket' scope routes. + + Args: + send: The ASGI send function. + exc: The caught exception. + + Returns: + None. + """ + code = 4000 + HTTP_500_INTERNAL_SERVER_ERROR + reason = "Internal Server Error" + if isinstance(exc, WebSocketException): + code = exc.code + reason = exc.detail + elif isinstance(exc, LitestarException): + reason = exc.detail + + event: WebSocketCloseEvent = {"type": "websocket.close", "code": code, "reason": reason} + await send(event) + + def default_http_exception_handler(self, request: Request, exc: Exception) -> Response[Any]: + """Handle an HTTP exception by returning the appropriate response. + + Args: + request: An HTTP Request instance. + exc: The caught exception. + + Returns: + An HTTP response. + """ + status_code = exc.status_code if isinstance(exc, HTTPException) else HTTP_500_INTERNAL_SERVER_ERROR + if status_code == HTTP_500_INTERNAL_SERVER_ERROR and self._get_debug_scope(request.scope): + return create_debug_response(request=request, exc=exc) + return create_exception_response(request=request, exc=exc) + + def handle_exception_logging(self, logger: Logger, logging_config: BaseLoggingConfig, scope: Scope) -> None: + """Handle logging - if the litestar app has a logging config in place. + + Args: + logger: A logger instance. + logging_config: Logging Config instance. + scope: The ASGI connection scope. + + Returns: + None + """ + if ( + logging_config.log_exceptions == "always" + or (logging_config.log_exceptions == "debug" and self._get_debug_scope(scope)) + ) and logging_config.exception_logging_handler: + logging_config.exception_logging_handler(logger, scope, format_exception(*exc_info())) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/body.html b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/body.html new file mode 100644 index 0000000..1c6705c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/body.html @@ -0,0 +1,20 @@ +<!doctype html> + +<html lang="en"> + <head> + <meta charset="utf-8" /> + <style type="text/css"> + {styles} + </style> + <title>Litestar exception page</title> + </head> + <body> + <h4>{error}</h4> + <h3><span>{error_description}</span></h3> + {exception_data} + <script type="text/javascript"> + // prettier-ignore + {scripts} // NOSONAR + </script> + </body> +</html> diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/frame.html b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/frame.html new file mode 100644 index 0000000..2ead8dd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/frame.html @@ -0,0 +1,12 @@ +<div class="frame {frame_class}"> + <div class="frame-name"> + <span class="expander">â–¼</span> + <span class="breakable">{file}</span> in <span>{symbol_name}</span> at line + <span>{line}</span> + </div> + <div class="code-snippet-wrapper"> + <table role="presentation" class="code-snippet"> + {code} + </table> + </div> +</div> diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/scripts.js b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/scripts.js new file mode 100644 index 0000000..014a256 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/scripts.js @@ -0,0 +1,27 @@ +const expanders = document.querySelectorAll(".frame .expander"); + +for (const expander of expanders) { + expander.addEventListener("click", (evt) => { + const currentSnippet = evt.currentTarget.closest(".frame"); + const snippetWrapper = currentSnippet.querySelector( + ".code-snippet-wrapper", + ); + if (currentSnippet.classList.contains("collapsed")) { + snippetWrapper.style.height = `${snippetWrapper.scrollHeight}px`; + currentSnippet.classList.remove("collapsed"); + } else { + currentSnippet.classList.add("collapsed"); + snippetWrapper.style.height = "0px"; + } + }); +} + +// init height for non-collapsed code snippets so animation will be show +// their first collapse +const nonCollapsedSnippets = document.querySelectorAll( + ".frame:not(.collapsed) .code-snippet-wrapper", +); + +for (const snippet of nonCollapsedSnippets) { + snippet.style.height = `${snippet.scrollHeight}px`; +} diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/styles.css b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/styles.css new file mode 100644 index 0000000..6b98b89 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/exceptions/templates/styles.css @@ -0,0 +1,121 @@ +:root { + --code-background-color: #f5f5f5; + --code-background-color-dark: #b8b8b8; + --code-color: #1d2534; + --code-color-light: #546996; + --code-font-family: Consolas, monospace; + --header-color: #303b55; + --warn-color: hsl(356, 92%, 60%); + --text-font-family: -apple-system, BlinkMacSystemFont, Helvetica, Arial, + sans-serif; +} + +html { + font-size: 20px; +} + +body { + font-family: var(--text-font-family); + font-size: 0.8rem; +} + +h1, +h2, +h3, +h4 { + color: var(--header-color); +} + +h4 { + font-size: 1rem; +} + +h3 { + font-size: 1.35rem; +} + +h2 { + font-size: 1.83rem; +} + +h3 span, +h4 span { + color: var(--warn-color); +} + +.frame { + background-color: var(--code-background-color); + border-radius: 0.2rem; + margin-bottom: 20px; +} + +.frame-name { + border-bottom: 1px solid var(--code-color-light); + padding: 10px 16px; +} + +.frame.collapsed .frame-name { + border-bottom: none; +} + +.frame-name span { + font-weight: 700; +} + +span.expander { + display: inline-block; + margin-right: 10px; + cursor: pointer; + transition: transform 0.33s ease-in-out; +} + +.frame.collapsed span.expander { + transform: rotate(-90deg); +} + +.frame-name span.breakable { + word-break: break-all; +} + +.code-snippet-wrapper { + height: auto; + overflow-y: hidden; + transition: height 0.33s ease-in-out; +} + +.frame.collapsed .code-snippet-wrapper { + height: 0; +} + +.code-snippet { + margin: 10px 16px; + border-spacing: 0 0; + color: var(--code-color); + font-family: var(--code-font-family); + font-size: 0.68rem; +} + +.code-snippet td { + padding: 0; + text-align: left; +} + +td.line_no { + color: var(--code-color-light); + min-width: 4ch; + padding-right: 20px; + text-align: right; + user-select: none; +} + +td.code_line { + width: 99%; +} + +tr.executed-line { + background-color: var(--code-background-color-dark); +} + +.cause-wrapper { + margin-top: 50px; +} diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/logging.py b/venv/lib/python3.11/site-packages/litestar/middleware/logging.py new file mode 100644 index 0000000..0094f10 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/logging.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Iterable + +from litestar.constants import ( + HTTP_RESPONSE_BODY, + HTTP_RESPONSE_START, +) +from litestar.data_extractors import ( + ConnectionDataExtractor, + RequestExtractorField, + ResponseDataExtractor, + ResponseExtractorField, +) +from litestar.enums import ScopeType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.middleware.base import AbstractMiddleware, DefineMiddleware +from litestar.serialization import encode_json +from litestar.utils.empty import value_or_default +from litestar.utils.scope import get_serializer_from_scope +from litestar.utils.scope.state import ScopeState + +__all__ = ("LoggingMiddleware", "LoggingMiddlewareConfig") + + +if TYPE_CHECKING: + from litestar.connection import Request + from litestar.types import ( + ASGIApp, + Logger, + Message, + Receive, + Scope, + Send, + Serializer, + ) + +try: + from structlog.types import BindableLogger + + structlog_installed = True +except ImportError: + BindableLogger = object # type: ignore[assignment, misc] + structlog_installed = False + + +class LoggingMiddleware(AbstractMiddleware): + """Logging middleware.""" + + __slots__ = ("config", "logger", "request_extractor", "response_extractor", "is_struct_logger") + + logger: Logger + + def __init__(self, app: ASGIApp, config: LoggingMiddlewareConfig) -> None: + """Initialize ``LoggingMiddleware``. + + Args: + app: The ``next`` ASGI app to call. + config: An instance of LoggingMiddlewareConfig. + """ + super().__init__( + app=app, scopes={ScopeType.HTTP}, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key + ) + self.is_struct_logger = structlog_installed + self.config = config + + self.request_extractor = ConnectionDataExtractor( + extract_body="body" in self.config.request_log_fields, + extract_client="client" in self.config.request_log_fields, + extract_content_type="content_type" in self.config.request_log_fields, + extract_cookies="cookies" in self.config.request_log_fields, + extract_headers="headers" in self.config.request_log_fields, + extract_method="method" in self.config.request_log_fields, + extract_path="path" in self.config.request_log_fields, + extract_path_params="path_params" in self.config.request_log_fields, + extract_query="query" in self.config.request_log_fields, + extract_scheme="scheme" in self.config.request_log_fields, + obfuscate_cookies=self.config.request_cookies_to_obfuscate, + obfuscate_headers=self.config.request_headers_to_obfuscate, + parse_body=self.is_struct_logger, + parse_query=self.is_struct_logger, + skip_parse_malformed_body=True, + ) + self.response_extractor = ResponseDataExtractor( + extract_body="body" in self.config.response_log_fields, + extract_headers="headers" in self.config.response_log_fields, + extract_status_code="status_code" in self.config.response_log_fields, + obfuscate_cookies=self.config.response_cookies_to_obfuscate, + obfuscate_headers=self.config.response_headers_to_obfuscate, + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + if not hasattr(self, "logger"): + self.logger = scope["app"].get_logger(self.config.logger_name) + self.is_struct_logger = structlog_installed and repr(self.logger).startswith("<BoundLoggerLazyProxy") + + if self.config.response_log_fields: + send = self.create_send_wrapper(scope=scope, send=send) + + if self.config.request_log_fields: + await self.log_request(scope=scope, receive=receive) + + await self.app(scope, receive, send) + + async def log_request(self, scope: Scope, receive: Receive) -> None: + """Extract request data and log the message. + + Args: + scope: The ASGI connection scope. + receive: ASGI receive callable + + Returns: + None + """ + extracted_data = await self.extract_request_data(request=scope["app"].request_class(scope, receive)) + self.log_message(values=extracted_data) + + def log_response(self, scope: Scope) -> None: + """Extract the response data and log the message. + + Args: + scope: The ASGI connection scope. + + Returns: + None + """ + extracted_data = self.extract_response_data(scope=scope) + self.log_message(values=extracted_data) + + def log_message(self, values: dict[str, Any]) -> None: + """Log a message. + + Args: + values: Extract values to log. + + Returns: + None + """ + message = values.pop("message") + if self.is_struct_logger: + self.logger.info(message, **values) + else: + value_strings = [f"{key}={value}" for key, value in values.items()] + log_message = f"{message}: {', '.join(value_strings)}" + self.logger.info(log_message) + + def _serialize_value(self, serializer: Serializer | None, value: Any) -> Any: + if not self.is_struct_logger and isinstance(value, (dict, list, tuple, set)): + value = encode_json(value, serializer) + return value.decode("utf-8", errors="backslashreplace") if isinstance(value, bytes) else value + + async def extract_request_data(self, request: Request) -> dict[str, Any]: + """Create a dictionary of values for the message. + + Args: + request: A :class:`Request <litestar.connection.Request>` instance. + + Returns: + An dict. + """ + + data: dict[str, Any] = {"message": self.config.request_log_message} + serializer = get_serializer_from_scope(request.scope) + + extracted_data = await self.request_extractor.extract(connection=request, fields=self.config.request_log_fields) + + for key in self.config.request_log_fields: + data[key] = self._serialize_value(serializer, extracted_data.get(key)) + return data + + def extract_response_data(self, scope: Scope) -> dict[str, Any]: + """Extract data from the response. + + Args: + scope: The ASGI connection scope. + + Returns: + An dict. + """ + data: dict[str, Any] = {"message": self.config.response_log_message} + serializer = get_serializer_from_scope(scope) + connection_state = ScopeState.from_scope(scope) + extracted_data = self.response_extractor( + messages=( + connection_state.log_context.pop(HTTP_RESPONSE_START), + connection_state.log_context.pop(HTTP_RESPONSE_BODY), + ), + ) + response_body_compressed = value_or_default(connection_state.response_compressed, False) + for key in self.config.response_log_fields: + value: Any + value = extracted_data.get(key) + if key == "body" and response_body_compressed: + if self.config.include_compressed_body: + data[key] = value + continue + data[key] = self._serialize_value(serializer, value) + return data + + def create_send_wrapper(self, scope: Scope, send: Send) -> Send: + """Create a ``send`` wrapper, which handles logging response data. + + Args: + scope: The ASGI connection scope. + send: The ASGI send function. + + Returns: + An ASGI send function. + """ + connection_state = ScopeState.from_scope(scope) + + async def send_wrapper(message: Message) -> None: + if message["type"] == HTTP_RESPONSE_START: + connection_state.log_context[HTTP_RESPONSE_START] = message + elif message["type"] == HTTP_RESPONSE_BODY: + connection_state.log_context[HTTP_RESPONSE_BODY] = message + self.log_response(scope=scope) + await send(message) + + return send_wrapper + + +@dataclass +class LoggingMiddlewareConfig: + """Configuration for ``LoggingMiddleware``""" + + exclude: str | list[str] | None = field(default=None) + """List of paths to exclude from logging.""" + exclude_opt_key: str | None = field(default=None) + """An identifier to use on routes to disable logging for a particular route.""" + include_compressed_body: bool = field(default=False) + """Include body of compressed response in middleware. If `"body"` not set in. + :attr:`response_log_fields <LoggingMiddlewareConfig.response_log_fields>` this config value is ignored. + """ + logger_name: str = field(default="litestar") + """Name of the logger to retrieve using `app.get_logger("<name>")`.""" + request_cookies_to_obfuscate: set[str] = field(default_factory=lambda: {"session"}) + """Request cookie keys to obfuscate. + + Obfuscated values are replaced with '*****'. + """ + request_headers_to_obfuscate: set[str] = field(default_factory=lambda: {"Authorization", "X-API-KEY"}) + """Request header keys to obfuscate. + + Obfuscated values are replaced with '*****'. + """ + response_cookies_to_obfuscate: set[str] = field(default_factory=lambda: {"session"}) + """Response cookie keys to obfuscate. + + Obfuscated values are replaced with '*****'. + """ + response_headers_to_obfuscate: set[str] = field(default_factory=lambda: {"Authorization", "X-API-KEY"}) + """Response header keys to obfuscate. + + Obfuscated values are replaced with '*****'. + """ + request_log_message: str = field(default="HTTP Request") + """Log message to prepend when logging a request.""" + response_log_message: str = field(default="HTTP Response") + """Log message to prepend when logging a response.""" + request_log_fields: Iterable[RequestExtractorField] = field( + default=( + "path", + "method", + "content_type", + "headers", + "cookies", + "query", + "path_params", + "body", + ) + ) + """Fields to extract and log from the request. + + Notes: + - The order of fields in the iterable determines the order of the log message logged out. + Thus, re-arranging the log-message is as simple as changing the iterable. + - To turn off logging of requests, use and empty iterable. + """ + response_log_fields: Iterable[ResponseExtractorField] = field( + default=( + "status_code", + "cookies", + "headers", + "body", + ) + ) + """Fields to extract and log from the response. The order of fields in the iterable determines the order of the log + message logged out. + + Notes: + - The order of fields in the iterable determines the order of the log message logged out. + Thus, re-arranging the log-message is as simple as changing the iterable. + - To turn off logging of responses, use and empty iterable. + """ + middleware_class: type[LoggingMiddleware] = field(default=LoggingMiddleware) + """Middleware class to use. + + Should be a subclass of [litestar.middleware.LoggingMiddleware]. + """ + + def __post_init__(self) -> None: + """Override default Pydantic type conversion for iterables. + + Args: + value: An iterable + + Returns: + The `value` argument cast as a tuple. + """ + if not isinstance(self.response_log_fields, Iterable): + raise ImproperlyConfiguredException("response_log_fields must be a valid Iterable") + + if not isinstance(self.request_log_fields, Iterable): + raise ImproperlyConfiguredException("request_log_fields must be a valid Iterable") + + self.response_log_fields = tuple(self.response_log_fields) + self.request_log_fields = tuple(self.request_log_fields) + + @property + def middleware(self) -> DefineMiddleware: + """Use this property to insert the config into a middleware list on one of the application layers. + + Examples: + .. code-block:: python + + from litestar import Litestar, Request, get + from litestar.logging import LoggingConfig + from litestar.middleware.logging import LoggingMiddlewareConfig + + logging_config = LoggingConfig() + + logging_middleware_config = LoggingMiddlewareConfig() + + + @get("/") + def my_handler(request: Request) -> None: ... + + + app = Litestar( + route_handlers=[my_handler], + logging_config=logging_config, + middleware=[logging_middleware_config.middleware], + ) + + Returns: + An instance of DefineMiddleware including ``self`` as the config kwarg value. + """ + return DefineMiddleware(self.middleware_class, config=self) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py b/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py new file mode 100644 index 0000000..cd767ba --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/rate_limit.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from time import time +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +from litestar.datastructures import MutableScopeHeaders +from litestar.enums import ScopeType +from litestar.exceptions import TooManyRequestsException +from litestar.middleware.base import AbstractMiddleware, DefineMiddleware +from litestar.serialization import decode_json, encode_json +from litestar.utils import ensure_async_callable + +__all__ = ("CacheObject", "RateLimitConfig", "RateLimitMiddleware") + + +if TYPE_CHECKING: + from typing import Awaitable + + from litestar import Litestar + from litestar.connection import Request + from litestar.stores.base import Store + from litestar.types import ASGIApp, Message, Receive, Scope, Send, SyncOrAsyncUnion + + +DurationUnit = Literal["second", "minute", "hour", "day"] + +DURATION_VALUES: dict[DurationUnit, int] = {"second": 1, "minute": 60, "hour": 3600, "day": 86400} + + +@dataclass +class CacheObject: + """Representation of a cached object's metadata.""" + + __slots__ = ("history", "reset") + + history: list[int] + reset: int + + +class RateLimitMiddleware(AbstractMiddleware): + """Rate-limiting middleware.""" + + __slots__ = ("app", "check_throttle_handler", "max_requests", "unit", "request_quota", "config") + + def __init__(self, app: ASGIApp, config: RateLimitConfig) -> None: + """Initialize ``RateLimitMiddleware``. + + Args: + app: The ``next`` ASGI app to call. + config: An instance of RateLimitConfig. + """ + super().__init__( + app=app, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key, scopes={ScopeType.HTTP} + ) + self.check_throttle_handler = cast("Callable[[Request], Awaitable[bool]] | None", config.check_throttle_handler) + self.config = config + self.max_requests: int = config.rate_limit[1] + self.unit: DurationUnit = config.rate_limit[0] + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + app = scope["app"] + request: Request[Any, Any, Any] = app.request_class(scope) + store = self.config.get_store_from_app(app) + if await self.should_check_request(request=request): + key = self.cache_key_from_request(request=request) + cache_object = await self.retrieve_cached_history(key, store) + if len(cache_object.history) >= self.max_requests: + raise TooManyRequestsException( + headers=self.create_response_headers(cache_object=cache_object) + if self.config.set_rate_limit_headers + else None + ) + await self.set_cached_history(key=key, cache_object=cache_object, store=store) + if self.config.set_rate_limit_headers: + send = self.create_send_wrapper(send=send, cache_object=cache_object) + + await self.app(scope, receive, send) # pyright: ignore + + def create_send_wrapper(self, send: Send, cache_object: CacheObject) -> Send: + """Create a ``send`` function that wraps the original send to inject response headers. + + Args: + send: The ASGI send function. + cache_object: A StorageObject instance. + + Returns: + Send wrapper callable. + """ + + async def send_wrapper(message: Message) -> None: + """Wrap the ASGI ``Send`` callable. + + Args: + message: An ASGI ``Message`` + + Returns: + None + """ + if message["type"] == "http.response.start": + message.setdefault("headers", []) + headers = MutableScopeHeaders(message) + for key, value in self.create_response_headers(cache_object=cache_object).items(): + headers.add(key, value) + await send(message) + + return send_wrapper + + def cache_key_from_request(self, request: Request[Any, Any, Any]) -> str: + """Get a cache-key from a ``Request`` + + Args: + request: A :class:`Request <.connection.Request>` instance. + + Returns: + A cache key. + """ + host = request.client.host if request.client else "anonymous" + identifier = request.headers.get("X-Forwarded-For") or request.headers.get("X-Real-IP") or host + route_handler = request.scope["route_handler"] + if getattr(route_handler, "is_mount", False): + identifier += "::mount" + + if getattr(route_handler, "is_static", False): + identifier += "::static" + + return f"{type(self).__name__}::{identifier}" + + async def retrieve_cached_history(self, key: str, store: Store) -> CacheObject: + """Retrieve a list of time stamps for the given duration unit. + + Args: + key: Cache key. + store: A :class:`Store <.stores.base.Store>` + + Returns: + An :class:`CacheObject`. + """ + duration = DURATION_VALUES[self.unit] + now = int(time()) + cached_string = await store.get(key) + if cached_string: + cache_object = CacheObject(**decode_json(value=cached_string)) + if cache_object.reset <= now: + return CacheObject(history=[], reset=now + duration) + + while cache_object.history and cache_object.history[-1] <= now - duration: + cache_object.history.pop() + return cache_object + + return CacheObject(history=[], reset=now + duration) + + async def set_cached_history(self, key: str, cache_object: CacheObject, store: Store) -> None: + """Store history extended with the current timestamp in cache. + + Args: + key: Cache key. + cache_object: A :class:`CacheObject`. + store: A :class:`Store <.stores.base.Store>` + + Returns: + None + """ + cache_object.history = [int(time()), *cache_object.history] + await store.set(key, encode_json(cache_object), expires_in=DURATION_VALUES[self.unit]) + + async def should_check_request(self, request: Request[Any, Any, Any]) -> bool: + """Return a boolean indicating if a request should be checked for rate limiting. + + Args: + request: A :class:`Request <.connection.Request>` instance. + + Returns: + Boolean dictating whether the request should be checked for rate-limiting. + """ + if self.check_throttle_handler: + return await self.check_throttle_handler(request) + return True + + def create_response_headers(self, cache_object: CacheObject) -> dict[str, str]: + """Create ratelimit response headers. + + Notes: + * see the `IETF RateLimit draft <https://datatracker.ietf.org/doc/draft-ietf-httpapi-ratelimit-headers/>_` + + Args: + cache_object:A :class:`CacheObject`. + + Returns: + A dict of http headers. + """ + remaining_requests = str( + len(cache_object.history) - self.max_requests if len(cache_object.history) <= self.max_requests else 0 + ) + + return { + self.config.rate_limit_policy_header_key: f"{self.max_requests}; w={DURATION_VALUES[self.unit]}", + self.config.rate_limit_limit_header_key: str(self.max_requests), + self.config.rate_limit_remaining_header_key: remaining_requests, + self.config.rate_limit_reset_header_key: str(int(time()) - cache_object.reset), + } + + +@dataclass +class RateLimitConfig: + """Configuration for ``RateLimitMiddleware``""" + + rate_limit: tuple[DurationUnit, int] + """A tuple containing a time unit (second, minute, hour, day) and quantity, e.g. ("day", 1) or ("minute", 5).""" + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns to skip in the rate limiting middleware.""" + exclude_opt_key: str | None = field(default=None) + """An identifier to use on routes to disable rate limiting for a particular route.""" + check_throttle_handler: Callable[[Request[Any, Any, Any]], SyncOrAsyncUnion[bool]] | None = field(default=None) + """Handler callable that receives the request instance, returning a boolean dictating whether or not the request + should be checked for rate limiting. + """ + middleware_class: type[RateLimitMiddleware] = field(default=RateLimitMiddleware) + """The middleware class to use.""" + set_rate_limit_headers: bool = field(default=True) + """Boolean dictating whether to set the rate limit headers on the response.""" + rate_limit_policy_header_key: str = field(default="RateLimit-Policy") + """Key to use for the rate limit policy header.""" + rate_limit_remaining_header_key: str = field(default="RateLimit-Remaining") + """Key to use for the rate limit remaining header.""" + rate_limit_reset_header_key: str = field(default="RateLimit-Reset") + """Key to use for the rate limit reset header.""" + rate_limit_limit_header_key: str = field(default="RateLimit-Limit") + """Key to use for the rate limit limit header.""" + store: str = "rate_limit" + """Name of the :class:`Store <.stores.base.Store>` to use""" + + def __post_init__(self) -> None: + if self.check_throttle_handler: + self.check_throttle_handler = ensure_async_callable(self.check_throttle_handler) # type: ignore[arg-type] + + @property + def middleware(self) -> DefineMiddleware: + """Use this property to insert the config into a middleware list on one of the application layers. + + Examples: + .. code-block:: python + + from litestar import Litestar, Request, get + from litestar.middleware.rate_limit import RateLimitConfig + + # limit to 10 requests per minute, excluding the schema path + throttle_config = RateLimitConfig(rate_limit=("minute", 10), exclude=["/schema"]) + + + @get("/") + def my_handler(request: Request) -> None: ... + + + app = Litestar(route_handlers=[my_handler], middleware=[throttle_config.middleware]) + + Returns: + An instance of :class:`DefineMiddleware <.middleware.base.DefineMiddleware>` including ``self`` as the + config kwarg value. + """ + return DefineMiddleware(self.middleware_class, config=self) + + def get_store_from_app(self, app: Litestar) -> Store: + """Get the store defined in :attr:`store` from an :class:`Litestar <.app.Litestar>` instance.""" + return app.stores.get(self.store) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/response_cache.py b/venv/lib/python3.11/site-packages/litestar/middleware/response_cache.py new file mode 100644 index 0000000..62dcde6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/response_cache.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from msgspec.msgpack import encode as encode_msgpack + +from litestar import Request +from litestar.constants import HTTP_RESPONSE_BODY, HTTP_RESPONSE_START +from litestar.enums import ScopeType +from litestar.utils.empty import value_or_default +from litestar.utils.scope.state import ScopeState + +from .base import AbstractMiddleware + +if TYPE_CHECKING: + from litestar.config.response_cache import ResponseCacheConfig + from litestar.handlers import HTTPRouteHandler + from litestar.types import ASGIApp, HTTPScope, Message, Receive, Scope, Send + +__all__ = ["ResponseCacheMiddleware"] + + +class ResponseCacheMiddleware(AbstractMiddleware): + def __init__(self, app: ASGIApp, config: ResponseCacheConfig) -> None: + self.config = config + super().__init__(app=app, scopes={ScopeType.HTTP}) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + route_handler = cast("HTTPRouteHandler", scope["route_handler"]) + + expires_in: int | None = None + if route_handler.cache is True: + expires_in = self.config.default_expiration + elif route_handler.cache is not False and isinstance(route_handler.cache, int): + expires_in = route_handler.cache + + connection_state = ScopeState.from_scope(scope) + + messages: list[Message] = [] + + async def wrapped_send(message: Message) -> None: + if not value_or_default(connection_state.is_cached, False): + if message["type"] == HTTP_RESPONSE_START: + do_cache = connection_state.do_cache = self.config.cache_response_filter( + cast("HTTPScope", scope), message["status"] + ) + if do_cache: + messages.append(message) + elif value_or_default(connection_state.do_cache, False): + messages.append(message) + + if messages and message["type"] == HTTP_RESPONSE_BODY and not message["more_body"]: + key = (route_handler.cache_key_builder or self.config.key_builder)(Request(scope)) + store = self.config.get_store_from_app(scope["app"]) + await store.set(key, encode_msgpack(messages), expires_in=expires_in) + await send(message) + + await self.app(scope, receive, wrapped_send) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/__init__.py b/venv/lib/python3.11/site-packages/litestar/middleware/session/__init__.py new file mode 100644 index 0000000..1ca9c17 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/__init__.py @@ -0,0 +1,3 @@ +from .base import SessionMiddleware + +__all__ = ("SessionMiddleware",) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8748ce3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..68a8b9c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/client_side.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/client_side.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..692f54c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/client_side.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/server_side.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/server_side.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..bd2373c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/__pycache__/server_side.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/base.py b/venv/lib/python3.11/site-packages/litestar/middleware/session/base.py new file mode 100644 index 0000000..a823848 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/base.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Generic, + Literal, + TypeVar, + cast, +) + +from litestar.connection import ASGIConnection +from litestar.enums import ScopeType +from litestar.middleware.base import AbstractMiddleware, DefineMiddleware +from litestar.serialization import decode_json, encode_json +from litestar.utils import get_serializer_from_scope + +__all__ = ("BaseBackendConfig", "BaseSessionBackend", "SessionMiddleware") + + +if TYPE_CHECKING: + from litestar.types import ASGIApp, Message, Receive, Scope, Scopes, ScopeSession, Send + +ONE_DAY_IN_SECONDS = 60 * 60 * 24 + +ConfigT = TypeVar("ConfigT", bound="BaseBackendConfig") +BaseSessionBackendT = TypeVar("BaseSessionBackendT", bound="BaseSessionBackend") + + +class BaseBackendConfig(ABC, Generic[BaseSessionBackendT]): # pyright: ignore + """Configuration for Session middleware backends.""" + + _backend_class: type[BaseSessionBackendT] # pyright: ignore + + key: str + """Key to use for the cookie inside the header, e.g. ``session=<data>`` where ``session`` is the cookie key and + ``<data>`` is the session data. + + Notes: + - If a session cookie exceeds 4KB in size it is split. In this case the key will be of the format + ``session-{segment number}``. + + """ + max_age: int + """Maximal age of the cookie before its invalidated.""" + scopes: Scopes = {ScopeType.HTTP, ScopeType.WEBSOCKET} + """Scopes for the middleware - options are ``http`` and ``websocket`` with the default being both""" + path: str + """Path fragment that must exist in the request url for the cookie to be valid. + + Defaults to ``'/'``. + """ + domain: str | None + """Domain for which the cookie is valid.""" + secure: bool + """Https is required for the cookie.""" + httponly: bool + """Forbids javascript to access the cookie via 'Document.cookie'.""" + samesite: Literal["lax", "strict", "none"] + """Controls whether or not a cookie is sent with cross-site requests. + + Defaults to ``lax``. + """ + exclude: str | list[str] | None + """A pattern or list of patterns to skip in the session middleware.""" + exclude_opt_key: str + """An identifier to use on routes to disable the session middleware for a particular route.""" + + @property + def middleware(self) -> DefineMiddleware: + """Use this property to insert the config into a middleware list on one of the application layers. + + Examples: + .. code-block:: python + + from os import urandom + + from litestar import Litestar, Request, get + from litestar.middleware.sessions.cookie_backend import CookieBackendConfig + + session_config = CookieBackendConfig(secret=urandom(16)) + + + @get("/") + def my_handler(request: Request) -> None: ... + + + app = Litestar(route_handlers=[my_handler], middleware=[session_config.middleware]) + + + Returns: + An instance of DefineMiddleware including ``self`` as the config kwarg value. + """ + return DefineMiddleware(SessionMiddleware, backend=self._backend_class(config=self)) + + +class BaseSessionBackend(ABC, Generic[ConfigT]): + """Abstract session backend defining the interface between a storage mechanism and the application + :class:`SessionMiddleware`. + + This serves as the base class for all client- and server-side backends + """ + + __slots__ = ("config",) + + def __init__(self, config: ConfigT) -> None: + """Initialize ``BaseSessionBackend`` + + Args: + config: A instance of a subclass of ``BaseBackendConfig`` + """ + self.config = config + + @staticmethod + def serialize_data(data: ScopeSession, scope: Scope | None = None) -> bytes: + """Serialize data into bytes for storage in the backend. + + Args: + data: Session data of the current scope. + scope: A scope, if applicable, from which to extract a serializer. + + Notes: + - The serializer will be extracted from ``scope`` or fall back to + :func:`default_serializer <.serialization.default_serializer>` + + Returns: + ``data`` serialized as bytes. + """ + serializer = get_serializer_from_scope(scope) if scope else None + return encode_json(data, serializer) + + @staticmethod + def deserialize_data(data: Any) -> dict[str, Any]: + """Deserialize data into a dictionary for use in the application scope. + + Args: + data: Data to be deserialized + + Returns: + Deserialized data as a dictionary + """ + return cast("dict[str, Any]", decode_json(value=data)) + + @abstractmethod + def get_session_id(self, connection: ASGIConnection) -> str | None: + """Try to fetch session id from connection ScopeState. If one does not exist, generate one. + + Args: + connection: Originating ASGIConnection containing the scope + + Returns: + Session id str or None if the concept of a session id does not apply. + """ + + @abstractmethod + async def store_in_message(self, scope_session: ScopeSession, message: Message, connection: ASGIConnection) -> None: + """Store the necessary information in the outgoing ``Message`` + + Args: + scope_session: Current session to store + message: Outgoing send-message + connection: Originating ASGIConnection containing the scope + + Returns: + None + """ + + @abstractmethod + async def load_from_connection(self, connection: ASGIConnection) -> dict[str, Any]: + """Load session data from a connection and return it as a dictionary to be used in the current application + scope. + + Args: + connection: An ASGIConnection instance + + Returns: + The session data + + Notes: + - This should not modify the connection's scope. The data returned by this + method will be stored in the application scope by the middleware + + """ + + +class SessionMiddleware(AbstractMiddleware, Generic[BaseSessionBackendT]): + """Litestar session middleware for storing session data.""" + + def __init__(self, app: ASGIApp, backend: BaseSessionBackendT) -> None: + """Initialize ``SessionMiddleware`` + + Args: + app: An ASGI application + backend: A :class:`BaseSessionBackend` instance used to store and retrieve session data + """ + + super().__init__( + app=app, + exclude=backend.config.exclude, + exclude_opt_key=backend.config.exclude_opt_key, + scopes=backend.config.scopes, + ) + self.backend = backend + + def create_send_wrapper(self, connection: ASGIConnection) -> Callable[[Message], Awaitable[None]]: + """Create a wrapper for the ASGI send function, which handles setting the cookies on the outgoing response. + + Args: + connection: ASGIConnection + + Returns: + None + """ + + async def wrapped_send(message: Message) -> None: + """Wrap the ``send`` function. + + Declared in local scope to make use of closure values. + + Args: + message: An ASGI message. + + Returns: + None + """ + if message["type"] != "http.response.start": + await connection.send(message) + return + + scope_session = connection.scope.get("session") + + await self.backend.store_in_message(scope_session, message, connection) + await connection.send(message) + + return wrapped_send + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI-callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + + connection = ASGIConnection[Any, Any, Any, Any](scope, receive=receive, send=send) + scope["session"] = await self.backend.load_from_connection(connection) + connection._connection_state.session_id = self.backend.get_session_id(connection) # pyright: ignore [reportGeneralTypeIssues] + + await self.app(scope, receive, self.create_send_wrapper(connection)) diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/client_side.py b/venv/lib/python3.11/site-packages/litestar/middleware/session/client_side.py new file mode 100644 index 0000000..f709410 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/client_side.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import binascii +import contextlib +import re +import time +from base64 import b64decode, b64encode +from dataclasses import dataclass, field +from os import urandom +from typing import TYPE_CHECKING, Any, Literal + +from litestar.datastructures import MutableScopeHeaders +from litestar.datastructures.cookie import Cookie +from litestar.enums import ScopeType +from litestar.exceptions import ( + ImproperlyConfiguredException, + MissingDependencyException, +) +from litestar.serialization import decode_json, encode_json +from litestar.types import Empty, Scopes +from litestar.utils.dataclass import extract_dataclass_items + +from .base import ONE_DAY_IN_SECONDS, BaseBackendConfig, BaseSessionBackend + +__all__ = ("ClientSideSessionBackend", "CookieBackendConfig") + + +try: + from cryptography.exceptions import InvalidTag + from cryptography.hazmat.primitives.ciphers.aead import AESGCM +except ImportError as e: + raise MissingDependencyException("cryptography") from e + +if TYPE_CHECKING: + from litestar.connection import ASGIConnection + from litestar.types import Message, Scope, ScopeSession + +NONCE_SIZE = 12 +CHUNK_SIZE = 4096 - 64 +AAD = b"additional_authenticated_data=" + + +class ClientSideSessionBackend(BaseSessionBackend["CookieBackendConfig"]): + """Cookie backend for SessionMiddleware.""" + + __slots__ = ("aesgcm", "cookie_re") + + def __init__(self, config: CookieBackendConfig) -> None: + """Initialize ``ClientSideSessionBackend``. + + Args: + config: SessionCookieConfig instance. + """ + super().__init__(config) + self.aesgcm = AESGCM(config.secret) + self.cookie_re = re.compile(rf"{self.config.key}(?:-\d+)?") + + def dump_data(self, data: Any, scope: Scope | None = None) -> list[bytes]: + """Given serializable data, including pydantic models and numpy types, dump it into a bytes string, encrypt, + encode and split it into chunks of the desirable size. + + Args: + data: Data to serialize, encrypt, encode and chunk. + scope: The ASGI connection scope. + + Notes: + - The returned list is composed of a chunks of a single base64 encoded + string that is encrypted using AES-CGM. + + Returns: + List of encoded bytes string of a maximum length equal to the ``CHUNK_SIZE`` constant. + """ + serialized = self.serialize_data(data, scope) + associated_data = encode_json({"expires_at": round(time.time()) + self.config.max_age}) + nonce = urandom(NONCE_SIZE) + encrypted = self.aesgcm.encrypt(nonce, serialized, associated_data=associated_data) + encoded = b64encode(nonce + encrypted + AAD + associated_data) + return [encoded[i : i + CHUNK_SIZE] for i in range(0, len(encoded), CHUNK_SIZE)] + + def load_data(self, data: list[bytes]) -> dict[str, Any]: + """Given a list of strings, decodes them into the session object. + + Args: + data: A list of strings derived from the request's session cookie(s). + + Returns: + A deserialized session value. + """ + decoded = b64decode(b"".join(data)) + nonce = decoded[:NONCE_SIZE] + aad_starts_from = decoded.find(AAD) + associated_data = decoded[aad_starts_from:].replace(AAD, b"") if aad_starts_from != -1 else None + if associated_data and decode_json(value=associated_data)["expires_at"] > round(time.time()): + encrypted_session = decoded[NONCE_SIZE:aad_starts_from] + decrypted = self.aesgcm.decrypt(nonce, encrypted_session, associated_data=associated_data) + return self.deserialize_data(decrypted) + return {} + + def get_cookie_keys(self, connection: ASGIConnection) -> list[str]: + """Return a list of cookie-keys from the connection if they match the session-cookie pattern. + + Args: + connection: An ASGIConnection instance + + Returns: + A list of session-cookie keys + """ + return sorted(key for key in connection.cookies if self.cookie_re.fullmatch(key)) + + def _create_session_cookies(self, data: list[bytes], cookie_params: dict[str, Any] | None = None) -> list[Cookie]: + """Create a list of cookies containing the session data. + If the data is split into multiple cookies, the key will be of the format ``session-{segment number}``, + however if only one cookie is needed, the key will be ``session``. + """ + if cookie_params is None: + cookie_params = dict( + extract_dataclass_items( + self.config, + exclude_none=True, + include={f for f in Cookie.__dict__ if f not in ("key", "secret")}, + ) + ) + + if len(data) == 1: + return [ + Cookie( + value=data[0].decode("utf-8"), + key=self.config.key, + **cookie_params, + ) + ] + + return [ + Cookie( + value=datum.decode("utf-8"), + key=f"{self.config.key}-{i}", + **cookie_params, + ) + for i, datum in enumerate(data) + ] + + async def store_in_message(self, scope_session: ScopeSession, message: Message, connection: ASGIConnection) -> None: + """Store data from ``scope_session`` in ``Message`` in the form of cookies. If the contents of ``scope_session`` + are too large to fit a single cookie, it will be split across several cookies, following the naming scheme of + ``<cookie key>-<n>``. If the session is empty or shrinks, cookies will be cleared by setting their value to + ``"null"`` + + Args: + scope_session: Current session to store + message: Outgoing send-message + connection: Originating ASGIConnection containing the scope + + Returns: + None + """ + + scope = connection.scope + headers = MutableScopeHeaders.from_message(message) + cookie_keys = self.get_cookie_keys(connection) + + if scope_session and scope_session is not Empty: + data = self.dump_data(scope_session, scope=scope) + cookie_params = dict( + extract_dataclass_items( + self.config, + exclude_none=True, + include={f for f in Cookie.__dict__ if f not in ("key", "secret")}, + ) + ) + for cookie in self._create_session_cookies(data, cookie_params): + headers.add("Set-Cookie", cookie.to_header(header="")) + # Cookies with the same key overwrite the earlier cookie with that key. To expire earlier session + # cookies, first check how many session cookies will not be overwritten in this upcoming response. + # If leftover cookies are greater than or equal to 1, that means older session cookies have to be + # expired and their names are in cookie_keys. + cookies_to_clear = cookie_keys[len(data) :] if len(cookie_keys) - len(data) > 0 else [] + else: + cookies_to_clear = cookie_keys + + for cookie_key in cookies_to_clear: + cookie_params = dict( + extract_dataclass_items( + self.config, + exclude_none=True, + include={f for f in Cookie.__dict__ if f not in ("key", "secret", "max_age")}, + ) + ) + headers.add( + "Set-Cookie", + Cookie(value="null", key=cookie_key, expires=0, **cookie_params).to_header(header=""), + ) + + async def load_from_connection(self, connection: ASGIConnection) -> dict[str, Any]: + """Load session data from a connection's session-cookies and return it as a dictionary. + + Args: + connection: Originating ASGIConnection + + Returns: + The session data + """ + if cookie_keys := self.get_cookie_keys(connection): + data = [connection.cookies[key].encode("utf-8") for key in cookie_keys] + # If these exceptions occur, the session must remain empty so do nothing. + with contextlib.suppress(InvalidTag, binascii.Error): + return self.load_data(data) + return {} + + def get_session_id(self, connection: ASGIConnection) -> str | None: + return None + + +@dataclass +class CookieBackendConfig(BaseBackendConfig[ClientSideSessionBackend]): # pyright: ignore + """Configuration for [SessionMiddleware] middleware.""" + + _backend_class = ClientSideSessionBackend + + secret: bytes + """A secret key to use for generating an encryption key. + + Must have a length of 16 (128 bits), 24 (192 bits) or 32 (256 bits) characters. + """ + key: str = field(default="session") + """Key to use for the cookie inside the header, e.g. ``session=<data>`` where ``session`` is the cookie key and + ``<data>`` is the session data. + + Notes: + - If a session cookie exceeds 4KB in size it is split. In this case the key will be of the format + ``session-{segment number}``. + + """ + max_age: int = field(default=ONE_DAY_IN_SECONDS * 14) + """Maximal age of the cookie before its invalidated.""" + scopes: Scopes = field(default_factory=lambda: {ScopeType.HTTP, ScopeType.WEBSOCKET}) + """Scopes for the middleware - options are ``http`` and ``websocket`` with the default being both""" + path: str = field(default="/") + """Path fragment that must exist in the request url for the cookie to be valid. + + Defaults to ``'/'``. + """ + domain: str | None = field(default=None) + """Domain for which the cookie is valid.""" + secure: bool = field(default=False) + """Https is required for the cookie.""" + httponly: bool = field(default=True) + """Forbids javascript to access the cookie via 'Document.cookie'.""" + samesite: Literal["lax", "strict", "none"] = field(default="lax") + """Controls whether or not a cookie is sent with cross-site requests. + + Defaults to ``lax``. + """ + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns to skip in the session middleware.""" + exclude_opt_key: str = field(default="skip_session") + """An identifier to use on routes to disable the session middleware for a particular route.""" + + def __post_init__(self) -> None: + if len(self.key) < 1 or len(self.key) > 256: + raise ImproperlyConfiguredException("key must be a string with a length between 1-256") + if self.max_age < 1: + raise ImproperlyConfiguredException("max_age must be greater than 0") + if len(self.secret) not in {16, 24, 32}: + raise ImproperlyConfiguredException("secret length must be 16 (128 bit), 24 (192 bit) or 32 (256 bit)") diff --git a/venv/lib/python3.11/site-packages/litestar/middleware/session/server_side.py b/venv/lib/python3.11/site-packages/litestar/middleware/session/server_side.py new file mode 100644 index 0000000..91708ac --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/middleware/session/server_side.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import secrets +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal + +from litestar.datastructures import Cookie, MutableScopeHeaders +from litestar.enums import ScopeType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.middleware.session.base import ONE_DAY_IN_SECONDS, BaseBackendConfig, BaseSessionBackend +from litestar.types import Empty, Message, Scopes, ScopeSession +from litestar.utils.dataclass import extract_dataclass_items + +__all__ = ("ServerSideSessionBackend", "ServerSideSessionConfig") + + +if TYPE_CHECKING: + from litestar import Litestar + from litestar.connection import ASGIConnection + from litestar.stores.base import Store + + +class ServerSideSessionBackend(BaseSessionBackend["ServerSideSessionConfig"]): + """Base class for server-side backends. + + Implements :class:`BaseSessionBackend` and defines and interface which subclasses can + implement to facilitate the storage of session data. + """ + + def __init__(self, config: ServerSideSessionConfig) -> None: + """Initialize ``ServerSideSessionBackend`` + + Args: + config: A subclass of ``ServerSideSessionConfig`` + """ + super().__init__(config=config) + + async def get(self, session_id: str, store: Store) -> bytes | None: + """Retrieve data associated with ``session_id``. + + Args: + session_id: The session-ID + store: Store to retrieve the session data from + + Returns: + The session data, if existing, otherwise ``None``. + """ + max_age = int(self.config.max_age) if self.config.max_age is not None else None + return await store.get(session_id, renew_for=max_age if self.config.renew_on_access else None) + + async def set(self, session_id: str, data: bytes, store: Store) -> None: + """Store ``data`` under the ``session_id`` for later retrieval. + + If there is already data associated with ``session_id``, replace + it with ``data`` and reset its expiry time + + Args: + session_id: The session-ID + data: Serialized session data + store: Store to save the session data in + + Returns: + None + """ + expires_in = int(self.config.max_age) if self.config.max_age is not None else None + await store.set(session_id, data, expires_in=expires_in) + + async def delete(self, session_id: str, store: Store) -> None: + """Delete the data associated with ``session_id``. Fails silently if no such session-ID exists. + + Args: + session_id: The session-ID + store: Store to delete the session data from + + Returns: + None + """ + await store.delete(session_id) + + def get_session_id(self, connection: ASGIConnection) -> str: + """Try to fetch session id from the connection. If one does not exist, generate one. + + If a session ID already exists in the cookies, it is returned. + If there is no ID in the cookies but one in the connection state, then the session exists but has not yet + been returned to the user. + Otherwise, a new session must be created. + + Args: + connection: Originating ASGIConnection containing the scope + Returns: + Session id str or None if the concept of a session id does not apply. + """ + session_id = connection.cookies.get(self.config.key) + if not session_id or session_id == "null": + session_id = connection.get_session_id() + if not session_id: + session_id = self.generate_session_id() + return session_id + + def generate_session_id(self) -> str: + """Generate a new session-ID, with + n=:attr:`session_id_bytes <ServerSideSessionConfig.session_id_bytes>` random bytes. + + Returns: + A session-ID + """ + return secrets.token_hex(self.config.session_id_bytes) + + async def store_in_message(self, scope_session: ScopeSession, message: Message, connection: ASGIConnection) -> None: + """Store the necessary information in the outgoing ``Message`` by setting a cookie containing the session-ID. + + If the session is empty, a null-cookie will be set. Otherwise, the serialised + data will be stored using :meth:`set <ServerSideSessionBackend.set>`, under the current session-id. If no session-ID + exists, a new ID will be generated using :meth:`generate_session_id <ServerSideSessionBackend.generate_session_id>`. + + Args: + scope_session: Current session to store + message: Outgoing send-message + connection: Originating ASGIConnection containing the scope + + Returns: + None + """ + scope = connection.scope + store = self.config.get_store_from_app(scope["app"]) + headers = MutableScopeHeaders.from_message(message) + session_id = self.get_session_id(connection) + + cookie_params = dict(extract_dataclass_items(self.config, exclude_none=True, include=Cookie.__dict__.keys())) + + if scope_session is Empty: + await self.delete(session_id, store=store) + headers.add( + "Set-Cookie", + Cookie(value="null", key=self.config.key, expires=0, **cookie_params).to_header(header=""), + ) + else: + serialised_data = self.serialize_data(scope_session, scope) + await self.set(session_id=session_id, data=serialised_data, store=store) + headers.add( + "Set-Cookie", Cookie(value=session_id, key=self.config.key, **cookie_params).to_header(header="") + ) + + async def load_from_connection(self, connection: ASGIConnection) -> dict[str, Any]: + """Load session data from a connection and return it as a dictionary to be used in the current application + scope. + + The session-ID will be gathered from a cookie with the key set in + :attr:`BaseBackendConfig.key`. If a cookie is found, its value will be used as the session-ID and data associated + with this ID will be loaded using :meth:`get <ServerSideSessionBackend.get>`. + If no cookie was found or no data was loaded from the store, this will return an + empty dictionary. + + Args: + connection: An ASGIConnection instance + + Returns: + The current session data + """ + if session_id := connection.cookies.get(self.config.key): + store = self.config.get_store_from_app(connection.scope["app"]) + data = await self.get(session_id, store=store) + if data is not None: + return self.deserialize_data(data) + return {} + + +@dataclass +class ServerSideSessionConfig(BaseBackendConfig[ServerSideSessionBackend]): # pyright: ignore + """Base configuration for server side backends.""" + + _backend_class = ServerSideSessionBackend + + session_id_bytes: int = field(default=32) + """Number of bytes used to generate a random session-ID.""" + renew_on_access: bool = field(default=False) + """Renew expiry times of sessions when they're being accessed""" + key: str = field(default="session") + """Key to use for the cookie inside the header, e.g. ``session=<data>`` where ``session`` is the cookie key and + ``<data>`` is the session data. + + Notes: + - If a session cookie exceeds 4KB in size it is split. In this case the key will be of the format + ``session-{segment number}``. + + """ + max_age: int = field(default=ONE_DAY_IN_SECONDS * 14) + """Maximal age of the cookie before its invalidated.""" + scopes: Scopes = field(default_factory=lambda: {ScopeType.HTTP, ScopeType.WEBSOCKET}) + """Scopes for the middleware - options are ``http`` and ``websocket`` with the default being both""" + path: str = field(default="/") + """Path fragment that must exist in the request url for the cookie to be valid. + + Defaults to ``'/'``. + """ + domain: str | None = field(default=None) + """Domain for which the cookie is valid.""" + secure: bool = field(default=False) + """Https is required for the cookie.""" + httponly: bool = field(default=True) + """Forbids javascript to access the cookie via 'Document.cookie'.""" + samesite: Literal["lax", "strict", "none"] = field(default="lax") + """Controls whether or not a cookie is sent with cross-site requests. Defaults to ``lax``.""" + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns to skip in the session middleware.""" + exclude_opt_key: str = field(default="skip_session") + """An identifier to use on routes to disable the session middleware for a particular route.""" + store: str = "sessions" + """Name of the :class:`Store <.stores.base.Store>` to use""" + + def __post_init__(self) -> None: + if len(self.key) < 1 or len(self.key) > 256: + raise ImproperlyConfiguredException("key must be a string with a length between 1-256") + if self.max_age < 1: + raise ImproperlyConfiguredException("max_age must be greater than 0") + + def get_store_from_app(self, app: Litestar) -> Store: + """Get the store defined in :attr:`store` from an :class:`Litestar <.app.Litestar>` instance""" + return app.stores.get(self.store) diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/__init__.py b/venv/lib/python3.11/site-packages/litestar/openapi/__init__.py new file mode 100644 index 0000000..8cc83d3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/__init__.py @@ -0,0 +1,5 @@ +from .config import OpenAPIConfig +from .controller import OpenAPIController +from .datastructures import ResponseSpec + +__all__ = ("OpenAPIController", "OpenAPIConfig", "ResponseSpec") diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..510eb6f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/__pycache__/config.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/__pycache__/config.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..16411a1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/__pycache__/config.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/__pycache__/controller.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/__pycache__/controller.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f153bae --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/__pycache__/controller.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/__pycache__/datastructures.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/__pycache__/datastructures.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..042a57f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/__pycache__/datastructures.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/config.py b/venv/lib/python3.11/site-packages/litestar/openapi/config.py new file mode 100644 index 0000000..c935693 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/config.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass, field, fields +from typing import TYPE_CHECKING, Literal + +from litestar._openapi.utils import default_operation_id_creator +from litestar.openapi.controller import OpenAPIController +from litestar.openapi.spec import ( + Components, + Contact, + ExternalDocumentation, + Info, + License, + OpenAPI, + PathItem, + Reference, + SecurityRequirement, + Server, + Tag, +) +from litestar.utils.path import normalize_path + +__all__ = ("OpenAPIConfig",) + + +if TYPE_CHECKING: + from litestar.types.callable_types import OperationIDCreator + + +@dataclass +class OpenAPIConfig: + """Configuration for OpenAPI. + + To enable OpenAPI schema generation and serving, pass an instance of this class to the + :class:`Litestar <.app.Litestar>` constructor using the ``openapi_config`` kwargs. + """ + + title: str + """Title of API documentation.""" + version: str + """API version, e.g. '1.0.0'.""" + + create_examples: bool = field(default=False) + """Generate examples using the polyfactory library.""" + random_seed: int = 10 + """The random seed used when creating the examples to ensure deterministic generation of examples.""" + openapi_controller: type[OpenAPIController] = field(default_factory=lambda: OpenAPIController) + """Controller for generating OpenAPI routes. + + Must be subclass of :class:`OpenAPIController <litestar.openapi.controller.OpenAPIController>`. + """ + contact: Contact | None = field(default=None) + """API contact information, should be an :class:`Contact <litestar.openapi.spec.contact.Contact>` instance.""" + description: str | None = field(default=None) + """API description.""" + external_docs: ExternalDocumentation | None = field(default=None) + """Links to external documentation. + + Should be an instance of :class:`ExternalDocumentation <litestar.openapi.spec.external_documentation.ExternalDocumentation>`. + """ + license: License | None = field(default=None) + """API Licensing information. + + Should be an instance of :class:`License <litestar.openapi.spec.license.License>`. + """ + security: list[SecurityRequirement] | None = field(default=None) + """API Security requirements information. + + Should be an instance of + :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>`. + """ + components: Components | list[Components] = field(default_factory=Components) + """API Components information. + + Should be an instance of :class:`Components <litestar.openapi.spec.components.Components>` or a list thereof. + """ + servers: list[Server] = field(default_factory=lambda: [Server(url="/")]) + """A list of :class:`Server <litestar.openapi.spec.server.Server>` instances.""" + summary: str | None = field(default=None) + """A summary text.""" + tags: list[Tag] | None = field(default=None) + """A list of :class:`Tag <litestar.openapi.spec.tag.Tag>` instances.""" + terms_of_service: str | None = field(default=None) + """URL to page that contains terms of service.""" + use_handler_docstrings: bool = field(default=False) + """Draw operation description from route handler docstring if not otherwise provided.""" + webhooks: dict[str, PathItem | Reference] | None = field(default=None) + """A mapping of key to either :class:`PathItem <litestar.openapi.spec.path_item.PathItem>` or. + + :class:`Reference <litestar.openapi.spec.reference.Reference>` objects. + """ + root_schema_site: Literal["redoc", "swagger", "elements", "rapidoc"] = "redoc" + """The static schema generator to use for the "root" path of `/schema/`.""" + enabled_endpoints: set[str] = field( + default_factory=lambda: { + "redoc", + "swagger", + "elements", + "rapidoc", + "openapi.json", + "openapi.yaml", + "openapi.yml", + "oauth2-redirect.html", + } + ) + """A set of the enabled documentation sites and schema download endpoints.""" + operation_id_creator: OperationIDCreator = default_operation_id_creator + """A callable that generates unique operation ids""" + path: str | None = field(default=None) + """Base path for the OpenAPI documentation endpoints.""" + + def __post_init__(self) -> None: + if self.path: + self.path = normalize_path(self.path) + self.openapi_controller = type("OpenAPIController", (self.openapi_controller,), {"path": self.path}) + + def to_openapi_schema(self) -> OpenAPI: + """Return an ``OpenAPI`` instance from the values stored in ``self``. + + Returns: + An instance of :class:`OpenAPI <litestar.openapi.spec.open_api.OpenAPI>`. + """ + + if isinstance(self.components, list): + merged_components = Components() + for components in self.components: + for key in (f.name for f in fields(components)): + if value := getattr(components, key, None): + merged_value_dict = getattr(merged_components, key, {}) or {} + merged_value_dict.update(value) + setattr(merged_components, key, merged_value_dict) + + self.components = merged_components + + return OpenAPI( + external_docs=self.external_docs, + security=self.security, + components=deepcopy(self.components), # deepcopy prevents mutation of the config's components + servers=self.servers, + tags=self.tags, + webhooks=self.webhooks, + info=Info( + title=self.title, + version=self.version, + description=self.description, + contact=self.contact, + license=self.license, + summary=self.summary, + terms_of_service=self.terms_of_service, + ), + paths={}, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/controller.py b/venv/lib/python3.11/site-packages/litestar/openapi/controller.py new file mode 100644 index 0000000..ac03d4c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/controller.py @@ -0,0 +1,604 @@ +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING, Any, Callable, Final, Literal + +from yaml import dump as dump_yaml + +from litestar.constants import OPENAPI_NOT_INITIALIZED +from litestar.controller import Controller +from litestar.enums import MediaType, OpenAPIMediaType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.handlers import get +from litestar.response.base import ASGIResponse +from litestar.serialization import encode_json +from litestar.serialization.msgspec_hooks import decode_json +from litestar.status_codes import HTTP_404_NOT_FOUND + +__all__ = ("OpenAPIController",) + + +if TYPE_CHECKING: + from litestar.connection.request import Request + from litestar.openapi.spec.open_api import OpenAPI + +_OPENAPI_JSON_ROUTER_NAME: Final = "__litestar_openapi_json" + + +class OpenAPIController(Controller): + """Controller for OpenAPI endpoints.""" + + path: str = "/schema" + """Base path for the OpenAPI documentation endpoints.""" + style: str = "body { margin: 0; padding: 0 }" + """Base styling of the html body.""" + redoc_version: str = "next" + """Redoc version to download from the CDN.""" + swagger_ui_version: str = "5.1.3" + """SwaggerUI version to download from the CDN.""" + stoplight_elements_version: str = "7.7.18" + """StopLight Elements version to download from the CDN.""" + rapidoc_version: str = "9.3.4" + """RapiDoc version to download from the CDN.""" + favicon_url: str = "" + """URL to download a favicon from.""" + redoc_google_fonts: bool = True + """Download google fonts via CDN. + + Should be set to ``False`` when not using a CDN. + """ + redoc_js_url: str = f"https://cdn.jsdelivr.net/npm/redoc@{redoc_version}/bundles/redoc.standalone.js" + """Download url for the Redoc JS bundle.""" + swagger_css_url: str = f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{swagger_ui_version}/swagger-ui.css" + """Download url for the Swagger UI CSS bundle.""" + swagger_ui_bundle_js_url: str = ( + f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{swagger_ui_version}/swagger-ui-bundle.js" + ) + """Download url for the Swagger UI JS bundle.""" + swagger_ui_standalone_preset_js_url: str = ( + f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{swagger_ui_version}/swagger-ui-standalone-preset.js" + ) + """Download url for the Swagger Standalone Preset JS bundle.""" + swagger_ui_init_oauth: dict[Any, Any] | bytes = {} + """ + JSON to initialize Swagger UI OAuth2 by calling the `initOAuth` method. + + Refer to the following URL for details: + `Swagger-UI <https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/>`_. + """ + stoplight_elements_css_url: str = ( + f"https://unpkg.com/@stoplight/elements@{stoplight_elements_version}/styles.min.css" + ) + """Download url for the Stoplight Elements CSS bundle.""" + stoplight_elements_js_url: str = ( + f"https://unpkg.com/@stoplight/elements@{stoplight_elements_version}/web-components.min.js" + ) + """Download url for the Stoplight Elements JS bundle.""" + rapidoc_js_url: str = f"https://unpkg.com/rapidoc@{rapidoc_version}/dist/rapidoc-min.js" + """Download url for the RapiDoc JS bundle.""" + + # internal + _dumped_json_schema: str = "" + _dumped_yaml_schema: bytes = b"" + # until swagger-ui supports v3.1.* of OpenAPI officially, we need to modify the schema for it and keep it + # separate from the redoc version of the schema, which is unmodified. + dto = None + return_dto = None + + @staticmethod + def get_schema_from_request(request: Request[Any, Any, Any]) -> OpenAPI: + """Return the OpenAPI pydantic model from the request instance. + + Args: + request: A :class:`Litestar <.connection.Request>` instance. + + Returns: + An :class:`OpenAPI <litestar.openapi.spec.open_api.OpenAPI>` instance. + """ + return request.app.openapi_schema + + def should_serve_endpoint(self, request: Request[Any, Any, Any]) -> bool: + """Verify that the requested path is within the enabled endpoints in the openapi_config. + + Args: + request: To be tested if endpoint enabled. + + Returns: + A boolean. + + Raises: + ImproperlyConfiguredException: If the application ``openapi_config`` attribute is ``None``. + """ + if not request.app.openapi_config: # pragma: no cover + raise ImproperlyConfiguredException("Litestar has not been instantiated with an OpenAPIConfig") + + asgi_root_path = set(filter(None, request.scope.get("root_path", "").split("/"))) + full_request_path = set(filter(None, request.url.path.split("/"))) + request_path = full_request_path.difference(asgi_root_path) + root_path = set(filter(None, self.path.split("/"))) + + config = request.app.openapi_config + + if request_path == root_path and config.root_schema_site in config.enabled_endpoints: + return True + + return bool(request_path & config.enabled_endpoints) + + @property + def favicon(self) -> str: + """Return favicon ``<link>`` tag, if applicable. + + Returns: + A ``<link>`` tag if ``self.favicon_url`` is not empty, otherwise returns a placeholder meta tag. + """ + return f"<link rel='icon' type='image/x-icon' href='{self.favicon_url}'>" if self.favicon_url else "<meta/>" + + @cached_property + def render_methods_map( + self, + ) -> dict[Literal["redoc", "swagger", "elements", "rapidoc"], Callable[[Request], bytes]]: + """Map render method names to render methods. + + Returns: + A mapping of string keys to render methods. + """ + return { + "redoc": self.render_redoc, + "swagger": self.render_swagger_ui, + "elements": self.render_stoplight_elements, + "rapidoc": self.render_rapidoc, + } + + @get( + path=["/openapi.yaml", "openapi.yml"], + media_type=OpenAPIMediaType.OPENAPI_YAML, + include_in_schema=False, + sync_to_thread=False, + ) + def retrieve_schema_yaml(self, request: Request[Any, Any, Any]) -> ASGIResponse: + """Return the OpenAPI schema as YAML with an ``application/vnd.oai.openapi`` Content-Type header. + + Args: + request: + A :class:`Request <.connection.Request>` instance. + + Returns: + A Response instance with the YAML object rendered into a string. + """ + if self.should_serve_endpoint(request): + if not self._dumped_json_schema: + schema_json = decode_json(self._get_schema_as_json(request)) + schema_yaml = dump_yaml(schema_json, default_flow_style=False) + self._dumped_yaml_schema = schema_yaml.encode("utf-8") + return ASGIResponse(body=self._dumped_yaml_schema, media_type=OpenAPIMediaType.OPENAPI_YAML) + return ASGIResponse(body=b"", status_code=HTTP_404_NOT_FOUND, media_type=MediaType.HTML) + + @get( + path="/openapi.json", + media_type=OpenAPIMediaType.OPENAPI_JSON, + include_in_schema=False, + sync_to_thread=False, + name=_OPENAPI_JSON_ROUTER_NAME, + ) + def retrieve_schema_json(self, request: Request[Any, Any, Any]) -> ASGIResponse: + """Return the OpenAPI schema as JSON with an ``application/vnd.oai.openapi+json`` Content-Type header. + + Args: + request: + A :class:`Request <.connection.Request>` instance. + + Returns: + A Response instance with the JSON object rendered into a string. + """ + if self.should_serve_endpoint(request): + return ASGIResponse( + body=self._get_schema_as_json(request), + media_type=OpenAPIMediaType.OPENAPI_JSON, + ) + return ASGIResponse(body=b"", status_code=HTTP_404_NOT_FOUND, media_type=MediaType.HTML) + + @get(path="/", include_in_schema=False, sync_to_thread=False) + def root(self, request: Request[Any, Any, Any]) -> ASGIResponse: + """Render a static documentation site. + + The site to be rendered is based on the ``root_schema_site`` value set in the application's + :class:`OpenAPIConfig <.openapi.OpenAPIConfig>`. Defaults to ``redoc``. + + Args: + request: + A :class:`Request <.connection.Request>` instance. + + Returns: + A response with the rendered site defined in root_schema_site. + + Raises: + ImproperlyConfiguredException: If the application ``openapi_config`` attribute is ``None``. + """ + config = request.app.openapi_config + if not config: # pragma: no cover + raise ImproperlyConfiguredException(OPENAPI_NOT_INITIALIZED) + + render_method = self.render_methods_map[config.root_schema_site] + + if self.should_serve_endpoint(request): + return ASGIResponse(body=render_method(request), media_type=MediaType.HTML) + return ASGIResponse(body=self.render_404_page(), status_code=HTTP_404_NOT_FOUND, media_type=MediaType.HTML) + + @get(path="/swagger", include_in_schema=False, sync_to_thread=False) + def swagger_ui(self, request: Request[Any, Any, Any]) -> ASGIResponse: + """Route handler responsible for rendering Swagger-UI. + + Args: + request: + A :class:`Request <.connection.Request>` instance. + + Returns: + A response with a rendered swagger documentation site + """ + if self.should_serve_endpoint(request): + return ASGIResponse(body=self.render_swagger_ui(request), media_type=MediaType.HTML) + return ASGIResponse(body=self.render_404_page(), status_code=HTTP_404_NOT_FOUND, media_type=MediaType.HTML) + + @get(path="/elements", media_type=MediaType.HTML, include_in_schema=False, sync_to_thread=False) + def stoplight_elements(self, request: Request[Any, Any, Any]) -> ASGIResponse: + """Route handler responsible for rendering StopLight Elements. + + Args: + request: + A :class:`Request <.connection.Request>` instance. + + Returns: + A response with a rendered stoplight elements documentation site + """ + if self.should_serve_endpoint(request): + return ASGIResponse(body=self.render_stoplight_elements(request), media_type=MediaType.HTML) + return ASGIResponse(body=self.render_404_page(), status_code=HTTP_404_NOT_FOUND, media_type=MediaType.HTML) + + @get(path="/redoc", media_type=MediaType.HTML, include_in_schema=False, sync_to_thread=False) + def redoc(self, request: Request[Any, Any, Any]) -> ASGIResponse: # pragma: no cover + """Route handler responsible for rendering Redoc. + + Args: + request: + A :class:`Request <.connection.Request>` instance. + + Returns: + A response with a rendered redoc documentation site + """ + if self.should_serve_endpoint(request): + return ASGIResponse(body=self.render_redoc(request), media_type=MediaType.HTML) + return ASGIResponse(body=self.render_404_page(), status_code=HTTP_404_NOT_FOUND, media_type=MediaType.HTML) + + @get(path="/rapidoc", media_type=MediaType.HTML, include_in_schema=False, sync_to_thread=False) + def rapidoc(self, request: Request[Any, Any, Any]) -> ASGIResponse: + if self.should_serve_endpoint(request): + return ASGIResponse(body=self.render_rapidoc(request), media_type=MediaType.HTML) + return ASGIResponse(body=self.render_404_page(), status_code=HTTP_404_NOT_FOUND, media_type=MediaType.HTML) + + @get(path="/oauth2-redirect.html", media_type=MediaType.HTML, include_in_schema=False, sync_to_thread=False) + def swagger_ui_oauth2_redirect(self, request: Request[Any, Any, Any]) -> ASGIResponse: # pragma: no cover + """Route handler responsible for rendering oauth2-redirect.html page for Swagger-UI. + + Args: + request: + A :class:`Request <.connection.Request>` instance. + + Returns: + A response with a rendered oauth2-redirect.html page for Swagger-UI. + """ + if self.should_serve_endpoint(request): + return ASGIResponse(body=self.render_swagger_ui_oauth2_redirect(request), media_type=MediaType.HTML) + return ASGIResponse(body=self.render_404_page(), status_code=HTTP_404_NOT_FOUND, media_type=MediaType.HTML) + + def render_swagger_ui_oauth2_redirect(self, request: Request[Any, Any, Any]) -> bytes: + """Render an HTML oauth2-redirect.html page for Swagger-UI. + + Notes: + - override this method to customize the template. + + Args: + request: + A :class:`Request <.connection.Request>` instance. + + Returns: + A rendered html string. + """ + return rb"""<!doctype html> + <html lang="en-US"> + <head> + <title>Swagger UI: OAuth2 Redirect</title> + </head> + <body> + <script> + 'use strict'; + function run () { + var oauth2 = window.opener.swaggerUIRedirectOauth2; + var sentState = oauth2.state; + var redirectUrl = oauth2.redirectUrl; + var isValid, qp, arr; + + if (/code|token|error/.test(window.location.hash)) { + qp = window.location.hash.substring(1).replace('?', '&'); + } else { + qp = location.search.substring(1); + } + + arr = qp.split("&"); + arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';}); + qp = qp ? JSON.parse('{' + arr.join() + '}', + function (key, value) { + return key === "" ? value : decodeURIComponent(value); + } + ) : {}; + + isValid = qp.state === sentState; + + if (( + oauth2.auth.schema.get("flow") === "accessCode" || + oauth2.auth.schema.get("flow") === "authorizationCode" || + oauth2.auth.schema.get("flow") === "authorization_code" + ) && !oauth2.auth.code) { + if (!isValid) { + oauth2.errCb({ + authId: oauth2.auth.name, + source: "auth", + level: "warning", + message: "Authorization may be unsafe, passed state was changed in server. The passed state wasn't returned from auth server." + }); + } + + if (qp.code) { + delete oauth2.state; + oauth2.auth.code = qp.code; + oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl}); + } else { + let oauthErrorMsg; + if (qp.error) { + oauthErrorMsg = "["+qp.error+"]: " + + (qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") + + (qp.error_uri ? "More info: "+qp.error_uri : ""); + } + + oauth2.errCb({ + authId: oauth2.auth.name, + source: "auth", + level: "error", + message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server." + }); + } + } else { + oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl}); + } + window.close(); + } + + if (document.readyState !== 'loading') { + run(); + } else { + document.addEventListener('DOMContentLoaded', function () { + run(); + }); + } + </script> + </body> + </html>""" + + def render_swagger_ui(self, request: Request[Any, Any, Any]) -> bytes: + """Render an HTML page for Swagger-UI. + + Notes: + - override this method to customize the template. + + Args: + request: + A :class:`Request <.connection.Request>` instance. + + Returns: + A rendered html string. + """ + schema = self.get_schema_from_request(request) + + head = f""" + <head> + <title>{schema.info.title}</title> + {self.favicon} + <meta charset="utf-8"/> + <meta name="viewport" content="width=device-width, initial-scale=1"> + <link href="{self.swagger_css_url}" rel="stylesheet"> + <script src="{self.swagger_ui_bundle_js_url}" crossorigin></script> + <script src="{self.swagger_ui_standalone_preset_js_url}" crossorigin></script> + <style>{self.style}</style> + </head> + """ + + body = f""" + <body> + <div id='swagger-container'/> + <script type="text/javascript"> + const ui = SwaggerUIBundle({{ + spec: {self._get_schema_as_json(request)}, + dom_id: '#swagger-container', + deepLinking: true, + showExtensions: true, + showCommonExtensions: true, + presets: [ + SwaggerUIBundle.presets.apis, + SwaggerUIBundle.SwaggerUIStandalonePreset + ], + }}) + ui.initOAuth({encode_json(self.swagger_ui_init_oauth).decode('utf-8')}) + </script> + </body> + """ + + return f""" + <!DOCTYPE html> + <html> + {head} + {body} + </html> + """.encode() + + def render_stoplight_elements(self, request: Request[Any, Any, Any]) -> bytes: + """Render an HTML page for StopLight Elements. + + Notes: + - override this method to customize the template. + + Args: + request: + A :class:`Request <.connection.Request>` instance. + + Returns: + A rendered html string. + """ + schema = self.get_schema_from_request(request) + head = f""" + <head> + <title>{schema.info.title}</title> + {self.favicon} + <meta charset="utf-8"/> + <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no"> + <link rel="stylesheet" href="{self.stoplight_elements_css_url}"> + <script src="{self.stoplight_elements_js_url}" crossorigin></script> + <style>{self.style}</style> + </head> + """ + + body = f""" + <body> + <elements-api + apiDescriptionUrl="{request.app.route_reverse(_OPENAPI_JSON_ROUTER_NAME)}" + router="hash" + layout="sidebar" + /> + </body> + """ + + return f""" + <!DOCTYPE html> + <html> + {head} + {body} + </html> + """.encode() + + def render_rapidoc(self, request: Request[Any, Any, Any]) -> bytes: # pragma: no cover + schema = self.get_schema_from_request(request) + + head = f""" + <head> + <title>{schema.info.title}</title> + {self.favicon} + <meta charset="utf-8"/> + <meta name="viewport" content="width=device-width, initial-scale=1"> + <script src="{self.rapidoc_js_url}" crossorigin></script> + <style>{self.style}</style> + </head> + """ + + body = f""" + <body> + <rapi-doc spec-url="{request.app.route_reverse(_OPENAPI_JSON_ROUTER_NAME)}" /> + </body> + """ + + return f""" + <!DOCTYPE html> + <html> + {head} + {body} + </html> + """.encode() + + def render_redoc(self, request: Request[Any, Any, Any]) -> bytes: # pragma: no cover + """Render an HTML page for Redoc. + + Notes: + - override this method to customize the template. + + Args: + request: + A :class:`Request <.connection.Request>` instance. + + Returns: + A rendered html string. + """ + schema = self.get_schema_from_request(request) + + head = f""" + <head> + <title>{schema.info.title}</title> + {self.favicon} + <meta charset="utf-8"/> + <meta name="viewport" content="width=device-width, initial-scale=1"> + """ + + if self.redoc_google_fonts: + head += """ + <link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet"> + """ + + head += f""" + <script src="{self.redoc_js_url}" crossorigin></script> + <style> + {self.style} + </style> + </head> + """ + + body = f""" + <body> + <div id='redoc-container'/> + <script type="text/javascript"> + Redoc.init( + {self._get_schema_as_json(request)}, + undefined, + document.getElementById('redoc-container') + ) + </script> + </body> + """ + + return f""" + <!DOCTYPE html> + <html> + {head} + {body} + </html> + """.encode() + + def render_404_page(self) -> bytes: + """Render an HTML 404 page. + + Returns: + A rendered html string. + """ + + return f""" + <!DOCTYPE html> + <html> + <head> + <title>404 Not found</title> + {self.favicon} + <meta charset="utf-8"/> + <meta name="viewport" content="width=device-width, initial-scale=1"> + <style> + {self.style} + </style> + </head> + <body> + <h1>Error 404</h1> + </body> + </html> + """.encode() + + def _get_schema_as_json(self, request: Request) -> str: + """Get the schema encoded as a JSON string.""" + + if not self._dumped_json_schema: + schema = self.get_schema_from_request(request).to_schema() + json_encoded_schema = encode_json(schema, request.route_handler.default_serializer) + self._dumped_json_schema = json_encoded_schema.decode("utf-8") + + return self._dumped_json_schema diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/datastructures.py b/venv/lib/python3.11/site-packages/litestar/openapi/datastructures.py new file mode 100644 index 0000000..5796a48 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/datastructures.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from litestar.enums import MediaType + +if TYPE_CHECKING: + from litestar.openapi.spec import Example + from litestar.types import DataContainerType + + +__all__ = ("ResponseSpec",) + + +@dataclass +class ResponseSpec: + """Container type of additional responses.""" + + data_container: DataContainerType | None + """A model that describes the content of the response.""" + generate_examples: bool = field(default=True) + """Generate examples for the response content.""" + description: str = field(default="Additional response") + """A description of the response.""" + media_type: MediaType = field(default=MediaType.JSON) + """Response media type.""" + examples: list[Example] | None = field(default=None) + """A list of Example models.""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__init__.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__init__.py new file mode 100644 index 0000000..438c351 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__init__.py @@ -0,0 +1,68 @@ +from .base import BaseSchemaObject +from .callback import Callback +from .components import Components +from .contact import Contact +from .discriminator import Discriminator +from .encoding import Encoding +from .enums import OpenAPIFormat, OpenAPIType +from .example import Example +from .external_documentation import ExternalDocumentation +from .header import OpenAPIHeader +from .info import Info +from .license import License +from .link import Link +from .media_type import OpenAPIMediaType +from .oauth_flow import OAuthFlow +from .oauth_flows import OAuthFlows +from .open_api import OpenAPI +from .operation import Operation +from .parameter import Parameter +from .path_item import PathItem +from .paths import Paths +from .reference import Reference +from .request_body import RequestBody +from .response import OpenAPIResponse +from .responses import Responses +from .schema import Schema +from .security_requirement import SecurityRequirement +from .security_scheme import SecurityScheme +from .server import Server +from .server_variable import ServerVariable +from .tag import Tag +from .xml import XML + +__all__ = ( + "BaseSchemaObject", + "Callback", + "Components", + "Contact", + "Discriminator", + "Encoding", + "Example", + "ExternalDocumentation", + "Info", + "License", + "Link", + "OAuthFlow", + "OAuthFlows", + "OpenAPI", + "OpenAPIFormat", + "OpenAPIHeader", + "OpenAPIMediaType", + "OpenAPIResponse", + "OpenAPIType", + "Operation", + "Parameter", + "PathItem", + "Paths", + "Reference", + "RequestBody", + "Responses", + "Schema", + "SecurityRequirement", + "SecurityScheme", + "Server", + "ServerVariable", + "Tag", + "XML", +) diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..682acac --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..7f9658f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/callback.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/callback.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8fb170e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/callback.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/components.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/components.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..98c83be --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/components.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/contact.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/contact.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..53b94c2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/contact.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/discriminator.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/discriminator.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5a4b1bc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/discriminator.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/encoding.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/encoding.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..58ab620 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/encoding.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/enums.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/enums.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..9e328e3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/enums.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/example.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/example.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4f031b1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/example.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/external_documentation.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/external_documentation.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a0f445b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/external_documentation.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/header.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/header.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1596fd8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/header.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/info.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/info.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..45c4fdd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/info.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/license.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/license.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..98c60d4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/license.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/link.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/link.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b30775a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/link.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/media_type.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/media_type.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..47c85af --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/media_type.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/oauth_flow.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/oauth_flow.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a6a8564 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/oauth_flow.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/oauth_flows.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/oauth_flows.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..04555f8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/oauth_flows.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/open_api.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/open_api.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1926726 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/open_api.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/operation.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/operation.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1c3ddc2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/operation.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/parameter.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/parameter.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..e860dc9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/parameter.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/path_item.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/path_item.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..cf33d1c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/path_item.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/paths.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/paths.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4dc2e6e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/paths.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/reference.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/reference.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..66b4e2c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/reference.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/request_body.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/request_body.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5788ebc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/request_body.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/response.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/response.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..364f342 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/response.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/responses.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/responses.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..db08130 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/responses.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/schema.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/schema.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6af9ca3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/schema.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/security_requirement.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/security_requirement.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5aa7c2f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/security_requirement.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/security_scheme.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/security_scheme.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a2b1045 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/security_scheme.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/server.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/server.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..26b70b3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/server.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/server_variable.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/server_variable.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..46ab0cd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/server_variable.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/tag.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/tag.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..7597e16 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/tag.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/xml.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/xml.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..97cacd9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/__pycache__/xml.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/base.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/base.py new file mode 100644 index 0000000..69cd3f3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/base.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from dataclasses import asdict, dataclass, fields, is_dataclass +from enum import Enum +from typing import Any + +__all__ = ("BaseSchemaObject",) + + +def _normalize_key(key: str) -> str: + if key.endswith("_in"): + return "in" + if key.startswith("schema_"): + return key.split("_")[1] + if "_" in key: + components = key.split("_") + return components[0] + "".join(component.title() for component in components[1:]) + return "$ref" if key == "ref" else key + + +def _normalize_value(value: Any) -> Any: + if isinstance(value, BaseSchemaObject): + return value.to_schema() + if is_dataclass(value): + return {_normalize_value(k): _normalize_value(v) for k, v in asdict(value).items() if v is not None} + if isinstance(value, dict): + return {_normalize_value(k): _normalize_value(v) for k, v in value.items() if v is not None} + if isinstance(value, list): + return [_normalize_value(v) for v in value] + return value.value if isinstance(value, Enum) else value + + +@dataclass +class BaseSchemaObject: + """Base class for schema spec objects""" + + def to_schema(self) -> dict[str, Any]: + """Transform the spec dataclass object into a string keyed dictionary. This method traverses all nested values + recursively. + """ + result: dict[str, Any] = {} + + for field in fields(self): + value = _normalize_value(getattr(self, field.name, None)) + + if value is not None: + if "alias" in field.metadata: + if not isinstance(field.metadata["alias"], str): + raise TypeError('metadata["alias"] must be a str') + key = field.metadata["alias"] + else: + key = _normalize_key(field.name) + + result[key] = value + + return result diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/callback.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/callback.py new file mode 100644 index 0000000..c4bb129 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/callback.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Union + +if TYPE_CHECKING: + from litestar.openapi.spec.path_item import PathItem + from litestar.openapi.spec.reference import Reference + + +Callback = Dict[str, Union["PathItem", "Reference"]] +"""A map of possible out-of band callbacks related to the parent operation. Each value in the map is a +`Path Item Object <https://spec.openapis.org/oas/v3.1.0#pathItemObject>`_ that describes a set of requests that may be +initiated by the API provider and the expected responses. The key value used to identify the path item object is an +expression, evaluated at runtime, that identifies a URL to use for the callback operation. + +Patterned Fields + +{expression}: 'PathItem' = ... + +A Path Item Object used to define a callback request and expected responses. + +A `complete example <https://github.com/OAI/OpenAPI-Specification/blob/main/examples/v3.1/webhook-example.yaml>`_ is +available. +""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/components.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/components.py new file mode 100644 index 0000000..b11da78 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/components.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from litestar.openapi.spec.base import BaseSchemaObject + +__all__ = ("Components",) + + +if TYPE_CHECKING: + from litestar.openapi.spec.callback import Callback + from litestar.openapi.spec.example import Example + from litestar.openapi.spec.header import OpenAPIHeader + from litestar.openapi.spec.link import Link + from litestar.openapi.spec.parameter import Parameter + from litestar.openapi.spec.path_item import PathItem + from litestar.openapi.spec.reference import Reference + from litestar.openapi.spec.request_body import RequestBody + from litestar.openapi.spec.response import OpenAPIResponse + from litestar.openapi.spec.schema import Schema + from litestar.openapi.spec.security_scheme import SecurityScheme + + +@dataclass +class Components(BaseSchemaObject): + """Holds a set of reusable objects for different aspects of the OAS. + + All objects defined within the components object will have no effect + on the API unless they are explicitly referenced from properties + outside the components object. + """ + + schemas: dict[str, Schema] = field(default_factory=dict) + """An object to hold reusable + `Schema Objects <https://spec.openapis.org/oas/v3.1.0#schemaObject>`_""" + + responses: dict[str, OpenAPIResponse | Reference] | None = None + """An object to hold reusable + `Response Objects <https://spec.openapis.org/oas/v3.1.0#responseObject>`_""" + + parameters: dict[str, Parameter | Reference] | None = None + """An object to hold reusable + `Parameter Objects <https://spec.openapis.org/oas/v3.1.0#parameterObject>`_""" + + examples: dict[str, Example | Reference] | None = None + """An object to hold reusable + `Example Objects <https://spec.openapis.org/oas/v3.1.0#exampleObject>`_""" + + request_bodies: dict[str, RequestBody | Reference] | None = None + """An object to hold reusable + `Request Body Objects <https://spec.openapis.org/oas/v3.1.0#requestBodyObject>`_""" + + headers: dict[str, OpenAPIHeader | Reference] | None = None + """An object to hold reusable + `Header Objects <https://spec.openapis.org/oas/v3.1.0#headerObject>`_""" + + security_schemes: dict[str, SecurityScheme | Reference] | None = None + """An object to hold reusable + `Security Scheme Objects <https://spec.openapis.org/oas/v3.1.0#securitySchemeObject>`_""" + + links: dict[str, Link | Reference] | None = None + """An object to hold reusable + `Link Objects <https://spec.openapis.org/oas/v3.1.0#linkObject>`_""" + + callbacks: dict[str, Callback | Reference] | None = None + """An object to hold reusable + `Callback Objects <https://spec.openapis.org/oas/v3.1.0#callbackObject>`_""" + + path_items: dict[str, PathItem | Reference] | None = None + """An object to hold reusable + `Path Item Object <https://spec.openapis.org/oas/v3.1.0#pathItemObject>`_""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/contact.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/contact.py new file mode 100644 index 0000000..b816288 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/contact.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from litestar.openapi.spec.base import BaseSchemaObject + +__all__ = ("Contact",) + + +@dataclass +class Contact(BaseSchemaObject): + """Contact information for the exposed API.""" + + name: str | None = None + """The identifying name of the contact person/organization.""" + + url: str | None = None + """The URL pointing to the contact information. MUST be in the form of a URL.""" + + email: str | None = None + """The email address of the contact person/organization. MUST be in the form of an email address.""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/discriminator.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/discriminator.py new file mode 100644 index 0000000..3c292ce --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/discriminator.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from litestar.openapi.spec.base import BaseSchemaObject + +__all__ = ("Discriminator",) + + +@dataclass(unsafe_hash=True) +class Discriminator(BaseSchemaObject): + """When request bodies or response payloads may be one of a number of different schemas, a ``discriminator`` + object can be used to aid in serialization, deserialization, and validation. + + The discriminator is a specific object in a schema which is used to inform the consumer of the specification of an + alternative schema based on the value associated with it. + + When using the discriminator, _inline_ schemas will not be considered. + """ + + property_name: str + """**REQUIRED**. The name of the property in the payload that will hold the discriminator value.""" + + mapping: dict[str, str] | None = None + """An object to hold mappings between payload values and schema names or references.""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/encoding.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/encoding.py new file mode 100644 index 0000000..2a469c8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/encoding.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from litestar.openapi.spec.base import BaseSchemaObject + +if TYPE_CHECKING: + from litestar.openapi.spec.header import OpenAPIHeader + from litestar.openapi.spec.reference import Reference + +__all__ = ("Encoding",) + + +@dataclass +class Encoding(BaseSchemaObject): + """A single encoding definition applied to a single schema property.""" + + content_type: str | None = None + """The Content-Type for encoding a specific property. Default value depends n the property type: + + - for ``object``: ``application/json`` + - for ``array``: the default is defined based on the inner type + - for all other cases the default is ``application/octet-stream``. + + The value can be a specific media type (e.g. ``application/json``), a wildcard media type (e.g. ``image/*``), or a + comma-separated list of the two types. + """ + + headers: dict[str, OpenAPIHeader | Reference] | None = None + """A map allowing additional information to be provided as headers, for example ``Content-Disposition``. + + ``Content-Type`` is described separately and SHALL be ignored in this section. This property SHALL be ignored if the + request body media type is not a ``multipart``. + """ + + style: str | None = None + """Describes how a specific property value will be serialized depending on its type. + + See `Parameter Object <https://spec.openapis.org/oas/v3.1.0#parameterObject>`_ for details on the + `style <https://spec.openapis.org/oas/v3.1.0#parameterStyle>`__ property. The behavior follows the same values as + ``query`` parameters, including default values. This property SHALL be ignored if the request body media type is not + ``application/x-www-form-urlencoded`` or ``multipart/form-data``. If a value is explicitly defined, then the value + of `contentType <https://spec.openapis.org/oas/v3.1.0#encodingContentType>`_ (implicit or explicit) SHALL be + ignored. + """ + + explode: bool = False + """When this is true, property values of type ``array`` or ``object`` generate separate parameters for each value of + the array, or key-value-pair of the map. + + For other types of properties this property has no effect. When + `style <https://spec.openapis.org/oas/v3.1.0#encodingStyle>`_ is ``form``, the default value is ``True``. For all + other styles, the default value is ``False``. This property SHALL be ignored if the request body media type is not + ``application/x-www-form-urlencoded`` or ``multipart/form-data``. If a value is explicitly defined, then the value + of `contentType <https://spec.openapis.org/oas/v3.1.0#encodingContentType>`_ (implicit or explicit) SHALL be + ignored. + """ + + allow_reserved: bool = False + """Determines whether the parameter value SHOULD allow reserved characters, as defined by :rfc:`3986` + (``:/?#[]@!$&'()*+,;=``) to be included without percent-encoding. + + This property SHALL be ignored if the request body media type s not ``application/x-www-form-urlencoded`` or + ``multipart/form-data``. If a value is explicitly defined, then the value of + `contentType <https://spec.openapis.org/oas/v3.1.0#encodingContentType>`_ (implicit or explicit) SHALL be ignored. + """ diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/enums.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/enums.py new file mode 100644 index 0000000..da9adea --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/enums.py @@ -0,0 +1,41 @@ +from enum import Enum + +__all__ = ("OpenAPIFormat", "OpenAPIType") + + +class OpenAPIFormat(str, Enum): + """Formats extracted from: https://datatracker.ietf.org/doc/html/draft-bhutton-json-schema-validation-00#page-13""" + + DATE = "date" + DATE_TIME = "date-time" + TIME = "time" + DURATION = "duration" + URL = "url" + EMAIL = "email" + IDN_EMAIL = "idn-email" + HOST_NAME = "hostname" + IDN_HOST_NAME = "idn-hostname" + IPV4 = "ipv4" + IPV6 = "ipv6" + URI = "uri" + URI_REFERENCE = "uri-reference" + URI_TEMPLATE = "uri-template" + JSON_POINTER = "json-pointer" + RELATIVE_JSON_POINTER = "relative-json-pointer" + IRI = "iri-reference" + IRI_REFERENCE = "iri-reference" # noqa: PIE796 + UUID = "uuid" + REGEX = "regex" + BINARY = "binary" + + +class OpenAPIType(str, Enum): + """An OopenAPI type.""" + + ARRAY = "array" + BOOLEAN = "boolean" + INTEGER = "integer" + NULL = "null" + NUMBER = "number" + OBJECT = "object" + STRING = "string" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/example.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/example.py new file mode 100644 index 0000000..414452e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/example.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from litestar.openapi.spec.base import BaseSchemaObject + + +@dataclass +class Example(BaseSchemaObject): + summary: str | None = None + """Short description for the example.""" + + description: str | None = None + """Long description for the example. + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + """ + + value: Any | None = None + """Embedded literal example. + + The ``value`` field and ``externalValue`` field are mutually exclusive. To represent examples of media types that + cannot naturally represented in JSON or YAML, use a string value to contain the example, escaping where necessary. + """ + + external_value: str | None = None + """A URL that points to the literal example. This provides the capability to reference examples that cannot easily + be included in JSON or YAML documents. + + The ``value`` field and ``externalValue`` field are mutually exclusive. See the rules for resolving + `Relative References <https://spec.openapis.org/oas/v3.1.0#relativeReferencesURI>`_. + """ diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/external_documentation.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/external_documentation.py new file mode 100644 index 0000000..f11b90d --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/external_documentation.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from litestar.openapi.spec.base import BaseSchemaObject + +__all__ = ("ExternalDocumentation",) + + +@dataclass +class ExternalDocumentation(BaseSchemaObject): + """Allows referencing an external resource for extended documentation.""" + + url: str + """**REQUIRED**. The URL for the target documentation. Value MUST be in the form of a URL.""" + + description: str | None = None + """A short description of the target documentation. + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + """ diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/header.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/header.py new file mode 100644 index 0000000..006ff22 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/header.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal + +from litestar.openapi.spec.base import BaseSchemaObject + +if TYPE_CHECKING: + from litestar.openapi.spec.example import Example + from litestar.openapi.spec.media_type import OpenAPIMediaType + from litestar.openapi.spec.reference import Reference + from litestar.openapi.spec.schema import Schema + +__all__ = ("OpenAPIHeader",) + + +@dataclass +class OpenAPIHeader(BaseSchemaObject): + """The Header Object follows the structure of the [Parameter + Object](https://spec.openapis.org/oas/v3.1.0#parameterObject) with the + following changes: + + 1. ``name`` MUST NOT be specified, it is given in the corresponding ``headers`` map. + 2. ``in`` MUST NOT be specified, it is implicitly in ``header``. + 3. All traits that are affected by the location MUST be applicable to a location of ``header`` + (for example, `style <https://spec.openapis.org/oas/v3.1.0#parameterStyle>`__). + """ + + schema: Schema | Reference | None = None + """The schema defining the type used for the parameter.""" + + name: Literal[""] = "" + """MUST NOT be specified, it is given in the corresponding ``headers`` map.""" + + param_in: Literal["header"] = "header" + """MUST NOT be specified, it is implicitly in ``header``.""" + + description: str | None = None + """A brief description of the parameter. This could contain examples of + use. + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + """ + + required: bool = False + """Determines whether this parameter is mandatory. + + If the `parameter location <https://spec.openapis.org/oas/v3.1.0#parameterIn>`_ is ``"path"``, this property is + **REQUIRED** and its value MUST be ``True``. Otherwise, the property MAY be included and its default value is + ``False``. + """ + + deprecated: bool = False + """Specifies that a parameter is deprecated and SHOULD be transitioned out of usage. Default value is ``False``.""" + + allow_empty_value: bool = False + """Sets the ability to pass empty-valued parameters. This is valid only for ``query`` parameters and allows sending + a parameter with an empty value. Default value is ``False``. If + `style <https://spec.openapis.org/oas/v3.1.0#parameterStyle>`__ is used, and if behavior is ``n/a`` (cannot be + serialized), the value of ``allowEmptyValue`` SHALL be ignored. Use of this property is NOT RECOMMENDED, as it is + likely to be removed in a later revision. + + The rules for serialization of the parameter are specified in one of two ways.For simpler scenarios, a + `schema <https://spec.openapis.org/oas/v3.1.0#parameterSchema>`_ and + `style <https://spec.openapis.org/oas/v3.1.0#parameterStyle>`__ can describe the structure and syntax of the + parameter. + """ + + style: str | None = None + """Describes how the parameter value will be serialized depending on the + type of the parameter value. Default values (based on value of ``in``): + + - for ``query`` - ``form``; + - for ``path`` - ``simple``; + - for ``header`` - ``simple``; + - for ``cookie`` - ``form``. + """ + + explode: bool | None = None + """When this is true, parameter values of type ``array`` or ``object`` generate separate parameters for each value + of the array or key-value pair of the map. + + For other types of parameters this property has no effect.When + `style <https://spec.openapis.org/oas/v3.1.0#parameterStyle>`__ is ``form``, the default value is ``True``. For all + other styles, the default value is ``False``. + """ + + allow_reserved: bool = False + """Determines whether the parameter value SHOULD allow reserved characters, as defined by. :rfc:`3986` + (``:/?#[]@!$&'()*+,;=``) to be included without percent-encoding. + + This property only applies to parameters with an ``in`` value of ``query``. The default value is ``False``. + """ + + example: Any | None = None + """Example of the parameter's potential value. + + The example SHOULD match the specified schema and encoding properties if present. The ``example`` field is mutually + exclusive of the ``examples`` field. Furthermore, if referencing a ``schema`` that contains an example, the + ``example`` value SHALL _override_ the example provided by the schema. To represent examples of media types that + cannot naturally be represented in JSON or YAML, a string value can contain the example with escaping where + necessary. + """ + + examples: dict[str, Example | Reference] | None = None + """Examples of the parameter's potential value. Each example SHOULD contain a value in the correct format as + specified in the parameter encoding. The ``examples`` field is mutually exclusive of the ``example`` field. + Furthermore, if referencing a ``schema`` that contains an example, the ``examples`` value SHALL _override_ the + example provided by the schema. + + For more complex scenarios, the `content <https://spec.openapis.org/oas/v3.1.0#parameterContent>`_ property can + define the media type and schema of the parameter. A parameter MUST contain either a ``schema`` property, or a + ``content`` property, but not both. When ``example`` or ``examples`` are provided in conjunction with the ``schema`` + object, the example MUST follow the prescribed serialization strategy for the parameter. + """ + + content: dict[str, OpenAPIMediaType] | None = None + """A map containing the representations for the parameter. + + The key is the media type and the value describes it. The map MUST only contain one entry. + """ diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/info.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/info.py new file mode 100644 index 0000000..1d858db --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/info.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from litestar.openapi.spec.base import BaseSchemaObject + +if TYPE_CHECKING: + from litestar.openapi.spec.contact import Contact + from litestar.openapi.spec.license import License + +__all__ = ("Info",) + + +@dataclass +class Info(BaseSchemaObject): + """The object provides metadata about the API. + + The metadata MAY be used by the clients if needed, and MAY be presented in editing or documentation generation tools + for convenience. + """ + + title: str + """ + **REQUIRED**. The title of the API. + """ + + version: str + """ + **REQUIRED**. The version of the OpenAPI document which is distinct from the + `OpenAPI Specification version <https://spec.openapis.org/oas/v3.1.0#oasVersion>`_ or the API implementation version + """ + + summary: str | None = None + """A short summary of the API.""" + + description: str | None = None + """A description of the API. + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + """ + + terms_of_service: str | None = None + """A URL to the Terms of Service for the API. MUST be in the form of a URL.""" + + contact: Contact | None = None + """The contact information for the exposed API.""" + + license: License | None = None + """The license information for the exposed API.""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/license.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/license.py new file mode 100644 index 0000000..b779bb4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/license.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from litestar.openapi.spec.base import BaseSchemaObject + +__all__ = ("License",) + + +@dataclass +class License(BaseSchemaObject): + """License information for the exposed API.""" + + name: str + """**REQUIRED**. The license name used for the API.""" + + identifier: str | None = None + """An + `SPDX <https://spdx.github.io/spdx-spec/v2.3/SPDX-license-list/#a1-licenses-with-short-identifiers>`_ license expression for the API. + + The ``identifier`` field is mutually exclusive of the ``url`` field. + """ + + url: str | None = None + """A URL to the license used for the API. + + This MUST be in the form of a URL. The ``url`` field is mutually exclusive of the ``identifier`` field. + """ diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/link.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/link.py new file mode 100644 index 0000000..78c4f85 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/link.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from litestar.openapi.spec.base import BaseSchemaObject + +if TYPE_CHECKING: + from litestar.openapi.spec.server import Server + +__all__ = ("Link",) + + +@dataclass +class Link(BaseSchemaObject): + """The ``Link object`` represents a possible design-time link for a response. The presence of a link does not + guarantee the caller's ability to successfully invoke it, rather it provides a known relationship and traversal + mechanism between responses and other operations. + + Unlike _dynamic_ links (i.e. links provided **in** the response payload), the OAS linking mechanism does not require + link information in the runtime response. + + For computing links, and providing instructions to execute them, a + `runtime expression <https://spec.openapis.org/oas/v3.1.0#runtimeExpression>`_ is used for accessing values in an + operation and using them as parameters while invoking the linked operation. + """ + + operation_ref: str | None = None + """A relative or absolute URI reference to an OAS operation. + + This field is mutually exclusive of the ``operationId`` field, and MUST point to an + `Operation Object <https://spec.openapis.org/oas/v3.1.0#operationObject>`_. Relative ``operationRef`` values MAY be + used to locate an existing `Operation Object <https://spec.openapis.org/oas/v3.1.0#operationObject>`_ in the OpenAPI + definition. See the rules for resolving + `Relative References <https://spec.openapis.org/oas/v3.1.0#relativeReferencesURI>`_ + """ + + operation_id: str | None = None + """The name of an _existing_, resolvable OAS operation, as defined with a unique ``operationId``. + + This field is mutually exclusive of the ``operationRef`` field. + """ + + parameters: dict[str, Any] | None = None + """A map representing parameters to pass to an operation as specified with ``operationId`` or identified via + ``operationRef``. The key is the parameter name to be used, whereas the value can be a constant or an expression to + be evaluated and passed to the linked operation. + + The parameter name can be qualified using the + `parameter location <https://spec.openapis.org/oas/v3.1.0#parameterIn>`_ ``[{in}.]{name}`` for operations that use + the same parameter name in different locations (e.g. path.id). + """ + + request_body: Any | None = None + """A literal value or + `{expression} <https://spec.openapis.org/oas/v3.1.0#runtimeExpression>`_ to use as a request body when calling the + target operation.""" + + description: str | None = None + """A description of the link. + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + """ + + server: Server | None = None + """A server object to be used by the target operation.""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/media_type.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/media_type.py new file mode 100644 index 0000000..3e83fb5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/media_type.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from litestar.openapi.spec.base import BaseSchemaObject + +if TYPE_CHECKING: + from litestar.openapi.spec.encoding import Encoding + from litestar.openapi.spec.example import Example + from litestar.openapi.spec.reference import Reference + from litestar.openapi.spec.schema import Schema + +__all__ = ("OpenAPIMediaType",) + + +@dataclass +class OpenAPIMediaType(BaseSchemaObject): + """Each Media Type Object provides schema and examples for the media type identified by its key.""" + + schema: Reference | Schema | None = None + """The schema defining the content of the request, response, or parameter.""" + + example: Any | None = None + """Example of the media type. + + The example object SHOULD be in the correct format as specified by the media type. + + The ``example`` field is mutually exclusive of the ``examples`` field. + + Furthermore, if referencing a ``schema`` which contains an example, the ``example`` value SHALL _override_ the + example provided by the schema. + """ + + examples: dict[str, Example | Reference] | None = None + """Examples of the media type. + + Each example object SHOULD match the media type and specified schema if present. + + The ``examples`` field is mutually exclusive of the ``example`` field. + + Furthermore, if referencing a ``schema`` which contains an example, the ``examples`` value SHALL _override_ the + example provided by the schema. + """ + + encoding: dict[str, Encoding] | None = None + """A map between a property name and its encoding information. + + The key, being the property name, MUST exist in the schema as a property. The encoding object SHALL only apply to + ``requestBody`` objects when the media type is ``multipart`` or ``application/x-www-form-urlencoded``. + """ diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/oauth_flow.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/oauth_flow.py new file mode 100644 index 0000000..c322efc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/oauth_flow.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from litestar.openapi.spec.base import BaseSchemaObject + +__all__ = ("OAuthFlow",) + + +@dataclass +class OAuthFlow(BaseSchemaObject): + """Configuration details for a supported OAuth Flow.""" + + authorization_url: str | None = None + """ + **REQUIRED** for ``oauth2`` ("implicit", "authorizationCode"). The authorization URL to be used for this flow. This + MUST be in the form of a URL. The OAuth2 standard requires the use of TLS. + """ + + token_url: str | None = None + """ + **REQUIRED** for ``oauth2`` ("password", "clientCredentials", "authorizationCode"). The token URL to be used for + this flow. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS. + """ + + refresh_url: str | None = None + """The URL to be used for obtaining refresh tokens. + + This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS. + """ + + scopes: dict[str, str] | None = None + """ + **REQUIRED** for ``oauth2``. The available scopes for the OAuth2 security scheme. A map between the scope name and a + short description for it the map MAY be empty. + """ diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/oauth_flows.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/oauth_flows.py new file mode 100644 index 0000000..920d095 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/oauth_flows.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from litestar.openapi.spec.base import BaseSchemaObject + +if TYPE_CHECKING: + from litestar.openapi.spec.oauth_flow import OAuthFlow + +__all__ = ("OAuthFlows",) + + +@dataclass +class OAuthFlows(BaseSchemaObject): + """Allows configuration of the supported OAuth Flows.""" + + implicit: OAuthFlow | None = None + """Configuration for the OAuth Implicit flow.""" + + password: OAuthFlow | None = None + """Configuration for the OAuth Resource Owner Password flow.""" + + client_credentials: OAuthFlow | None = None + """Configuration for the OAuth Client Credentials flow. Previously called ``application`` in OpenAPI 2.0.""" + + authorization_code: OAuthFlow | None = None + """Configuration for the OAuth Authorization Code flow. Previously called ``accessCode`` in OpenAPI 2.0.""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/open_api.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/open_api.py new file mode 100644 index 0000000..55465fa --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/open_api.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from litestar.openapi.spec.base import BaseSchemaObject +from litestar.openapi.spec.components import Components +from litestar.openapi.spec.server import Server + +if TYPE_CHECKING: + from litestar.openapi.spec.external_documentation import ExternalDocumentation + from litestar.openapi.spec.info import Info + from litestar.openapi.spec.path_item import PathItem + from litestar.openapi.spec.paths import Paths + from litestar.openapi.spec.reference import Reference + from litestar.openapi.spec.security_requirement import SecurityRequirement + from litestar.openapi.spec.tag import Tag + +__all__ = ("OpenAPI",) + + +@dataclass +class OpenAPI(BaseSchemaObject): + """Root OpenAPI document.""" + + info: Info + """ + **REQUIRED**. Provides metadata about the API. The metadata MAY be used by tooling as required. + """ + + openapi: str = "3.1.0" + """ + **REQUIRED**. This string MUST be the + `version number <https://spec.openapis.org/oas/v3.1.0#versions>`_ of the OpenAPI Specification that the OpenAPI + document uses. The ``openapi`` field SHOULD be used by tooling to interpret the OpenAPI document. This is *not* + related to the API `info.version <https://spec.openapis.org/oas/v3.1.0#infoVersion>`_ string. + """ + + json_schema_dialect: str | None = None + """The default value for the ``$schema`` keyword within + `Schema Objects <https://spec.openapis.org/oas/v3.1.0#schemaObject>`_ contained within this OAS document. + + This MUST be in the form of a URI. + """ + + servers: list[Server] = field(default_factory=lambda x: [Server(url="/")]) # type: ignore[misc, arg-type] + """An array of Server Objects, which provide connectivity information to a target server. + + If the ``servers`` property is not provided, or is an empty array, the default value would be a + `Server Object <https://spec.openapis.org/oas/v3.1.0#serverObject>`_ with a + `url <https://spec.openapis.org/oas/v3.1.0#serverUrl>`_ value of ``/``. + """ + + paths: Paths | None = None + """The available paths and operations for the API.""" + + webhooks: dict[str, PathItem | Reference] | None = None + """The incoming webhooks that MAY be received as part of this API and that the API consumer MAY choose to implement. + + Closely related to the ``callbacks`` feature, this section describes requests initiated other than by an API call, + for example by an out of band registration. The key name is a unique string to refer to each webhook, while the + (optionally referenced) Path Item Object describes a request that may be initiated by the API provider and the + expected responses. An + `example <https://github.com/OAI/OpenAPI-Specification/blob/main/examples/v3.1/webhook-example.yaml>`_ is available. + """ + + components: Components = field(default_factory=Components) + """An element to hold various schemas for the document.""" + + security: list[SecurityRequirement] | None = None + """A declaration of which security mechanisms can be used across the API. + + The list of values includes alternative security requirement objects that can be used. Only one of the security + requirement objects need to be satisfied to authorize a request. Individual operations can override this definition. + To make security optional, an empty security requirement ( ``{}`` ) can be included in the array. + """ + + tags: list[Tag] | None = None + """A list of tags used by the document with additional metadata. + + The order of the tags can be used to reflect on their order by the parsing tools. Not all tags that are used by the + `Operation Object <https://spec.openapis.org/oas/v3.1.0#operationObject>`_ must be declared. The tags that are not + declared MAY be organized randomly or based on the tools' logic. Each tag name in the list MUST be unique. + """ + + external_docs: ExternalDocumentation | None = None + """Additional external documentation.""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/operation.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/operation.py new file mode 100644 index 0000000..ab80181 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/operation.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from litestar.openapi.spec.base import BaseSchemaObject + +if TYPE_CHECKING: + from litestar.openapi.spec.callback import Callback + from litestar.openapi.spec.external_documentation import ExternalDocumentation + from litestar.openapi.spec.parameter import Parameter + from litestar.openapi.spec.reference import Reference + from litestar.openapi.spec.request_body import RequestBody + from litestar.openapi.spec.responses import Responses + from litestar.openapi.spec.security_requirement import SecurityRequirement + from litestar.openapi.spec.server import Server + +__all__ = ("Operation",) + + +@dataclass +class Operation(BaseSchemaObject): + """Describes a single API operation on a path.""" + + tags: list[str] | None = None + """A list of tags for API documentation control. + + Tags can be used for logical grouping of operations by resources or any other qualifier. + """ + + summary: str | None = None + """A short summary of what the operation does.""" + + description: str | None = None + """A verbose explanation of the operation behavior. + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + """ + + external_docs: ExternalDocumentation | None = None + """Additional external documentation for this operation.""" + + operation_id: str | None = None + """Unique string used to identify the operation. + + The id MUST be unique among all operations described in the API. The operationId value is **case-sensitive**. Tools + and libraries MAY use the operationId to uniquely identify an operation, therefore, it is RECOMMENDED to follow + common programming naming conventions. + """ + + parameters: list[Parameter | Reference] | None = None + """A list of parameters that are applicable for this operation. + + If a parameter is already defined at the `Path Item <https://spec.openapis.org/oas/v3.1.0#pathItemParameters>`_, + the new definition will override it but can never remove it. The list MUST NOT include duplicated parameters. A + unique parameter is defined by a combination of a `name <https://spec.openapis.org/oas/v3.1.0#parameterName>`_ and + `location <https://spec.openapis.org/oas/v3.1.0#parameterIn>`_. The list can use the + `Reference Object <https://spec.openapis.org/oas/v3.1.0#referenceObject>`_ to link to parameters that are defined at + the `OpenAPI Object's components/parameters <https://spec.openapis.org/oas/v3.1.0#componentsParameters>`_. + """ + + request_body: RequestBody | Reference | None = None + """The request body applicable for this operation. + + The ``requestBody`` is fully supported in HTTP methods where the HTTP 1.1 specification + :rfc:`7231` has explicitly defined semantics for request bodies. In other cases where the HTTP spec is vague (such + as `GET <https://tools.ietf.org/html/rfc7231#section-4.3.1>`_, + `HEAD <https://tools.ietf.org/html/rfc7231#section-4.3.2>`_ and + `DELETE <https://tools.ietf.org/html/rfc7231#section-4.3.5>`_, ``requestBody`` is permitted but does not have + well-defined semantics and SHOULD be avoided if possible. + """ + + responses: Responses | None = None + """The list of possible responses as they are returned from executing this operation.""" + + callbacks: dict[str, Callback | Reference] | None = None + """A map of possible out-of band callbacks related to the parent operation. + + The key is a unique identifier for the Callback Object. Each value in the map is a + `Callback Object <https://spec.openapis.org/oas/v3.1.0#callbackObject>`_ that describes a request that may be + initiated by the API provider and the expected responses. + """ + + deprecated: bool = False + """Declares this operation to be deprecated. + + Consumers SHOULD refrain from usage of the declared operation. Default value is ``False``. + """ + + security: list[SecurityRequirement] | None = None + """A declaration of which security mechanisms can be used for this operation. + + The list of values includes alternative security requirement objects that can be used. Only one of the security + requirement objects need to be satisfied to authorize a request. To make security optional, an empty security + requirement (``{}``) can be included in the array. This definition overrides any declared top-level + `security <https://spec.openapis.org/oas/v3.1.0#oasSecurity>`_. To remove a top-level security declaration, an empty + array can be used. + """ + + servers: list[Server] | None = None + """An alternative ``server`` array to service this operation. + + If an alternative ``server`` object is specified at the Path Item Object or Root level, it will be overridden by + this value. + """ diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/parameter.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/parameter.py new file mode 100644 index 0000000..74a100f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/parameter.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Mapping + +from litestar.openapi.spec.base import BaseSchemaObject + +if TYPE_CHECKING: + from litestar.openapi.spec.example import Example + from litestar.openapi.spec.media_type import OpenAPIMediaType + from litestar.openapi.spec.reference import Reference + from litestar.openapi.spec.schema import Schema + +__all__ = ("Parameter",) + + +@dataclass +class Parameter(BaseSchemaObject): + """Describes a single operation parameter. + + A unique parameter is defined by a combination of a `name <https://spec.openapis.org/oas/v3.1.0#parameterName>`_ and + `location <https://spec.openapis.org/oas/v3.1.0#parameterIn>`_. + """ + + name: str + """ + **REQUIRED**. The name of the parameter. + Parameter names are *case sensitive*. + + - If `in <https://spec.openapis.org/oas/v3.1.0#parameterIn>`_ is ``"path"``, the ``name`` field MUST correspond to a + template expression occurring within the `path <https://spec.openapis.org/oas/v3.1.0#pathsPath>`_ field in the + `Paths Object <https://spec.openapis.org/oas/v3.1.0#pathsObject>`_. See + `Path Templating <https://spec.openapis.org/oas/v3.1.0#pathTemplating>`_ for further information. + - If `in <https://spec.openapis.org/oas/v3.1.0#parameterIn>`_ is ``"header"`` and the ``name`` field is + ``"Accept"``, ``"Content-Type"`` or ``"Authorization"``, the parameter definition SHALL be ignored. + - For all other cases, the ``name`` corresponds to the parameter name used by the + `in <https://spec.openapis.org/oas/v3.1.0#parameterIn>`_ property. + """ + + param_in: str + """ + **REQUIRED**. The location of the parameter. Possible values are + ``"query"``, ``"header"``, ``"path"`` or ``"cookie"``. + """ + + schema: Schema | Reference | None = None + """The schema defining the type used for the parameter.""" + + description: str | None = None + """A brief description of the parameter. This could contain examples of + use. + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + """ + + required: bool = False + """Determines whether this parameter is mandatory. + + If the `parameter location <https://spec.openapis.org/oas/v3.1.0#parameterIn>`_ is ``"path"``, this property is + **REQUIRED** and its value MUST be ``True``. Otherwise, the property MAY be included and its default value is + ``False``. + """ + + deprecated: bool = False + """Specifies that a parameter is deprecated and SHOULD be transitioned out of usage. + + Default value is ``False``. + """ + + allow_empty_value: bool = False + """Sets the ability to pass empty-valued parameters. This is valid only for ``query`` parameters and allows sending + a parameter with an empty value. Default value is ``False``. If + `style <https://spec.openapis.org/oas/v3.1.0#parameterStyle>`__ is used, and if behavior is ``n/a`` (cannot be + serialized), the value of ``allowEmptyValue`` SHALL be ignored. Use of this property is NOT RECOMMENDED, as it is + likely to be removed in a later revision. + + The rules for serialization of the parameter are specified in one of two ways. For simpler scenarios, a + `schema <https://spec.openapis.org/oas/v3.1.0#parameterSchema>`_ and + `style <https://spec.openapis.org/oas/v3.1.0#parameterStyle>`__ can describe the structure and syntax of the + parameter. + """ + + style: str | None = None + """Describes how the parameter value will be serialized depending on the ype of the parameter value. Default values + (based on value of ``in``): + + - for ``query`` - ``form`` + - for ``path`` - ``simple`` + - for ``header`` - ``simple`` + - for ``cookie`` - ``form`` + """ + + explode: bool | None = None + """When this is true, parameter values of type ``array`` or ``object`` generate separate parameters for each value + of the array or key-value pair of the map. + + For other types of parameters this property has no effect. When + `style <https://spec.openapis.org/oas/v3.1.0#parameterStyle>`__ is ``form``, the default value is ``True``. For all + other styles, the default value is ``False``. + """ + + allow_reserved: bool = False + """Determines whether the parameter value SHOULD allow reserved characters, as defined by. + + :rfc:`3986` ``:/?#[]@!$&'()*+,;=`` to be included without percent-encoding. + + This property only applies to parameters with an ``in`` value of ``query``. The default value is ``False``. + """ + + example: Any | None = None + """Example of the parameter's potential value. + + The example SHOULD match the specified schema and encoding properties if present. The ``example`` field is mutually + exclusive of the ``examples`` field. Furthermore, if referencing a ``schema`` that contains an example, the + ``example`` value SHALL _override_ the example provided by the schema. To represent examples of media types that + cannot naturally be represented in JSON or YAML, a string value can contain the example with escaping where + necessary. + """ + + examples: Mapping[str, Example | Reference] | None = None + """Examples of the parameter's potential value. Each example SHOULD contain a value in the correct format as + specified in the parameter encoding. The ``examples`` field is mutually exclusive of the ``example`` field. + Furthermore, if referencing a ``schema`` that contains an example, the ``examples`` value SHALL _override_ the + example provided by the schema. + + For more complex scenarios, the `content <https://spec.openapis.org/oas/v3.1.0#parameterContent>`_ property can + define the media type and schema of the parameter. A parameter MUST contain either a ``schema`` property, or a + ``content`` property, but not both. When ``example`` or ``examples`` are provided in conjunction with the + ``schema`` object, the example MUST follow the prescribed serialization strategy for the parameter. + """ + + content: dict[str, OpenAPIMediaType] | None = None + """A map containing the representations for the parameter. + + The key is the media type and the value describes it. The map MUST only contain one entry. + """ diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/path_item.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/path_item.py new file mode 100644 index 0000000..17005c5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/path_item.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from litestar.openapi.spec.base import BaseSchemaObject + +if TYPE_CHECKING: + from litestar.openapi.spec.operation import Operation + from litestar.openapi.spec.parameter import Parameter + from litestar.openapi.spec.reference import Reference + from litestar.openapi.spec.server import Server + +__all__ = ("PathItem",) + + +@dataclass +class PathItem(BaseSchemaObject): + """Describes the operations available on a single path. + + A Path Item MAY be empty, due to `ACL constraints <https://spec.openapis.org/oas/v3.1.0#securityFiltering>`_. The + path itself is still exposed to the documentation viewer, but they will not know which operations and parameters are + available. + """ + + ref: str | None = None + """Allows for an external definition of this path item. The referenced structure MUST be in the format of a + `Path Item Object <https://spec.openapis.org/oas/v3.1.0#pathItemObject>`. + + In case a Path Item Object field appears both in the defined object and the referenced object, the behavior is + undefined. See the rules for resolving + `Relative References <https://spec.openapis.org/oas/v3.1.0#relativeReferencesURI>`_. + """ + + summary: str | None = None + """An optional, string summary, intended to apply to all operations in this path.""" + + description: str | None = None + """An optional, string description, intended to apply to all operations in this path. + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + """ + + get: Operation | None = None + """A definition of a GET operation on this path.""" + + put: Operation | None = None + """A definition of a PUT operation on this path.""" + + post: Operation | None = None + """A definition of a POST operation on this path.""" + + delete: Operation | None = None + """A definition of a DELETE operation on this path.""" + + options: Operation | None = None + """A definition of a OPTIONS operation on this path.""" + + head: Operation | None = None + """A definition of a HEAD operation on this path.""" + + patch: Operation | None = None + """A definition of a PATCH operation on this path.""" + + trace: Operation | None = None + """A definition of a TRACE operation on this path.""" + + servers: list[Server] | None = None + """An alternative ``server`` array to service all operations in this path.""" + + parameters: list[Parameter | Reference] | None = None + """A list of parameters that are applicable for all the operations described under this path. These parameters can + be overridden at the operation level, but cannot be removed there. The list MUST NOT include duplicated parameters. + A unique parameter is defined by a combination of a `name <https://spec.openapis.org/oas/v3.1.0#parameterName>`_ and + `location <https://spec.openapis.org/oas/v3.1.0#parameterIn>`_. The list can use the + `Reference Object <https://spec.openapis.org/oas/v3.1.0#referenceObject>`_ to link to parameters that are defined at + the `OpenAPI Object's components/parameters <https://spec.openapis.org/oas/v3.1.0#componentsParameters>`_. + """ diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/paths.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/paths.py new file mode 100644 index 0000000..682664e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/paths.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict + +if TYPE_CHECKING: + from litestar.openapi.spec import PathItem + +Paths = Dict[str, "PathItem"] +"""Holds the relative paths to the individual endpoints and their operations. The path is appended to the URL from the. + +`Server Object <https://spec.openapis.org/oas/v3.1.0#serverObject>`_ in order to construct the full URL. + +The Paths MAY be empty, due to +`Access Control List (ACL) constraints <https://spec.openapis.org/oas/v3.1.0#securityFiltering>`_. + +Patterned Fields + +/{path}: PathItem + +A relative path to an individual endpoint. The field name MUST begin with a forward slash (``/``). The path is +**appended** (no relative URL resolution) to the expanded URL from the +`Server Object <https://spec.openapis.org/oas/v3.1.0#serverObject>`_ 's ``url`` field in order to construct the full +URL. `Path templating <https://spec.openapis.org/oas/v3.1.0#pathTemplating>`_ is allowed. When matching URLs, concrete +(non-templated) paths would be matched before their templated counterparts. Templated paths with the same hierarchy but +different templated names MUST NOT exist as they are identical. In case of ambiguous matching, it's up to the tooling to +decide which one to use. +""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/reference.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/reference.py new file mode 100644 index 0000000..5ee2a95 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/reference.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from litestar.openapi.spec.base import BaseSchemaObject + +__all__ = ("Reference",) + + +@dataclass +class Reference(BaseSchemaObject): + """A simple object to allow referencing other components in the OpenAPI document, internally and externally. + + The ``$ref`` string value contains a URI `RFC3986 <https://tools.ietf.org/html/rfc3986>`_ , which identifies the + location of the value being referenced. + + See the rules for resolving `Relative References <https://spec.openapis.org/oas/v3.1.0#relativeReferencesURI>`_. + """ + + ref: str + """**REQUIRED**. The reference identifier. This MUST be in the form of a URI.""" + + summary: str | None = None + """A short summary which by default SHOULD override that of the referenced component. + + If the referenced object-type does not allow a ``summary`` field, then this field has no effect. + """ + + description: str | None = None + """A description which by default SHOULD override that of the referenced component. + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. If the referenced + object-type does not allow a ``description`` field, then this field has no effect. + """ + + @property + def value(self) -> str: + return self.ref.split("/")[-1] diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/request_body.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/request_body.py new file mode 100644 index 0000000..5e4e195 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/request_body.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from litestar.openapi.spec.base import BaseSchemaObject + +if TYPE_CHECKING: + from litestar.openapi.spec.media_type import OpenAPIMediaType + +__all__ = ("RequestBody",) + + +@dataclass +class RequestBody(BaseSchemaObject): + """Describes a single request body.""" + + content: dict[str, OpenAPIMediaType] + """ + **REQUIRED**. The content of the request body. + The key is a media type or `media type range <https://tools.ietf.org/html/rfc7231#appendix-D>`_ and the value + describes it. + + For requests that match multiple keys, only the most specific key is applicable. e.g. ``text/plain`` overrides + ``text/*`` + """ + + description: str | None = None + """A brief description of the request body. This could contain examples of use. + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + """ + + required: bool = False + """Determines if the request body is required in the request. + + Defaults to ``False``. + """ diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/response.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/response.py new file mode 100644 index 0000000..236bc40 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/response.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from litestar.openapi.spec.base import BaseSchemaObject + +if TYPE_CHECKING: + from litestar.openapi.spec.header import OpenAPIHeader + from litestar.openapi.spec.link import Link + from litestar.openapi.spec.media_type import OpenAPIMediaType + from litestar.openapi.spec.reference import Reference + + +__all__ = ("OpenAPIResponse",) + + +@dataclass +class OpenAPIResponse(BaseSchemaObject): + """Describes a single response from an API Operation, including design-time, static ``links`` to operations based on + the response. + """ + + description: str + """**REQUIRED**. A short description of the response. + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + """ + + headers: dict[str, OpenAPIHeader | Reference] | None = None + """Maps a header name to its definition. + `RFC7230 <https://tools.ietf.org/html/rfc7230#page-22>`_ states header names are case insensitive. + If a response header is defined with the name ``Content-Type``, it SHALL be ignored. + """ + + content: dict[str, OpenAPIMediaType] | None = None + """A map containing descriptions of potential response payloads. The key is a media type or + `media type range <https://tools.ietf.org/html/rfc7231#appendix-D>`_ and the value describes it. + + For responses that match multiple keys, only the most specific key is applicable. e.g. ``text/plain`` overrides + ``text/*`` + """ + + links: dict[str, Link | Reference] | None = None + """A map of operations links that can be followed from the response. + + The key of the map is a short name for the link, following the naming constraints of the names for + `Component Objects <https://spec.openapis.org/oas/v3.1.0#componentsObject>`_. + """ diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/responses.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/responses.py new file mode 100644 index 0000000..0cff680 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/responses.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Union + +if TYPE_CHECKING: + from litestar.openapi.spec.reference import Reference + from litestar.openapi.spec.response import OpenAPIResponse + +Responses = Dict[str, Union["OpenAPIResponse", "Reference"]] +"""A container for the expected responses of an operation. The container maps a +HTTP response code to the expected response. + +The documentation is not necessarily expected to cover all possible HTTP response codes because they may not be known in +advance. However, documentation is expected to cover a successful operation response and any known errors. + +The ``default`` MAY be used as a default response object for all HTTP codes hat are not covered individually by the +specification. + +The ``Responses Object`` MUST contain at least one response code, and it SHOULD be the response for a successful +operation call. + +Fixed Fields + +default: ``Optional[Union[Response, Reference]]`` + +The documentation of responses other than the ones declared for specific HTTP response codes. Use this field to cover +undeclared responses. A `Reference Object <https://spec.openapis.org/oas/v3.1.0#referenceObject>`_ can link to a +response that the `OpenAPI Object's components/responses <https://spec.openapis.org/oas/v3.1.0#componentsResponses>`_ +section defines. + +Patterned Fields +{httpStatusCode}: ``Optional[Union[Response, Reference]]`` + +Any `HTTP status code <https://spec.openapis.org/oas/v3.1.0#httpCodes>`_ can be used as the property name, but only one +property per code, to describe the expected response for that HTTP status code. + +A `Reference Object <https://spec.openapis.org/oas/v3.1.0#referenceObject>`_ can link to a response that is defined in +the `OpenAPI Object's components/responses <https://spec.openapis.org/oas/v3.1.0#componentsResponses>`_ section. This +field MUST be enclosed in quotation marks (for example, ``200``) for compatibility between JSON and YAML. To define a +range of response codes, this field MAY contain the uppercase wildcard character ``X``. For example, ``2XX`` represents +all response codes between ``[200-299]``. Only the following range definitions are allowed: ``1XX``, ``2XX``, ``3XX``, +``4XX``, and ``5XX``. If a response is defined using an explicit code, the explicit code definition takes precedence +over the range definition for that code. +""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/schema.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/schema.py new file mode 100644 index 0000000..4be2b7c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/schema.py @@ -0,0 +1,652 @@ +from __future__ import annotations + +from dataclasses import dataclass, fields, is_dataclass +from typing import TYPE_CHECKING, Any, Hashable, Mapping, Sequence + +from litestar.openapi.spec.base import BaseSchemaObject +from litestar.utils.predicates import is_non_string_sequence + +if TYPE_CHECKING: + from litestar.openapi.spec.discriminator import Discriminator + from litestar.openapi.spec.enums import OpenAPIFormat, OpenAPIType + from litestar.openapi.spec.external_documentation import ExternalDocumentation + from litestar.openapi.spec.reference import Reference + from litestar.openapi.spec.xml import XML + from litestar.types import DataclassProtocol + +__all__ = ("Schema", "SchemaDataContainer") + + +def _recursive_hash(value: Hashable | Sequence | Mapping | DataclassProtocol | type[DataclassProtocol]) -> int: + if isinstance(value, Mapping): + hash_value = 0 + for k, v in value.items(): + if k != "examples": + hash_value += hash(k) + hash_value += _recursive_hash(v) + return hash_value + if is_dataclass(value): + hash_value = hash(type(value).__name__) + for field in fields(value): + if field.name != "examples": + hash_value += hash(field.name) + hash_value += _recursive_hash(getattr(value, field.name, None)) + return hash_value + if is_non_string_sequence(value): + return sum(_recursive_hash(v) for v in value) + return hash(value) if isinstance(value, Hashable) else 0 + + +@dataclass +class Schema(BaseSchemaObject): + """The Schema Object allows the definition of input and output data types. These types can be objects, but also + primitives and arrays. This object is a superset of the + `JSON Schema Specification Draft 2020-12 <https://tools.ietf.org/html/draft-bhutton-json-schema-00>`_. + + For more information about the properties, see + `JSON Schema Core <https://tools.ietf.org/html/draft-wright-json-schema-00>`_ and + `JSON Schema Validation <https://tools.ietf.org/html/draft-wright-json-schema-validation-00>`_. + + Unless stated otherwise, the property definitions follow those of JSON Schema and do not add any additional + semantics. Where JSON Schema indicates that behavior is defined by the application (e.g. for annotations), OAS also + defers the definition of semantics to the application consuming the OpenAPI document. + + The following properties are taken directly from the + `JSON Schema Core <https://tools.ietf.org/html/draft-wright-json-schema-00>`_ and follow the same specifications. + """ + + all_of: Sequence[Reference | Schema] | None = None + """This keyword's value MUST be a non-empty array. Each item of the array MUST be a valid JSON Schema. + + An instance validates successfully against this keyword if it validates successfully against all schemas defined by + this keyword's value. + """ + + any_of: Sequence[Reference | Schema] | None = None + """This keyword's value MUST be a non-empty array. Each item of the array MUST be a valid JSON Schema. + + An instance validates successfully against this keyword if it validates successfully against at least one schema + defined by this keyword's value. Note that when annotations are being collected, all subschemas MUST be examined so + that annotations are collected from each subschema that validates successfully. + """ + + one_of: Sequence[Reference | Schema] | None = None + """This keyword's value MUST be a non-empty array. Each item of the array MUST be a valid JSON Schema. + + An instance validates successfully against this keyword if it validates successfully against exactly one schema + defined by this keyword's value. + """ + + schema_not: Reference | Schema | None = None + """This keyword's value MUST be a valid JSON Schema. + + An instance is valid against this keyword if it fails to validate successfully against the schema defined by this + keyword. + """ + + schema_if: Reference | Schema | None = None + """This keyword's value MUST be a valid JSON Schema. + + This validation outcome of this keyword's subschema has no direct effect on the overall validation result. Rather, + it controls which of the "then" or "else" keywords are evaluated. + + Instances that successfully validate against this keyword's subschema MUST also be valid against the subschema + value of the "then" keyword, if present. + + Instances that fail to validate against this keyword's subschema MUST also be valid against the subschema value of + the "else" keyword, if present. + + If annotations (Section 7.7) are being collected, they are collected rom this keyword's subschema in the usual way, + including when the keyword is present without either "then" or "else". + """ + + then: Reference | Schema | None = None + """This keyword's value MUST be a valid JSON Schema. + + When "if" is present, and the instance successfully validates against its subschema, then validation succeeds + against this keyword if the instance also successfully validates against this keyword's subschema. + + This keyword has no effect when "if" is absent, or when the instance fails to validate against its subschema. + Implementations MUST NOT evaluate the instance against this keyword, for either validation or annotation collection + purposes, in such cases. + """ + + schema_else: Reference | Schema | None = None + """This keyword's value MUST be a valid JSON Schema. + + When "if" is present, and the instance fails to validate against its subschema, then validation succeeds against + this keyword if the instance successfully validates against this keyword's subschema. + + This keyword has no effect when "if" is absent, or when the instance successfully validates against its subschema. + Implementations MUST NOT evaluate the instance against this keyword, for either validation or annotation collection + purposes, in such cases. + """ + + dependent_schemas: dict[str, Reference | Schema] | None = None + """This keyword specifies subschemas that are evaluated if the instance is + an object and contains a certain property. + + This keyword's value MUST be an object. Each value in the object MUST be a valid JSON Schema. + + If the object key is a property in the instance, the entire instance must validate against the subschema. Its use is + dependent on the presence of the property. + + Omitting this keyword has the same behavior as an empty object. + """ + + prefix_items: Sequence[Reference | Schema] | None = None + """The value of "prefixItems" MUST be a non-empty array of valid JSON Schemas. + + Validation succeeds if each element of the instance validates against the schema at the same position, if any. + This keyword does not constrain the length of the array. If the array is longer than this keyword's value, this + keyword validates only the prefix of matching length. + + This keyword produces an annotation value which is the largest index to which this keyword applied a subschema. + he value MAY be a boolean true if a subschema was applied to every index of the instance, such as is produced by the + "items" keyword. This annotation affects the behavior of "items" and "unevaluatedItems". + + Omitting this keyword has the same assertion behavior as an empty array. + """ + + items: Reference | Schema | None = None + """The value of "items" MUST be a valid JSON Schema. + + This keyword applies its subschema to all instance elements at indexes greater than the length of the "prefixItems" + array in the same schema object, as reported by the annotation result of that "prefixItems" keyword. If no such + annotation result exists, "items" applies its subschema to all instance array elements. [[CREF11: Note that the + behavior of "items" without "prefixItems" is identical to that of the schema form of "items" in prior drafts. When + "prefixItems" is present, the behavior of "items" is identical to the former "additionalItems" keyword. ]] + + If the "items" subschema is applied to any positions within the instance array, it produces an annotation result of + boolean true, indicating that all remaining array elements have been evaluated against this keyword's subschema. + + Omitting this keyword has the same assertion behavior as an empty schema. + + Implementations MAY choose to implement or optimize this keyword in another way that produces the same effect, such + as by directly checking for the presence and size of a "prefixItems" array. Implementations that do not support + annotation collection MUST do so. + """ + + contains: Reference | Schema | None = None + """The value of this keyword MUST be a valid JSON Schema. + + An array instance is valid against "contains" if at least one of its elements is valid against the given schema. + The subschema MUST be applied to every array element even after the first match has been found, in order to collect + annotations for use by other keywords. This is to ensure that all possible annotations are collected. + + Logically, the validation result of applying the value subschema to each item in the array MUST be ORed with + "false", resulting in an overall validation result. + + This keyword produces an annotation value which is an array of the indexes to which this keyword validates + successfully when applying its subschema, in ascending order. The value MAY be a boolean "true" if the subschema + validates successfully when applied to every index of the instance. The annotation MUST be present if the instance + array to which this keyword's schema applies is empty. + """ + + properties: dict[str, Reference | Schema] | None = None + """The value of "properties" MUST be an object. Each value of this object MUST be a valid JSON Schema. + + Validation succeeds if, for each name that appears in both the instance and as a name within this keyword's value, + the child instance for that name successfully validates against the corresponding schema. + + The annotation result of this keyword is the set of instance property names matched by this keyword. + + Omitting this keyword has the same assertion behavior as an empty object. + """ + + pattern_properties: dict[str, Reference | Schema] | None = None + """The value of "patternProperties" MUST be an object. Each property name of this object SHOULD be a valid + regular expression, according to the ECMA-262 regular expression dialect. Each property value of this object + MUST be a valid JSON Schema. + + Validation succeeds if, for each instance name that matches any regular expressions that appear as a property name + in this keyword's value, the child instance for that name successfully validates against each schema that + corresponds to a matching regular expression. + + The annotation result of this keyword is the set of instance property names matched by this keyword. + + Omitting this keyword has the same assertion behavior as an empty object. + """ + + additional_properties: Reference | Schema | bool | None = None + """The value of "additionalProperties" MUST be a valid JSON Schema. + + The behavior of this keyword depends on the presence and annotation results of "properties" and "patternProperties" + within the same schema object. Validation with "additionalProperties" applies only to the child values of instance + names that do not appear in the annotation results of either "properties" or "patternProperties". + + For all such properties, validation succeeds if the child instance validates against the "additionalProperties" + schema. + + The annotation result of this keyword is the set of instance property names validated by this keyword's subschema. + + Omitting this keyword has the same assertion behavior as an empty schema. + + Implementations MAY choose to implement or optimize this keyword in another way that produces the same effect, such + as by directly checking the names in "properties" and the patterns in "patternProperties" against the instance + property set. Implementations that do not support annotation collection MUST do so. + """ + + property_names: Reference | Schema | None = None + """The value of "propertyNames" MUST be a valid JSON Schema. + + If the instance is an object, this keyword validates if every property name in the instance validates against the + provided schema. Note the property name that the schema is testing will always be a string. + + Omitting this keyword has the same behavior as an empty schema. + """ + + unevaluated_items: Reference | Schema | None = None + """The value of "unevaluatedItems" MUST be a valid JSON Schema. + + The behavior of this keyword depends on the annotation results of adjacent keywords that apply to the instance + location being validated. Specifically, the annotations from "prefixItems" items", and "contains", which can come + from those keywords when they are adjacent to the "unevaluatedItems" keyword. Those three annotations, as well as + "unevaluatedItems", can also result from any and all adjacent in-place applicator (Section 10.2) keywords. This + includes but is not limited to the in-place applicators defined in this document. + + If no relevant annotations are present, the "unevaluatedItems" subschema MUST be applied to all locations in the + array. If a boolean true value is present from any of the relevant annotations, unevaluatedItems" MUST be ignored. + Otherwise, the subschema MUST be applied to any index greater than the largest annotation value for "prefixItems", + which does not appear in any annotation value for + "contains". + + This means that "prefixItems", "items", "contains", and all in-place applicators MUST be evaluated before this + keyword can be evaluated. Authors of extension keywords MUST NOT define an in-place applicator that would need to be + evaluated after this keyword. + + If the "unevaluatedItems" subschema is applied to any positions within the instance array, it produces an annotation + result of boolean true, analogous to the behavior of "items". + + Omitting this keyword has the same assertion behavior as an empty schema. + """ + + unevaluated_properties: Reference | Schema | None = None + """The value of "unevaluatedProperties" MUST be a valid JSON Schema. + + The behavior of this keyword depends on the annotation results of adjacent keywords that apply to the instance + location being validated. Specifically, the annotations from "properties", "patternProperties", and + "additionalProperties", which can come from those keywords when they are adjacent to the "unevaluatedProperties" + keyword. Those three annotations, as well as "unevaluatedProperties", can also result from any and all adjacent + in-place applicator (Section 10.2) keywords. This includes but is not limited to the in-place applicators defined + in this document. + + Validation with "unevaluatedProperties" applies only to the child values of instance names that do not appear in + the "properties", "patternProperties", "additionalProperties", or "unevaluatedProperties" annotation results that + apply to the instance location being validated. + + For all such properties, validation succeeds if the child instance validates against the "unevaluatedProperties" + schema. + + This means that "properties", "patternProperties", "additionalProperties", and all in-place applicators MUST be + evaluated before this keyword can be evaluated. Authors of extension keywords MUST NOT define an in-place + applicator that would need to be evaluated after this keyword. + + The annotation result of this keyword is the set of instance property names validated by this keyword's subschema. + + Omitting this keyword has the same assertion behavior as an empty schema. + + The following properties are taken directly from the + `JSON Schema Validation <https://tools.ietf.org/html/draft-wright-json-schema-validation-00>`_ and follow the same + specifications: + """ + + type: OpenAPIType | Sequence[OpenAPIType] | None = None + """The value of this keyword MUST be either a string or an array. If it is an array, elements of the array MUST be + strings and MUST be unique. + + String values MUST be one of the six primitive types (``"null"``, ``"boolean"``, ``"object"``, ``"array"``, + ``"number"``, and ``"string"``), or ``"integer"`` which matches any number with a zero fractional part. + + An instance validates if and only if the instance is in any of the sets listed for this keyword. + """ + + enum: Sequence[Any] | None = None + """The value of this keyword MUST be an array. This array SHOULD have at least one element. Elements in the array + SHOULD be unique. + + An instance validates successfully against this keyword if its value is equal to one of the elements in this + keyword's array value. + + Elements in the array might be of any type, including null. + """ + + const: Any | None = None + """The value of this keyword MAY be of any type, including null. + + Use of this keyword is functionally equivalent to an "enum" (Section 6.1.2) with a single value. + + An instance validates successfully against this keyword if its value is equal to the value of the keyword. + """ + + multiple_of: float | None = None + """The value of "multipleOf" MUST be a number, strictly greater than 0. + + A numeric instance is only valid if division by this keyword's value results in an integer. + """ + + maximum: float | None = None + """The value of "maximum" MUST be a number, representing an inclusive upper limit for a numeric instance. + + If the instance is a number, then this keyword validates only if the instance is less than or exactly equal to + "maximum". + """ + + exclusive_maximum: float | None = None + """The value of "exclusiveMaximum" MUST be a number, representing an exclusive upper limit for a numeric instance. + + If the instance is a number, then the instance is valid only if it has a value strictly less than (not equal to) + "exclusiveMaximum". + """ + + minimum: float | None = None + """The value of "minimum" MUST be a number, representing an inclusive lower limit for a numeric instance. + + If the instance is a number, then this keyword validates only if the instance is greater than or exactly equal to + "minimum". + """ + + exclusive_minimum: float | None = None + """The value of "exclusiveMinimum" MUST be a number, representing an exclusive lower limit for a numeric instance. + + If the instance is a number, then the instance is valid only if it has a value strictly greater than (not equal to) + "exclusiveMinimum". + """ + + max_length: int | None = None + """The value of this keyword MUST be a non-negative integer. + + A string instance is valid against this keyword if its length is less than, or equal to, the value of this keyword. + + The length of a string instance is defined as the number of its characters as defined by :rfc:`8259`. + """ + + min_length: int | None = None + """The value of this keyword MUST be a non-negative integer. + + A string instance is valid against this keyword if its length is greater than, or equal to, the value of this + keyword. + + The length of a string instance is defined as the number of its characters as defined by :rfc:`8259`. + + Omitting this keyword has the same behavior as a value of 0. + """ + + pattern: str | None = None + """The value of this keyword MUST be a string. This string SHOULD be a valid regular expression, according to the + ECMA-262 regular expression dialect. + + A string instance is considered valid if the regular expression matches the instance successfully. Recall: regular + expressions are not implicitly anchored. + """ + + max_items: int | None = None + """The value of this keyword MUST be a non-negative integer. + + An array instance is valid against "maxItems" if its size is less than, or equal to, the value of this keyword. + """ + + min_items: int | None = None + """The value of this keyword MUST be a non-negative integer. + + An array instance is valid against "minItems" if its size is greater than, or equal to, the value of this keyword. + + Omitting this keyword has the same behavior as a value of 0. + """ + + unique_items: bool | None = None + """The value of this keyword MUST be a boolean. + + If this keyword has boolean value false, the instance validates successfully. If it has boolean value true, the + instance validates successfully if all of its elements are unique. + + Omitting this keyword has the same behavior as a value of false. + """ + + max_contains: int | None = None + """The value of this keyword MUST be a non-negative integer. + + If "contains" is not present within the same schema object, then this keyword has no effect. + + An instance array is valid against "maxContains" in two ways, depending on the form of the annotation result of an + adjacent "contains" [json-schema] keyword. The first way is if the annotation result is an array and the length of + that array is less than or equal to the "maxContains" value. The second way is if the annotation result is a + boolean "true" and the instance array length is less than r equal to the "maxContains" value. + """ + + min_contains: int | None = None + """The value of this keyword MUST be a non-negative integer. + + If "contains" is not present within the same schema object, then this keyword has no effect. + + An instance array is valid against "minContains" in two ways, depending on the form of the annotation result of an + adjacent "contains" [json-schema] keyword. The first way is if the annotation result is an array and the length of + that array is greater than or equal to the "minContains" value. The second way is if the annotation result is a + boolean "true" and the instance array length is greater than or equal to the "minContains" value. + + A value of 0 is allowed, but is only useful for setting a range of occurrences from 0 to the value of "maxContains". + A value of 0 with no "maxContains" causes "contains" to always pass validation. + + Omitting this keyword has the same behavior as a value of 1. + """ + + max_properties: int | None = None + """The value of this keyword MUST be a non-negative integer. + + An object instance is valid against "maxProperties" if its number of properties is less than, or equal to, the value + of this keyword. + """ + + min_properties: int | None = None + """The value of this keyword MUST be a non-negative integer. + + An object instance is valid against "minProperties" if its number of properties is greater than, or equal to, the + value of this keyword. + + Omitting this keyword has the same behavior as a value of 0. + """ + + required: Sequence[str] | None = None + """The value of this keyword MUST be an array. Elements of this array, if any, MUST be strings, and MUST be unique. + + An object instance is valid against this keyword if every item in the rray is the name of a property in the instance. + + Omitting this keyword has the same behavior as an empty array. + """ + + dependent_required: dict[str, Sequence[str]] | None = None + """The value of this keyword MUST be an object. Properties in this object, f any, MUST be arrays. Elements in each + array, if any, MUST be strings, and MUST be unique. + + This keyword specifies properties that are required if a specific other property is present. Their requirement is + dependent on the presence of the other property. + + Validation succeeds if, for each name that appears in both the instance and as a name within this keyword's value, + every item in the corresponding array is also the name of a property in the instance. + + Omitting this keyword has the same behavior as an empty object. + """ + + format: OpenAPIFormat | None = None + """From OpenAPI: + + See `Data Type Formats <https://spec.openapis.org/oas/v3.1.0#dataTypeFormat>`_ for further details. While relying on + JSON Schema's defined formats, the OAS offers a few additional predefined formats. + + From JSON Schema: + + Structural validation alone may be insufficient to allow an application to correctly utilize certain values. + The "format" annotation keyword is defined to allow schema authors to convey semantic information for a fixed subset + of values which are accurately described by authoritative resources, be they RFCs or other external specifications. + + The value of this keyword is called a format attribute. It MUST be a string. A format attribute can generally only + validate a given set of instance types. If the type of the instance to validate is not in this set, validation for + this format attribute and instance SHOULD succeed. All format attributes defined in this section apply to strings, + but a format attribute can be specified to apply to any instance types defined in the data model defined in the core + JSON Schema. [json-schema] [[CREF1: Note that the "type" keyword in this specification defines an "integer" type + which is not part of the data model. Therefore a format attribute can be limited to numbers, but not specifically to + integers. However, a numeric format can be used alongside the "type" keyword with a value of "integer", or could be + explicitly defined to always pass if the number is not an integer, which produces essentially the same behavior as + only applying to integers. ]] + """ + + content_encoding: str | None = None + """If the instance value is a string, this property defines that the string SHOULD be interpreted as binary data and + decoded using the encoding named by this property. + + Possible values indicating base 16, 32, and 64 encodings with several variations are listed in :rfc:`4648`. + Additionally, sections 6.7 and 6.8 of :rfc:`2045` provide encodings used in MIME. As "base64" is defined in both + RFCs, the definition from :rfc:`4648` SHOULD be assumed unless the string is specifically intended for use in a + MIME context. Note that all of these encodings result in strings consisting only of 7-bit ASCII characters. + therefore, this keyword has no meaning for strings containing characters outside of that range. + + If this keyword is absent, but "contentMediaType" is present, this indicates that the encoding is the identity + encoding, meaning that no transformation was needed in order to represent the content in a UTF-8 string. + """ + + content_media_type: str | None = None + """If the instance is a string, this property indicates the media type of the contents of the string. If + "contentEncoding" is present, this property describes the decoded string. + + The value of this property MUST be a string, which MUST be a media type, as defined by :rfc:`2046` + """ + + content_schema: Reference | Schema | None = None + """If the instance is a string, and if "contentMediaType" is present, this property contains a schema which + describes the structure of the string. + + This keyword MAY be used with any media type that can be mapped into JSON Schema's data model. + + The value of this property MUST be a valid JSON schema. It SHOULD be ignored if "contentMediaType" is not present. + """ + + title: str | None = None + """The value of "title" MUST be a string. + + The title can be used to decorate a user interface with information about the data produced by this user interface. + A title will preferably be short. + """ + + description: str | None = None + """From OpenAPI: + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + + From JSON Schema: + The value "description" MUST be a string. + + The description can be used to decorate a user interface with information about the data produced by this user + interface. A description will provide explanation about the purpose of the instance described by this schema. + """ + + default: Any | None = None + """There are no restrictions placed on the value of this keyword. When multiple occurrences of this keyword are + applicable to a single sub-instance, implementations SHOULD remove duplicates. + + This keyword can be used to supply a default JSON value associated with a particular schema. It is RECOMMENDED that + a default value be valid against the associated schema. + """ + + deprecated: bool | None = None + """The value of this keyword MUST be a boolean. When multiple occurrences of this keyword are applicable to a + single sub-instance, applications SHOULD consider the instance location to be deprecated if any occurrence specifies + a true value. + + If "deprecated" has a value of boolean true, it indicates that applications SHOULD refrain from usage of the + declared property. It MAY mean the property is going to be removed in the future. + + A root schema containing "deprecated" with a value of true indicates that the entire resource being described MAY be + removed in the future. + + The "deprecated" keyword applies to each instance location to which the schema object containing the keyword + successfully applies. This can result in scenarios where every array item or object property is deprecated even + though the containing array or object is not. + + Omitting this keyword has the same behavior as a value of false. + """ + + read_only: bool | None = None + """The value of "readOnly" MUST be a boolean. When multiple occurrences of this keyword are applicable to a single + sub-instance, the resulting behavior SHOULD be as for a true value if any occurrence specifies a true value, and + SHOULD be as for a false value otherwise. + + If "readOnly" has a value of boolean true, it indicates that the value of the instance is managed exclusively by + the owning authority, and attempts by an application to modify the value of this property are expected to be ignored + or rejected by that owning authority. + + An instance document that is marked as "readOnly" for the entire document MAY be ignored if sent to the owning + authority, or MAY result in an error, at the authority's discretion. + + For example, "readOnly" would be used to mark a database-generated serial number as read-only, while "writeOnly" + would be used to mark a password input field. + + This keyword can be used to assist in user interface instance generation. In particular, an application MAY choose + to use a widget that hides input values as they are typed for write-only fields. + + Omitting these keywords has the same behavior as values of false. + """ + + write_only: bool | None = None + """The value of "writeOnly" MUST be a boolean. When multiple occurrences of this keyword are applicable to a + single sub-instance, the resulting behavior SHOULD be as for a true value if any occurrence specifies a true value, + and SHOULD be as for a false value otherwise. + + If "writeOnly" has a value of boolean true, it indicates that the value is never present when the instance is + retrieved from the owning authority. It can be present when sent to the owning authority to update or create the + document (or the resource it represents), but it will not be included in any updated or newly created version of the + instance. + + An instance document that is marked as "writeOnly" for the entire document MAY be returned as a blank document of + some sort, or MAY produce an error upon retrieval, or have the retrieval request ignored, at the authority's + discretion. + + For example, "readOnly" would be used to mark a database-generated serial number as read-only, while "writeOnly" + would be used to mark a password input field. + + This keyword can be used to assist in user interface instance generation. In particular, an application MAY choose + to use a widget that hides input values as they are typed for write-only fields. + + Omitting these keywords has the same behavior as values of false. + """ + + examples: list[Any] | None = None + """The value of this must be an array containing the example values.""" + + discriminator: Discriminator | None = None + """Adds support for polymorphism. + + The discriminator is an object name that is used to differentiate between other schemas which may satisfy the + payload description. See `Composition and Inheritance <https://spec.openapis.org/oas/v3.1.0#schemaComposition>`_ + for more details. + """ + + xml: XML | None = None + """This MAY be used only on properties schemas. + + It has no effect on root schemas. Adds additional metadata to describe the XML representation of this property. + """ + + external_docs: ExternalDocumentation | None = None + """Additional external documentation for this schema.""" + + example: Any | None = None + """A free-form property to include an example of an instance for this schema. To represent examples that cannot be + naturally represented in JSON or YAML, a string value can be used to contain the example with escaping where + necessary. + + Deprecated: The example property has been deprecated in favor of the JSON Schema examples keyword. Use of example is + discouraged, and later versions of this specification may remove it. + """ + + def __hash__(self) -> int: + return _recursive_hash(self) + + +@dataclass +class SchemaDataContainer(Schema): + """Special class that allows using python data containers, e.g. dataclasses or pydantic models, to represent a + schema + """ + + data_container: Any = None + """A data container instance that will be used to generate the schema.""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/security_requirement.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/security_requirement.py new file mode 100644 index 0000000..e3d1394 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/security_requirement.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import Dict, List + +SecurityRequirement = Dict[str, List[str]] +"""Lists the required security schemes to execute this operation. The name used for each property MUST correspond to a +security scheme declared in the. + +`Security Schemes <https://spec.openapis.org/oas/v3.1.0#componentsSecuritySchemes>`_ under the +`Components Object <https://spec.openapis.org/oas/v3.1.0#componentsObject>`_. + +Security Requirement Objects that contain multiple schemes require that all schemes MUST be satisfied for a request to +be authorized. This enables support for scenarios where multiple query parameters or HTTP headers are required to convey +security information. + +When a list of Security Requirement Objects is defined on the +`OpenAPI Object <https://spec.openapis.org/oas/v3.1.0#oasObject>`_ or +`Operation Object <https://spec.openapis.org/oas/v3.1.0#operationObject>`_, only one of the Security Requirement +Objects in the list needs to be satisfied to authorize the request. + +Patterned Fields + +{name}: ``List[str]`` +Each name MUST correspond to a security scheme which is declared +in the `Security Schemes <https://spec.openapis.org/oas/v3.1.0#componentsSecuritySchemes>`_ under the +`Components Object <https://spec.openapis.org/oas/v3.1.0#componentsObject>`_. if the security scheme is of type +``"oauth2"`` or ``"openIdConnect"``, then the value is a list of scope names required for the execution, and the list +MAY be empty if authorization does not require a specified scope. For other security scheme types, the array MAY contain +a list of role names which are required for the execution,but are not otherwise defined or exchanged in-band. +""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/security_scheme.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/security_scheme.py new file mode 100644 index 0000000..badc77e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/security_scheme.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Literal + +from litestar.openapi.spec.base import BaseSchemaObject + +if TYPE_CHECKING: + from litestar.openapi.spec.oauth_flows import OAuthFlows + +__all__ = ("SecurityScheme",) + + +@dataclass +class SecurityScheme(BaseSchemaObject): + """Defines a security scheme that can be used by the operations. + + Supported schemes are HTTP authentication, an API key (either as a header, a cookie parameter or as a query + parameter), mutual TLS (use of a client certificate), OAuth2's common flows (implicit, password, client credentials + and authorization code) as defined in :rfc`6749`, and + `OpenID Connect Discovery <https://tools.ietf.org/html/draft-ietf-oauth-discovery-06>`_. + + Please note that as of 2020, the implicit flow is about to be deprecated by + `OAuth 2.0 Security Best Current Practice <https://tools.ietf.org/html/draft-ietf-oauth-security-topics>`_. + Recommended for most use case is Authorization Code Grant flow with PKCE. + """ + + type: Literal["apiKey", "http", "mutualTLS", "oauth2", "openIdConnect"] + """**REQUIRED**. The type of the security scheme.""" + + description: str | None = None + """A description for security scheme. + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + """ + + name: str | None = None + """ + **REQUIRED** for ``apiKey``. The name of the header, query or cookie parameter to be used. + """ + + security_scheme_in: Literal["query", "header", "cookie"] | None = None + """ + **REQUIRED** for ``apiKey``. The location of the API key. + """ + + scheme: str | None = None + """ + **REQUIRED** for ``http``. The name of the HTTP Authorization scheme to be used in the + authorization header as defined in :rfc:`7235`. + + The values used SHOULD be registered in the + `IANA Authentication Scheme registry <https://www.iana.org/assignments/http-authschemes/http-authschemes.xhtml>`_ + """ + + bearer_format: str | None = None + """A hint to the client to identify how the bearer token is formatted. + + Bearer tokens are usually generated by an authorization server, so this information is primarily for documentation + purposes. + """ + + flows: OAuthFlows | None = None + """**REQUIRED** for ``oauth2``. An object containing configuration information for the flow types supported.""" + + open_id_connect_url: str | None = None + """**REQUIRED** for ``openIdConnect``. OpenId Connect URL to discover OAuth2 configuration values. This MUST be in + the form of a URL. The OpenID Connect standard requires the use of TLS. + """ diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/server.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/server.py new file mode 100644 index 0000000..2c12fcf --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/server.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from litestar.openapi.spec.base import BaseSchemaObject + +if TYPE_CHECKING: + from litestar.openapi.spec.server_variable import ServerVariable + +__all__ = ("Server",) + + +@dataclass +class Server(BaseSchemaObject): + """An object representing a Server.""" + + url: str + """ + **REQUIRED**. A URL to the target host. + + This URL supports Server Variables and MAY be relative, to indicate that the host location is relative to the + location where the OpenAPI document is being served. Variable substitutions will be made when a variable is named in + ``{brackets}``. + """ + + description: str | None = None + """An optional string describing the host designated by the URL. + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + """ + + variables: dict[str, ServerVariable] | None = None + """A map between a variable name and its value. The value is used for substitution in the server's URL template.""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/server_variable.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/server_variable.py new file mode 100644 index 0000000..c59c542 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/server_variable.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from litestar.openapi.spec.base import BaseSchemaObject + +__all__ = ("ServerVariable",) + + +@dataclass +class ServerVariable(BaseSchemaObject): + """An object representing a Server Variable for server URL template substitution.""" + + default: str + """**REQUIRED**. The default value to use for substitution, which SHALL be sent if an alternate value is _not_ + supplied. Note this behavior is different than the + `Schema Object's <https://spec.openapis.org/oas/v3.1.0#schemaObject>`_ treatment of default values, because in those + cases parameter values are optional. If the `enum <https://spec.openapis.org/oas/v3.1.0#serverVariableEnum>`_ is + defined, the value MUST exist in the enum's values. + """ + + enum: list[str] | None = None + """An enumeration of string values to be used if the substitution options are from a limited set. + + The array SHOULD NOT be empty. + """ + + description: str | None = None + """An optional description for the server variable. + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + """ diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/tag.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/tag.py new file mode 100644 index 0000000..c3e6374 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/tag.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from litestar.openapi.spec.base import BaseSchemaObject + +if TYPE_CHECKING: + from litestar.openapi.spec.external_documentation import ExternalDocumentation + +__all__ = ("Tag",) + + +@dataclass +class Tag(BaseSchemaObject): + """Adds metadata to a single tag that is used by the + `Operation Object <https://spec.openapis.org/oas/v3.1.0#operationObject>`_. + + It is not mandatory to have a Tag Object per tag defined in the Operation Object instances. + """ + + name: str + """**REQUIRED**. The name of the tag.""" + + description: str | None = None + """A short description for the tag. + + `CommonMark syntax <https://spec.commonmark.org/>`_ MAY be used for rich text representation. + """ + + external_docs: ExternalDocumentation | None = None + """Additional external documentation for this tag.""" diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/xml.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/xml.py new file mode 100644 index 0000000..6030998 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/xml.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from litestar.openapi.spec.base import BaseSchemaObject + +__all__ = ("XML",) + + +@dataclass() +class XML(BaseSchemaObject): + """A metadata object that allows for more fine-tuned XML model definitions. + + When using arrays, XML element names are *not* inferred (for singular/plural forms) and the ``name`` property SHOULD + be used to add that information. See examples for expected behavior. + """ + + name: str | None = None + """ + Replaces the name of the element/attribute used for the described schema property. When defined within ``items``, it + will affect the name of the individual XML elements within the list. When defined alongside ``type`` being ``array`` + (outside the ``items``), it will affect the wrapping element and only if ``wrapped`` is ``True``. If ``wrapped`` is + ``False``, it will be ignored. + """ + + namespace: str | None = None + """The URI of the namespace definition. Value MUST be in the form of an absolute URI.""" + + prefix: str | None = None + """The prefix to be used for the + `xmlName <https://spec.openapis.org/oas/v3.1.0#xmlName>`_ + """ + + attribute: bool = False + """Declares whether the property definition translates to an attribute instead of an element. Default value is + ``False``. + """ + + wrapped: bool = False + """ + MAY be used only for an array definition. Signifies whether the array is wrapped (for example, + ``<books><book/><book/></books>``) or unwrapped (``<book/><book/>``). Default value is ``False``. The definition + takes effect only when defined alongside ``type`` being ``array`` (outside the ``items``). + """ diff --git a/venv/lib/python3.11/site-packages/litestar/pagination.py b/venv/lib/python3.11/site-packages/litestar/pagination.py new file mode 100644 index 0000000..294a13a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/pagination.py @@ -0,0 +1,342 @@ +# ruff: noqa: UP006,UP007 +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Generic, List, Optional, TypeVar +from uuid import UUID + +__all__ = ( + "AbstractAsyncClassicPaginator", + "AbstractAsyncCursorPaginator", + "AbstractAsyncOffsetPaginator", + "AbstractSyncClassicPaginator", + "AbstractSyncCursorPaginator", + "AbstractSyncOffsetPaginator", + "ClassicPagination", + "CursorPagination", + "OffsetPagination", +) + + +T = TypeVar("T") +C = TypeVar("C", int, str, UUID) + + +@dataclass +class ClassicPagination(Generic[T]): + """Container for data returned using limit/offset pagination.""" + + __slots__ = ("items", "page_size", "current_page", "total_pages") + + items: List[T] + """List of data being sent as part of the response.""" + page_size: int + """Number of items per page.""" + current_page: int + """Current page number.""" + total_pages: int + """Total number of pages.""" + + +@dataclass +class OffsetPagination(Generic[T]): + """Container for data returned using limit/offset pagination.""" + + __slots__ = ("items", "limit", "offset", "total") + + items: List[T] + """List of data being sent as part of the response.""" + limit: int + """Maximal number of items to send.""" + offset: int + """Offset from the beginning of the query. + + Identical to an index. + """ + total: int + """Total number of items.""" + + +@dataclass +class CursorPagination(Generic[C, T]): + """Container for data returned using cursor pagination.""" + + __slots__ = ("items", "results_per_page", "cursor", "next_cursor") + + items: List[T] + """List of data being sent as part of the response.""" + results_per_page: int + """Maximal number of items to send.""" + cursor: Optional[C] + """Unique ID, designating the last identifier in the given data set. + + This value can be used to request the "next" batch of records. + """ + + +class AbstractSyncClassicPaginator(ABC, Generic[T]): + """Base paginator class for sync classic pagination. + + Implement this class to return paginated result sets using the classic pagination scheme. + """ + + @abstractmethod + def get_total(self, page_size: int) -> int: + """Return the total number of records. + + Args: + page_size: Maximal number of records to return. + + Returns: + An integer. + """ + raise NotImplementedError + + @abstractmethod + def get_items(self, page_size: int, current_page: int) -> list[T]: + """Return a list of items of the given size 'page_size' correlating with 'current_page'. + + Args: + page_size: Maximal number of records to return. + current_page: The current page of results to return. + + Returns: + A list of items. + """ + raise NotImplementedError + + def __call__(self, page_size: int, current_page: int) -> ClassicPagination[T]: + """Return a paginated result set. + + Args: + page_size: Maximal number of records to return. + current_page: The current page of results to return. + + Returns: + A paginated result set. + """ + total_pages = self.get_total(page_size=page_size) + + items = self.get_items(page_size=page_size, current_page=current_page) + + return ClassicPagination[T]( + items=items, total_pages=total_pages, page_size=page_size, current_page=current_page + ) + + +class AbstractAsyncClassicPaginator(ABC, Generic[T]): + """Base paginator class for async classic pagination. + + Implement this class to return paginated result sets using the classic pagination scheme. + """ + + @abstractmethod + async def get_total(self, page_size: int) -> int: + """Return the total number of records. + + Args: + page_size: Maximal number of records to return. + + Returns: + An integer. + """ + raise NotImplementedError + + @abstractmethod + async def get_items(self, page_size: int, current_page: int) -> list[T]: + """Return a list of items of the given size 'page_size' correlating with 'current_page'. + + Args: + page_size: Maximal number of records to return. + current_page: The current page of results to return. + + Returns: + A list of items. + """ + raise NotImplementedError + + async def __call__(self, page_size: int, current_page: int) -> ClassicPagination[T]: + """Return a paginated result set. + + Args: + page_size: Maximal number of records to return. + current_page: The current page of results to return. + + Returns: + A paginated result set. + """ + total_pages = await self.get_total(page_size=page_size) + + items = await self.get_items(page_size=page_size, current_page=current_page) + + return ClassicPagination[T]( + items=items, total_pages=total_pages, page_size=page_size, current_page=current_page + ) + + +class AbstractSyncOffsetPaginator(ABC, Generic[T]): + """Base paginator class for limit / offset pagination. + + Implement this class to return paginated result sets using the limit / offset pagination scheme. + """ + + @abstractmethod + def get_total(self) -> int: + """Return the total number of records. + + Returns: + An integer. + """ + raise NotImplementedError + + @abstractmethod + def get_items(self, limit: int, offset: int) -> list[T]: + """Return a list of items of the given size 'limit' starting from position 'offset'. + + Args: + limit: Maximal number of records to return. + offset: Starting position within the result set (assume index 0 as starting position). + + Returns: + A list of items. + """ + raise NotImplementedError + + def __call__(self, limit: int, offset: int) -> OffsetPagination[T]: + """Return a paginated result set. + + Args: + limit: Maximal number of records to return. + offset: Starting position within the result set (assume index 0 as starting position). + + Returns: + A paginated result set. + """ + total = self.get_total() + + items = self.get_items(limit=limit, offset=offset) + + return OffsetPagination[T](items=items, total=total, offset=offset, limit=limit) + + +class AbstractAsyncOffsetPaginator(ABC, Generic[T]): + """Base paginator class for limit / offset pagination. + + Implement this class to return paginated result sets using the limit / offset pagination scheme. + """ + + @abstractmethod + async def get_total(self) -> int: + """Return the total number of records. + + Returns: + An integer. + """ + raise NotImplementedError + + @abstractmethod + async def get_items(self, limit: int, offset: int) -> list[T]: + """Return a list of items of the given size 'limit' starting from position 'offset'. + + Args: + limit: Maximal number of records to return. + offset: Starting position within the result set (assume index 0 as starting position). + + Returns: + A list of items. + """ + raise NotImplementedError + + async def __call__(self, limit: int, offset: int) -> OffsetPagination[T]: + """Return a paginated result set. + + Args: + limit: Maximal number of records to return. + offset: Starting position within the result set (assume index 0 as starting position). + + Returns: + A paginated result set. + """ + total = await self.get_total() + items = await self.get_items(limit=limit, offset=offset) + + return OffsetPagination[T](items=items, total=total, offset=offset, limit=limit) + + +class AbstractSyncCursorPaginator(ABC, Generic[C, T]): + """Base paginator class for sync cursor pagination. + + Implement this class to return paginated result sets using the cursor pagination scheme. + """ + + @abstractmethod + def get_items(self, cursor: C | None, results_per_page: int) -> tuple[list[T], C | None]: + """Return a list of items of the size 'results_per_page' following the given cursor, if any, + + Args: + cursor: A unique identifier that acts as the 'cursor' after which results should be given. + results_per_page: A maximal number of results to return. + + Returns: + A tuple containing the result set and a new cursor that marks the last record retrieved. + The new cursor can be used to ask for the 'next_cursor' batch of results. + """ + raise NotImplementedError + + def __call__(self, cursor: C | None, results_per_page: int) -> CursorPagination[C, T]: + """Return a paginated result set given an optional cursor (unique ID) and a maximal number of results to return. + + Args: + cursor: A unique identifier that acts as the 'cursor' after which results should be given. + results_per_page: A maximal number of results to return. + + Returns: + A paginated result set. + """ + items, new_cursor = self.get_items(cursor=cursor, results_per_page=results_per_page) + + return CursorPagination[C, T]( + items=items, + results_per_page=results_per_page, + cursor=new_cursor, + ) + + +class AbstractAsyncCursorPaginator(ABC, Generic[C, T]): + """Base paginator class for async cursor pagination. + + Implement this class to return paginated result sets using the cursor pagination scheme. + """ + + @abstractmethod + async def get_items(self, cursor: C | None, results_per_page: int) -> tuple[list[T], C | None]: + """Return a list of items of the size 'results_per_page' following the given cursor, if any, + + Args: + cursor: A unique identifier that acts as the 'cursor' after which results should be given. + results_per_page: A maximal number of results to return. + + Returns: + A tuple containing the result set and a new cursor that marks the last record retrieved. + The new cursor can be used to ask for the 'next_cursor' batch of results. + """ + raise NotImplementedError + + async def __call__(self, cursor: C | None, results_per_page: int) -> CursorPagination[C, T]: + """Return a paginated result set given an optional cursor (unique ID) and a maximal number of results to return. + + Args: + cursor: A unique identifier that acts as the 'cursor' after which results should be given. + results_per_page: A maximal number of results to return. + + Returns: + A paginated result set. + """ + items, new_cursor = await self.get_items(cursor=cursor, results_per_page=results_per_page) + + return CursorPagination[C, T]( + items=items, + results_per_page=results_per_page, + cursor=new_cursor, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/params.py b/venv/lib/python3.11/site-packages/litestar/params.py new file mode 100644 index 0000000..bff010b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/params.py @@ -0,0 +1,383 @@ +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING, Any, Hashable, Sequence + +from litestar.enums import RequestEncodingType +from litestar.types import Empty + +__all__ = ( + "Body", + "BodyKwarg", + "Dependency", + "DependencyKwarg", + "KwargDefinition", + "Parameter", + "ParameterKwarg", +) + + +if TYPE_CHECKING: + from litestar.openapi.spec.example import Example + from litestar.openapi.spec.external_documentation import ( + ExternalDocumentation, + ) + + +@dataclass(frozen=True) +class KwargDefinition: + """Data container representing a constrained kwarg.""" + + examples: list[Example] | None = field(default=None) + """A list of Example models.""" + external_docs: ExternalDocumentation | None = field(default=None) + """A url pointing at external documentation for the given parameter.""" + content_encoding: str | None = field(default=None) + """The content encoding of the value. + + Applicable on to string values. See OpenAPI 3.1 for details. + """ + default: Any = field(default=Empty) + """A default value. + + If const is true, this value is required. + """ + title: str | None = field(default=None) + """String value used in the title section of the OpenAPI schema for the given parameter.""" + description: str | None = field(default=None) + """String value used in the description section of the OpenAPI schema for the given parameter.""" + const: bool | None = field(default=None) + """A boolean flag dictating whether this parameter is a constant. + + If True, the value passed to the parameter must equal its default value. This also causes the OpenAPI const field to + be populated with the default value. + """ + gt: float | None = field(default=None) + """Constrict value to be greater than a given float or int. + + Equivalent to exclusiveMinimum in the OpenAPI specification. + """ + ge: float | None = field(default=None) + """Constrict value to be greater or equal to a given float or int. + + Equivalent to minimum in the OpenAPI specification. + """ + lt: float | None = field(default=None) + """Constrict value to be less than a given float or int. + + Equivalent to exclusiveMaximum in the OpenAPI specification. + """ + le: float | None = field(default=None) + """Constrict value to be less or equal to a given float or int. + + Equivalent to maximum in the OpenAPI specification. + """ + multiple_of: float | None = field(default=None) + """Constrict value to a multiple of a given float or int. + + Equivalent to multipleOf in the OpenAPI specification. + """ + min_items: int | None = field(default=None) + """Constrict a set or a list to have a minimum number of items. + + Equivalent to minItems in the OpenAPI specification. + """ + max_items: int | None = field(default=None) + """Constrict a set or a list to have a maximum number of items. + + Equivalent to maxItems in the OpenAPI specification. + """ + min_length: int | None = field(default=None) + """Constrict a string or bytes value to have a minimum length. + + Equivalent to minLength in the OpenAPI specification. + """ + max_length: int | None = field(default=None) + """Constrict a string or bytes value to have a maximum length. + + Equivalent to maxLength in the OpenAPI specification. + """ + pattern: str | None = field(default=None) + """A string representing a regex against which the given string will be matched. + + Equivalent to pattern in the OpenAPI specification. + """ + lower_case: bool | None = field(default=None) + """Constrict a string value to be lower case.""" + upper_case: bool | None = field(default=None) + """Constrict a string value to be upper case.""" + format: str | None = field(default=None) + """Specify the format to which a string value should be converted.""" + enum: Sequence[Any] | None = field(default=None) + """A sequence of valid values.""" + read_only: bool | None = field(default=None) + """A boolean flag dictating whether this parameter is read only.""" + + @property + def is_constrained(self) -> bool: + """Return True if any of the constraints are set.""" + return any( + attr if attr and attr is not Empty else False # type: ignore[comparison-overlap] + for attr in ( + self.gt, + self.ge, + self.lt, + self.le, + self.multiple_of, + self.min_items, + self.max_items, + self.min_length, + self.max_length, + self.pattern, + self.const, + self.lower_case, + self.upper_case, + ) + ) + + +@dataclass(frozen=True) +class ParameterKwarg(KwargDefinition): + """Data container representing a parameter.""" + + annotation: Any = field(default=Empty) + """The field value - `Empty` by default.""" + header: str | None = field(default=None) + """The header parameter key - required for header parameters.""" + cookie: str | None = field(default=None) + """The cookie parameter key - required for cookie parameters.""" + query: str | None = field(default=None) + """The query parameter key for this parameter.""" + required: bool | None = field(default=None) + """A boolean flag dictating whether this parameter is required. + + If set to False, None values will be allowed. Defaults to True. + """ + + def __hash__(self) -> int: # pragma: no cover + """Hash the dataclass in a safe way. + + Returns: + A hash + """ + return sum(hash(v) for v in asdict(self) if isinstance(v, Hashable)) + + +def Parameter( + annotation: Any = Empty, + *, + const: bool | None = None, + content_encoding: str | None = None, + cookie: str | None = None, + default: Any = Empty, + description: str | None = None, + examples: list[Example] | None = None, + external_docs: ExternalDocumentation | None = None, + ge: float | None = None, + gt: float | None = None, + header: str | None = None, + le: float | None = None, + lt: float | None = None, + max_items: int | None = None, + max_length: int | None = None, + min_items: int | None = None, + min_length: int | None = None, + multiple_of: float | None = None, + pattern: str | None = None, + query: str | None = None, + required: bool | None = None, + title: str | None = None, +) -> Any: + """Create an extended parameter kwarg definition. + + Args: + annotation: `Empty` by default. + const: A boolean flag dictating whether this parameter is a constant. If True, the value passed to the parameter + must equal its default value. This also causes the OpenAPI const field + to be populated with the default value. + content_encoding: The content encoding of the value. + Applicable on to string values. See OpenAPI 3.1 for details. + cookie: The cookie parameter key - required for cookie parameters. + default: A default value. If const is true, this value is required. + description: String value used in the description section of the OpenAPI schema for the given parameter. + examples: A list of Example models. + external_docs: A url pointing at external documentation for the given parameter. + ge: Constrict value to be greater or equal to a given float or int. + Equivalent to minimum in the OpenAPI specification. + gt: Constrict value to be greater than a given float or int. + Equivalent to exclusiveMinimum in the OpenAPI specification. + header: The header parameter key - required for header parameters. + le: Constrict value to be less or equal to a given float or int. + Equivalent to maximum in the OpenAPI specification. + lt: Constrict value to be less than a given float or int. + Equivalent to exclusiveMaximum in the OpenAPI specification. + max_items: Constrict a set or a list to have a maximum number of items. + Equivalent to maxItems in the OpenAPI specification. + max_length: Constrict a string or bytes value to have a maximum length. + Equivalent to maxLength in the OpenAPI specification. + min_items: Constrict a set or a list to have a minimum number of items. ֿ + Equivalent to minItems in the OpenAPI specification. + min_length: Constrict a string or bytes value to have a minimum length. + Equivalent to minLength in the OpenAPI specification. + multiple_of: Constrict value to a multiple of a given float or int. + Equivalent to multipleOf in the OpenAPI specification. + pattern: A string representing a regex against which the given string will be matched. + Equivalent to pattern in the OpenAPI specification. + query: The query parameter key for this parameter. + required: A boolean flag dictating whether this parameter is required. + If set to False, None values will be allowed. Defaults to True. + title: String value used in the title section of the OpenAPI schema for the given parameter. + """ + return ParameterKwarg( + annotation=annotation, + header=header, + cookie=cookie, + query=query, + examples=examples, + external_docs=external_docs, + content_encoding=content_encoding, + required=required, + default=default, + title=title, + description=description, + const=const, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + min_items=min_items, + max_items=max_items, + min_length=min_length, + max_length=max_length, + pattern=pattern, + ) + + +@dataclass(frozen=True) +class BodyKwarg(KwargDefinition): + """Data container representing a request body.""" + + media_type: str | RequestEncodingType = field(default=RequestEncodingType.JSON) + """Media-Type of the body.""" + + multipart_form_part_limit: int | None = field(default=None) + """The maximal number of allowed parts in a multipart/formdata request. This limit is intended to protect from DoS attacks.""" + + def __hash__(self) -> int: # pragma: no cover + """Hash the dataclass in a safe way. + + Returns: + A hash + """ + return sum(hash(v) for v in asdict(self) if isinstance(v, Hashable)) + + +def Body( + *, + const: bool | None = None, + content_encoding: str | None = None, + default: Any = Empty, + description: str | None = None, + examples: list[Example] | None = None, + external_docs: ExternalDocumentation | None = None, + ge: float | None = None, + gt: float | None = None, + le: float | None = None, + lt: float | None = None, + max_items: int | None = None, + max_length: int | None = None, + media_type: str | RequestEncodingType = RequestEncodingType.JSON, + min_items: int | None = None, + min_length: int | None = None, + multipart_form_part_limit: int | None = None, + multiple_of: float | None = None, + pattern: str | None = None, + title: str | None = None, +) -> Any: + """Create an extended request body kwarg definition. + + Args: + const: A boolean flag dictating whether this parameter is a constant. If True, the value passed to the parameter + must equal its default value. This also causes the OpenAPI const field to be + populated with the default value. + content_encoding: The content encoding of the value. Applicable on to string values. + See OpenAPI 3.1 for details. + default: A default value. If const is true, this value is required. + description: String value used in the description section of the OpenAPI schema for the given parameter. + examples: A list of Example models. + external_docs: A url pointing at external documentation for the given parameter. + ge: Constrict value to be greater or equal to a given float or int. + Equivalent to minimum in the OpenAPI specification. + gt: Constrict value to be greater than a given float or int. + Equivalent to exclusiveMinimum in the OpenAPI specification. + le: Constrict value to be less or equal to a given float or int. + Equivalent to maximum in the OpenAPI specification. + lt: Constrict value to be less than a given float or int. + Equivalent to exclusiveMaximum in the OpenAPI specification. + max_items: Constrict a set or a list to have a maximum number of items. + Equivalent to maxItems in the OpenAPI specification. + max_length: Constrict a string or bytes value to have a maximum length. + Equivalent to maxLength in the OpenAPI specification. + media_type: Defaults to RequestEncodingType.JSON. + min_items: Constrict a set or a list to have a minimum number of items. + Equivalent to minItems in the OpenAPI specification. + min_length: Constrict a string or bytes value to have a minimum length. + Equivalent to minLength in the OpenAPI specification. + multipart_form_part_limit: The maximal number of allowed parts in a multipart/formdata request. + This limit is intended to protect from DoS attacks. + multiple_of: Constrict value to a multiple of a given float or int. + Equivalent to multipleOf in the OpenAPI specification. + pattern: A string representing a regex against which the given string will be matched. + Equivalent to pattern in the OpenAPI specification. + title: String value used in the title section of the OpenAPI schema for the given parameter. + """ + return BodyKwarg( + media_type=media_type, + examples=examples, + external_docs=external_docs, + content_encoding=content_encoding, + default=default, + title=title, + description=description, + const=const, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + min_items=min_items, + max_items=max_items, + min_length=min_length, + max_length=max_length, + pattern=pattern, + multipart_form_part_limit=multipart_form_part_limit, + ) + + +@dataclass(frozen=True) +class DependencyKwarg: + """Data container representing a dependency.""" + + default: Any = field(default=Empty) + """A default value.""" + skip_validation: bool = field(default=False) + """Flag dictating whether to skip validation.""" + + def __hash__(self) -> int: + """Hash the dataclass in a safe way. + + Returns: + A hash + """ + return sum(hash(v) for v in asdict(self) if isinstance(v, Hashable)) + + +def Dependency(*, default: Any = Empty, skip_validation: bool = False) -> Any: + """Create a dependency kwarg definition. + + Args: + default: A default value to use in case a dependency is not provided. + skip_validation: If `True` provided dependency values are not validated by signature model. + """ + return DependencyKwarg(default=default, skip_validation=skip_validation) diff --git a/venv/lib/python3.11/site-packages/litestar/plugins/__init__.py b/venv/lib/python3.11/site-packages/litestar/plugins/__init__.py new file mode 100644 index 0000000..f093104 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/plugins/__init__.py @@ -0,0 +1,23 @@ +from litestar.plugins.base import ( + CLIPlugin, + CLIPluginProtocol, + DIPlugin, + InitPluginProtocol, + OpenAPISchemaPlugin, + OpenAPISchemaPluginProtocol, + PluginProtocol, + PluginRegistry, + SerializationPluginProtocol, +) + +__all__ = ( + "SerializationPluginProtocol", + "DIPlugin", + "CLIPlugin", + "InitPluginProtocol", + "OpenAPISchemaPluginProtocol", + "OpenAPISchemaPlugin", + "PluginProtocol", + "CLIPluginProtocol", + "PluginRegistry", +) diff --git a/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..e547c6a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..582d144 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/core.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/core.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..d197b7e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/core.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/sqlalchemy.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/sqlalchemy.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2b9ab93 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/sqlalchemy.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/structlog.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/structlog.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..e575d1b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/plugins/__pycache__/structlog.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/plugins/base.py b/venv/lib/python3.11/site-packages/litestar/plugins/base.py new file mode 100644 index 0000000..afc571e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/plugins/base.py @@ -0,0 +1,323 @@ +from __future__ import annotations + +import abc +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Iterator, Protocol, TypeVar, Union, cast, runtime_checkable + +if TYPE_CHECKING: + from inspect import Signature + + from click import Group + + from litestar._openapi.schema_generation import SchemaCreator + from litestar.app import Litestar + from litestar.config.app import AppConfig + from litestar.dto import AbstractDTO + from litestar.openapi.spec import Schema + from litestar.routes import BaseRoute + from litestar.typing import FieldDefinition + +__all__ = ( + "SerializationPluginProtocol", + "InitPluginProtocol", + "OpenAPISchemaPluginProtocol", + "OpenAPISchemaPlugin", + "PluginProtocol", + "CLIPlugin", + "CLIPluginProtocol", + "PluginRegistry", + "DIPlugin", +) + + +@runtime_checkable +class InitPluginProtocol(Protocol): + """Protocol used to define plugins that affect the application's init process.""" + + __slots__ = () + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + """Receive the :class:`AppConfig<.config.app.AppConfig>` instance after `on_app_init` hooks have been called. + + Examples: + .. code-block:: python + + from litestar import Litestar, get + from litestar.di import Provide + from litestar.plugins import InitPluginProtocol + + + def get_name() -> str: + return "world" + + + @get("/my-path") + def my_route_handler(name: str) -> dict[str, str]: + return {"hello": name} + + + class MyPlugin(InitPluginProtocol): + def on_app_init(self, app_config: AppConfig) -> AppConfig: + app_config.dependencies["name"] = Provide(get_name) + app_config.route_handlers.append(my_route_handler) + return app_config + + + app = Litestar(plugins=[MyPlugin()]) + + Args: + app_config: The :class:`AppConfig <litestar.config.app.AppConfig>` instance. + + Returns: + The app config object. + """ + return app_config # pragma: no cover + + +class ReceiveRoutePlugin: + """Receive routes as they are added to the application.""" + + __slots__ = () + + def receive_route(self, route: BaseRoute) -> None: + """Receive routes as they are registered on an application.""" + + +@runtime_checkable +class CLIPluginProtocol(Protocol): + """Plugin protocol to extend the CLI.""" + + __slots__ = () + + def on_cli_init(self, cli: Group) -> None: + """Called when the CLI is initialized. + + This can be used to extend or override existing commands. + + Args: + cli: The root :class:`click.Group` of the Litestar CLI + + Examples: + .. code-block:: python + + from litestar import Litestar + from litestar.plugins import CLIPluginProtocol + from click import Group + + + class CLIPlugin(CLIPluginProtocol): + def on_cli_init(self, cli: Group) -> None: + @cli.command() + def is_debug_mode(app: Litestar): + print(app.debug) + + + app = Litestar(plugins=[CLIPlugin()]) + """ + + +class CLIPlugin(CLIPluginProtocol): + """Plugin protocol to extend the CLI Server Lifespan.""" + + __slots__ = () + + def on_cli_init(self, cli: Group) -> None: + return super().on_cli_init(cli) + + @contextmanager + def server_lifespan(self, app: Litestar) -> Iterator[None]: + yield + + +@runtime_checkable +class SerializationPluginProtocol(Protocol): + """Protocol used to define a serialization plugin for DTOs.""" + + __slots__ = () + + def supports_type(self, field_definition: FieldDefinition) -> bool: + """Given a value of indeterminate type, determine if this value is supported by the plugin. + + Args: + field_definition: A parsed type. + + Returns: + Whether the type is supported by the plugin. + """ + raise NotImplementedError() + + def create_dto_for_type(self, field_definition: FieldDefinition) -> type[AbstractDTO]: + """Given a parsed type, create a DTO class. + + Args: + field_definition: A parsed type. + + Returns: + A DTO class. + """ + raise NotImplementedError() + + +class DIPlugin(abc.ABC): + """Extend dependency injection""" + + @abc.abstractmethod + def has_typed_init(self, type_: Any) -> bool: + """Return ``True`` if ``type_`` has type information available for its + :func:`__init__` method that cannot be extracted from this method's type + annotations (e.g. a Pydantic BaseModel subclass), and + :meth:`DIPlugin.get_typed_init` supports extraction of these annotations. + """ + ... + + @abc.abstractmethod + def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]: + r"""Return signature and type information about the ``type_``\ s :func:`__init__` + method. + """ + ... + + +@runtime_checkable +class OpenAPISchemaPluginProtocol(Protocol): + """Plugin protocol to extend the support of OpenAPI schema generation for non-library types.""" + + __slots__ = () + + @staticmethod + def is_plugin_supported_type(value: Any) -> bool: + """Given a value of indeterminate type, determine if this value is supported by the plugin. + + Args: + value: An arbitrary value. + + Returns: + A typeguard dictating whether the value is supported by the plugin. + """ + raise NotImplementedError() + + def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema: + """Given a type annotation, transform it into an OpenAPI schema class. + + Args: + field_definition: An :class:`OpenAPI <litestar.openapi.spec.schema.Schema>` instance. + schema_creator: An instance of the openapi SchemaCreator. + + Returns: + An :class:`OpenAPI <litestar.openapi.spec.schema.Schema>` instance. + """ + raise NotImplementedError() + + +class OpenAPISchemaPlugin(OpenAPISchemaPluginProtocol): + """Plugin to extend the support of OpenAPI schema generation for non-library types.""" + + @staticmethod + def is_plugin_supported_type(value: Any) -> bool: + """Given a value of indeterminate type, determine if this value is supported by the plugin. + + This is called by the default implementation of :meth:`is_plugin_supported_field` for + backwards compatibility. User's should prefer to override that method instead. + + Args: + value: An arbitrary value. + + Returns: + A bool indicating whether the value is supported by the plugin. + """ + raise NotImplementedError( + "One of either is_plugin_supported_type or is_plugin_supported_field should be defined. " + "The default implementation of is_plugin_supported_field calls is_plugin_supported_type " + "for backwards compatibility. Users should prefer to override is_plugin_supported_field " + "as it receives a 'FieldDefinition' instance which is more useful than a raw type." + ) + + def is_plugin_supported_field(self, field_definition: FieldDefinition) -> bool: + """Given a :class:`FieldDefinition <litestar.typing.FieldDefinition>` that represents an indeterminate type, + determine if this value is supported by the plugin + + Args: + field_definition: A parsed type. + + Returns: + Whether the type is supported by the plugin. + """ + return self.is_plugin_supported_type(field_definition.annotation) + + @staticmethod + def is_undefined_sentinel(value: Any) -> bool: + """Return ``True`` if ``value`` should be treated as an undefined field""" + return False + + @staticmethod + def is_constrained_field(field_definition: FieldDefinition) -> bool: + """Return ``True`` if the field should be treated as constrained. If returning + ``True``, constraints should be defined in the field's extras + """ + return False + + +PluginProtocol = Union[ + CLIPlugin, + CLIPluginProtocol, + InitPluginProtocol, + OpenAPISchemaPlugin, + OpenAPISchemaPluginProtocol, + ReceiveRoutePlugin, + SerializationPluginProtocol, + DIPlugin, +] + +PluginT = TypeVar("PluginT", bound=PluginProtocol) + + +class PluginRegistry: + __slots__ = { + "init": "Plugins that implement the InitPluginProtocol", + "openapi": "Plugins that implement the OpenAPISchemaPluginProtocol", + "receive_route": "ReceiveRoutePlugin instances", + "serialization": "Plugins that implement the SerializationPluginProtocol", + "cli": "Plugins that implement the CLIPluginProtocol", + "di": "DIPlugin instances", + "_plugins_by_type": None, + "_plugins": None, + "_get_plugins_of_type": None, + } + + def __init__(self, plugins: list[PluginProtocol]) -> None: + self._plugins_by_type = {type(p): p for p in plugins} + self._plugins = frozenset(plugins) + self.init = tuple(p for p in plugins if isinstance(p, InitPluginProtocol)) + self.openapi = tuple(p for p in plugins if isinstance(p, OpenAPISchemaPluginProtocol)) + self.receive_route = tuple(p for p in plugins if isinstance(p, ReceiveRoutePlugin)) + self.serialization = tuple(p for p in plugins if isinstance(p, SerializationPluginProtocol)) + self.cli = tuple(p for p in plugins if isinstance(p, CLIPluginProtocol)) + self.di = tuple(p for p in plugins if isinstance(p, DIPlugin)) + + def get(self, type_: type[PluginT] | str) -> PluginT: + """Return the registered plugin of ``type_``. + + This should be used with subclasses of the plugin protocols. + """ + if isinstance(type_, str): + for plugin in self._plugins: + _name = plugin.__class__.__name__ + _module = plugin.__class__.__module__ + _qualname = ( + f"{_module}.{plugin.__class__.__qualname__}" + if _module is not None and _module != "__builtin__" + else plugin.__class__.__qualname__ + ) + if type_ in {_name, _qualname}: + return cast(PluginT, plugin) + raise KeyError(f"No plugin of type {type_!r} registered") + try: + return cast(PluginT, self._plugins_by_type[type_]) # type: ignore[index] + except KeyError as e: + raise KeyError(f"No plugin of type {type_.__name__!r} registered") from e + + def __iter__(self) -> Iterator[PluginProtocol]: + return iter(self._plugins) + + def __contains__(self, item: PluginProtocol) -> bool: + return item in self._plugins diff --git a/venv/lib/python3.11/site-packages/litestar/plugins/core.py b/venv/lib/python3.11/site-packages/litestar/plugins/core.py new file mode 100644 index 0000000..d25d6d6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/plugins/core.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import inspect +from inspect import Signature +from typing import Any + +import msgspec + +from litestar.plugins import DIPlugin + +__all__ = ("MsgspecDIPlugin",) + + +class MsgspecDIPlugin(DIPlugin): + def has_typed_init(self, type_: Any) -> bool: + return type(type_) is type(msgspec.Struct) # noqa: E721 + + def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]: + parameters = [] + type_hints = {} + for field_info in msgspec.structs.fields(type_): + type_hints[field_info.name] = field_info.type + parameters.append( + inspect.Parameter( + name=field_info.name, + kind=inspect.Parameter.KEYWORD_ONLY, + annotation=field_info.type, + default=field_info.default, + ) + ) + return inspect.Signature(parameters), type_hints diff --git a/venv/lib/python3.11/site-packages/litestar/plugins/sqlalchemy.py b/venv/lib/python3.11/site-packages/litestar/plugins/sqlalchemy.py new file mode 100644 index 0000000..d65d712 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/plugins/sqlalchemy.py @@ -0,0 +1,54 @@ +from advanced_alchemy import filters, types +from advanced_alchemy.base import ( + AuditColumns, + BigIntAuditBase, + BigIntBase, + BigIntPrimaryKey, + CommonTableAttributes, + UUIDAuditBase, + UUIDBase, + UUIDPrimaryKey, + orm_registry, +) +from advanced_alchemy.extensions.litestar import ( + AlembicAsyncConfig, + AlembicCommands, + AlembicSyncConfig, + AsyncSessionConfig, + EngineConfig, + SQLAlchemyAsyncConfig, + SQLAlchemyDTO, + SQLAlchemyDTOConfig, + SQLAlchemyInitPlugin, + SQLAlchemyPlugin, + SQLAlchemySerializationPlugin, + SQLAlchemySyncConfig, + SyncSessionConfig, +) + +__all__ = ( + "orm_registry", + "filters", + "types", + "AuditColumns", + "BigIntAuditBase", + "BigIntBase", + "UUIDAuditBase", + "UUIDPrimaryKey", + "CommonTableAttributes", + "UUIDBase", + "BigIntPrimaryKey", + "AlembicCommands", + "AlembicAsyncConfig", + "AlembicSyncConfig", + "AsyncSessionConfig", + "SyncSessionConfig", + "SQLAlchemyDTO", + "SQLAlchemyDTOConfig", + "SQLAlchemyAsyncConfig", + "SQLAlchemyInitPlugin", + "SQLAlchemyPlugin", + "SQLAlchemySerializationPlugin", + "SQLAlchemySyncConfig", + "EngineConfig", +) diff --git a/venv/lib/python3.11/site-packages/litestar/plugins/structlog.py b/venv/lib/python3.11/site-packages/litestar/plugins/structlog.py new file mode 100644 index 0000000..fafa3dd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/plugins/structlog.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from litestar.cli._utils import console +from litestar.logging.config import StructLoggingConfig +from litestar.middleware.logging import LoggingMiddlewareConfig +from litestar.plugins import InitPluginProtocol + +if TYPE_CHECKING: + from litestar.config.app import AppConfig + + +@dataclass +class StructlogConfig: + structlog_logging_config: StructLoggingConfig = field(default_factory=StructLoggingConfig) + """Structlog Logging configuration for Litestar. See ``litestar.logging.config.StructLoggingConfig``` for details.""" + middleware_logging_config: LoggingMiddlewareConfig = field(default_factory=LoggingMiddlewareConfig) + """Middleware logging config.""" + enable_middleware_logging: bool = True + """Enable request logging.""" + + +class StructlogPlugin(InitPluginProtocol): + """Structlog Plugin.""" + + __slots__ = ("_config",) + + def __init__(self, config: StructlogConfig | None = None) -> None: + if config is None: + config = StructlogConfig() + self._config = config + super().__init__() + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + """Structlog Plugin + + Args: + app_config: The :class:`AppConfig <litestar.config.app.AppConfig>` instance. + + Returns: + The app config object. + """ + if app_config.logging_config is not None and isinstance(app_config.logging_config, StructLoggingConfig): + console.print( + "[red dim]* Found pre-configured `StructLoggingConfig` on the `app` instance. Skipping configuration.[/]", + ) + else: + app_config.logging_config = self._config.structlog_logging_config + app_config.logging_config.configure() + if self._config.structlog_logging_config.standard_lib_logging_config is not None: # pragma: no cover + self._config.structlog_logging_config.standard_lib_logging_config.configure() # pragma: no cover + if self._config.enable_middleware_logging: + app_config.middleware.append(self._config.middleware_logging_config.middleware) + return app_config # pragma: no cover diff --git a/venv/lib/python3.11/site-packages/litestar/py.typed b/venv/lib/python3.11/site-packages/litestar/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/py.typed diff --git a/venv/lib/python3.11/site-packages/litestar/repository/__init__.py b/venv/lib/python3.11/site-packages/litestar/repository/__init__.py new file mode 100644 index 0000000..62e3f83 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/__init__.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from .abc import AbstractAsyncRepository, AbstractSyncRepository +from .exceptions import ConflictError, NotFoundError, RepositoryError +from .filters import FilterTypes + +__all__ = ( + "AbstractAsyncRepository", + "AbstractSyncRepository", + "ConflictError", + "FilterTypes", + "NotFoundError", + "RepositoryError", +) diff --git a/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..42bf9be --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/_exceptions.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/_exceptions.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0f48627 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/_exceptions.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/_filters.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/_filters.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2f16cb4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/_filters.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/exceptions.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/exceptions.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..3c9f01e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/exceptions.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/filters.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/filters.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..87d90db --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/filters.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/handlers.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/handlers.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..abe227b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/__pycache__/handlers.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/repository/_exceptions.py b/venv/lib/python3.11/site-packages/litestar/repository/_exceptions.py new file mode 100644 index 0000000..1c2b7be --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/_exceptions.py @@ -0,0 +1,15 @@ +from __future__ import annotations # pragma: no cover + +__all__ = ("ConflictError", "NotFoundError", "RepositoryError") # pragma: no cover + + +class RepositoryError(Exception): # pragma: no cover + """Base repository exception type.""" + + +class ConflictError(RepositoryError): # pragma: no cover + """Data integrity error.""" + + +class NotFoundError(RepositoryError): # pragma: no cover + """An identity does not exist.""" diff --git a/venv/lib/python3.11/site-packages/litestar/repository/_filters.py b/venv/lib/python3.11/site-packages/litestar/repository/_filters.py new file mode 100644 index 0000000..f6b787e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/_filters.py @@ -0,0 +1,117 @@ +"""Collection filter datastructures.""" + +from __future__ import annotations + +from collections import abc # noqa: TCH003 +from dataclasses import dataclass +from datetime import datetime # noqa: TCH003 +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + +T = TypeVar("T") + +__all__ = ( + "BeforeAfter", + "CollectionFilter", + "FilterTypes", + "LimitOffset", + "OrderBy", + "SearchFilter", + "NotInCollectionFilter", + "OnBeforeAfter", + "NotInSearchFilter", +) + + +FilterTypes: TypeAlias = "BeforeAfter | OnBeforeAfter | CollectionFilter[Any] | LimitOffset | OrderBy | SearchFilter | NotInCollectionFilter[Any] | NotInSearchFilter" +"""Aggregate type alias of the types supported for collection filtering.""" + + +@dataclass +class BeforeAfter: + """Data required to filter a query on a ``datetime`` column.""" + + field_name: str + """Name of the model attribute to filter on.""" + before: datetime | None + """Filter results where field earlier than this.""" + after: datetime | None + """Filter results where field later than this.""" + + +@dataclass +class OnBeforeAfter: + """Data required to filter a query on a ``datetime`` column.""" + + field_name: str + """Name of the model attribute to filter on.""" + on_or_before: datetime | None + """Filter results where field is on or earlier than this.""" + on_or_after: datetime | None + """Filter results where field on or later than this.""" + + +@dataclass +class CollectionFilter(Generic[T]): + """Data required to construct a ``WHERE ... IN (...)`` clause.""" + + field_name: str + """Name of the model attribute to filter on.""" + values: abc.Collection[T] + """Values for ``IN`` clause.""" + + +@dataclass +class NotInCollectionFilter(Generic[T]): + """Data required to construct a ``WHERE ... NOT IN (...)`` clause.""" + + field_name: str + """Name of the model attribute to filter on.""" + values: abc.Collection[T] + """Values for ``NOT IN`` clause.""" + + +@dataclass +class LimitOffset: + """Data required to add limit/offset filtering to a query.""" + + limit: int + """Value for ``LIMIT`` clause of query.""" + offset: int + """Value for ``OFFSET`` clause of query.""" + + +@dataclass +class OrderBy: + """Data required to construct a ``ORDER BY ...`` clause.""" + + field_name: str + """Name of the model attribute to sort on.""" + sort_order: Literal["asc", "desc"] = "asc" + """Sort ascending or descending""" + + +@dataclass +class SearchFilter: + """Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause.""" + + field_name: str + """Name of the model attribute to sort on.""" + value: str + """Values for ``LIKE`` clause.""" + ignore_case: bool | None = False + """Should the search be case insensitive.""" + + +@dataclass +class NotInSearchFilter: + """Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause.""" + + field_name: str + """Name of the model attribute to search on.""" + value: str + """Values for ``NOT LIKE`` clause.""" + ignore_case: bool | None = False + """Should the search be case insensitive.""" diff --git a/venv/lib/python3.11/site-packages/litestar/repository/abc/__init__.py b/venv/lib/python3.11/site-packages/litestar/repository/abc/__init__.py new file mode 100644 index 0000000..def6bb4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/abc/__init__.py @@ -0,0 +1,7 @@ +from ._async import AbstractAsyncRepository +from ._sync import AbstractSyncRepository + +__all__ = ( + "AbstractAsyncRepository", + "AbstractSyncRepository", +) diff --git a/venv/lib/python3.11/site-packages/litestar/repository/abc/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/repository/abc/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..710a74b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/abc/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/repository/abc/__pycache__/_async.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/repository/abc/__pycache__/_async.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..24217cc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/abc/__pycache__/_async.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/repository/abc/__pycache__/_sync.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/repository/abc/__pycache__/_sync.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..7a69407 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/abc/__pycache__/_sync.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/repository/abc/_async.py b/venv/lib/python3.11/site-packages/litestar/repository/abc/_async.py new file mode 100644 index 0000000..85ca139 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/abc/_async.py @@ -0,0 +1,303 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from litestar.repository.exceptions import NotFoundError + +if TYPE_CHECKING: + from litestar.repository.filters import FilterTypes + +T = TypeVar("T") +CollectionT = TypeVar("CollectionT") + + +class AbstractAsyncRepository(Generic[T], metaclass=ABCMeta): + """Interface for persistent data interaction.""" + + model_type: type[T] + """Type of object represented by the repository.""" + id_attribute: Any = "id" + """Name of the primary identifying attribute on :attr:`model_type`.""" + + def __init__(self, **kwargs: Any) -> None: + """Repository constructors accept arbitrary kwargs.""" + super().__init__(**kwargs) + + @abstractmethod + async def add(self, data: T) -> T: + """Add ``data`` to the collection. + + Args: + data: Instance to be added to the collection. + + Returns: + The added instance. + """ + + @abstractmethod + async def add_many(self, data: list[T]) -> list[T]: + """Add multiple ``data`` to the collection. + + Args: + data: Instances to be added to the collection. + + Returns: + The added instances. + """ + + @abstractmethod + async def count(self, *filters: FilterTypes, **kwargs: Any) -> int: + """Get the count of records returned by a query. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + The count of instances + """ + + @abstractmethod + async def delete(self, item_id: Any) -> T: + """Delete instance identified by ``item_id``. + + Args: + item_id: Identifier of instance to be deleted. + + Returns: + The deleted instance. + + Raises: + NotFoundError: If no instance found identified by ``item_id``. + """ + + @abstractmethod + async def delete_many(self, item_ids: list[Any]) -> list[T]: + """Delete multiple instances identified by list of IDs ``item_ids``. + + Args: + item_ids: list of Identifiers to be deleted. + + Returns: + The deleted instances. + """ + + @abstractmethod + async def exists(self, *filters: FilterTypes, **kwargs: Any) -> bool: + """Return true if the object specified by ``kwargs`` exists. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Identifier of the instance to be retrieved. + + Returns: + True if the instance was found. False if not found. + + """ + + @abstractmethod + async def get(self, item_id: Any, **kwargs: Any) -> T: + """Get instance identified by ``item_id``. + + Args: + item_id: Identifier of the instance to be retrieved. + **kwargs: Additional arguments + + Returns: + The retrieved instance. + + Raises: + NotFoundError: If no instance found identified by ``item_id``. + """ + + @abstractmethod + async def get_one(self, **kwargs: Any) -> T: + """Get an instance specified by the ``kwargs`` filters if it exists. + + Args: + **kwargs: Instance attribute value filters. + + Returns: + The retrieved instance. + + Raises: + NotFoundError: If no instance found identified by ``kwargs``. + """ + + @abstractmethod + async def get_or_create(self, **kwargs: Any) -> tuple[T, bool]: + """Get an instance specified by the ``kwargs`` filters if it exists or create it. + + Args: + **kwargs: Instance attribute value filters. + + Returns: + A tuple that includes the retrieved or created instance, and a boolean on whether the record was created or not + """ + + @abstractmethod + async def get_one_or_none(self, **kwargs: Any) -> T | None: + """Get an instance if it exists or None. + + Args: + **kwargs: Instance attribute value filters. + + Returns: + The retrieved instance or None. + """ + + @abstractmethod + async def update(self, data: T) -> T: + """Update instance with the attribute values present on ``data``. + + Args: + data: An instance that should have a value for :attr:`id_attribute <AbstractAsyncRepository.id_attribute>` that exists in the + collection. + + Returns: + The updated instance. + + Raises: + NotFoundError: If no instance found with same identifier as ``data``. + """ + + @abstractmethod + async def update_many(self, data: list[T]) -> list[T]: + """Update multiple instances with the attribute values present on instances in ``data``. + + Args: + data: A list of instance that should have a value for :attr:`id_attribute <AbstractAsyncRepository.id_attribute>` that exists in the + collection. + + Returns: + a list of the updated instances. + + Raises: + NotFoundError: If no instance found with same identifier as ``data``. + """ + + @abstractmethod + async def upsert(self, data: T) -> T: + """Update or create instance. + + Updates instance with the attribute values present on ``data``, or creates a new instance if + one doesn't exist. + + Args: + data: Instance to update existing, or be created. Identifier used to determine if an + existing instance exists is the value of an attribute on ``data`` named as value of + :attr:`id_attribute <AbstractAsyncRepository.id_attribute>`. + + Returns: + The updated or created instance. + + Raises: + NotFoundError: If no instance found with same identifier as ``data``. + """ + + @abstractmethod + async def upsert_many(self, data: list[T]) -> list[T]: + """Update or create multiple instances. + + Update instances with the attribute values present on ``data``, or create a new instance if + one doesn't exist. + + Args: + data: Instances to update or created. Identifier used to determine if an + existing instance exists is the value of an attribute on ``data`` named as value of + :attr:`id_attribute <AbstractAsyncRepository.id_attribute>`. + + Returns: + The updated or created instances. + + Raises: + NotFoundError: If no instance found with same identifier as ``data``. + """ + + @abstractmethod + async def list_and_count(self, *filters: FilterTypes, **kwargs: Any) -> tuple[list[T], int]: + """List records with total count. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + a tuple containing The list of instances, after filtering applied, and a count of records returned by query, ignoring pagination. + """ + + @abstractmethod + async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[T]: + """Get a list of instances, optionally filtered. + + Args: + *filters: filters for specific filtering operations + **kwargs: Instance attribute value filters. + + Returns: + The list of instances, after filtering applied + """ + + @abstractmethod + def filter_collection_by_kwargs(self, collection: CollectionT, /, **kwargs: Any) -> CollectionT: + """Filter the collection by kwargs. + + Has ``AND`` semantics where multiple kwargs name/value pairs are provided. + + Args: + collection: the objects to be filtered + **kwargs: key/value pairs such that objects remaining in the collection after filtering + have the property that their attribute named ``key`` has value equal to ``value``. + + + Returns: + The filtered objects + + Raises: + RepositoryError: if a named attribute doesn't exist on :attr:`model_type <AbstractAsyncRepository.model_type>`. + """ + + @staticmethod + def check_not_found(item_or_none: T | None) -> T: + """Raise :class:`NotFoundError` if ``item_or_none`` is ``None``. + + Args: + item_or_none: Item (:class:`T <T>`) to be tested for existence. + + Returns: + The item, if it exists. + """ + if item_or_none is None: + raise NotFoundError("No item found when one was expected") + return item_or_none + + @classmethod + def get_id_attribute_value(cls, item: T | type[T], id_attribute: str | None = None) -> Any: + """Get value of attribute named as :attr:`id_attribute <AbstractAsyncRepository.id_attribute>` on ``item``. + + Args: + item: Anything that should have an attribute named as :attr:`id_attribute <AbstractAsyncRepository.id_attribute>` value. + id_attribute: Allows customization of the unique identifier to use for model fetching. + Defaults to `None`, but can reference any surrogate or candidate key for the table. + + Returns: + The value of attribute on ``item`` named as :attr:`id_attribute <AbstractAsyncRepository.id_attribute>`. + """ + return getattr(item, id_attribute if id_attribute is not None else cls.id_attribute) + + @classmethod + def set_id_attribute_value(cls, item_id: Any, item: T, id_attribute: str | None = None) -> T: + """Return the ``item`` after the ID is set to the appropriate attribute. + + Args: + item_id: Value of ID to be set on instance + item: Anything that should have an attribute named as :attr:`id_attribute <AbstractAsyncRepository.id_attribute>` value. + id_attribute: Allows customization of the unique identifier to use for model fetching. + Defaults to `None`, but can reference any surrogate or candidate key for the table. + + Returns: + Item with ``item_id`` set to :attr:`id_attribute <AbstractAsyncRepository.id_attribute>` + """ + setattr(item, id_attribute if id_attribute is not None else cls.id_attribute, item_id) + return item diff --git a/venv/lib/python3.11/site-packages/litestar/repository/abc/_sync.py b/venv/lib/python3.11/site-packages/litestar/repository/abc/_sync.py new file mode 100644 index 0000000..d667fc2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/abc/_sync.py @@ -0,0 +1,305 @@ +# Do not edit this file directly. It has been autogenerated from +# litestar/repository/abc/_async.py +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from litestar.repository.exceptions import NotFoundError + +if TYPE_CHECKING: + from litestar.repository.filters import FilterTypes + +T = TypeVar("T") +CollectionT = TypeVar("CollectionT") + + +class AbstractSyncRepository(Generic[T], metaclass=ABCMeta): + """Interface for persistent data interaction.""" + + model_type: type[T] + """Type of object represented by the repository.""" + id_attribute: Any = "id" + """Name of the primary identifying attribute on :attr:`model_type`.""" + + def __init__(self, **kwargs: Any) -> None: + """Repository constructors accept arbitrary kwargs.""" + super().__init__(**kwargs) + + @abstractmethod + def add(self, data: T) -> T: + """Add ``data`` to the collection. + + Args: + data: Instance to be added to the collection. + + Returns: + The added instance. + """ + + @abstractmethod + def add_many(self, data: list[T]) -> list[T]: + """Add multiple ``data`` to the collection. + + Args: + data: Instances to be added to the collection. + + Returns: + The added instances. + """ + + @abstractmethod + def count(self, *filters: FilterTypes, **kwargs: Any) -> int: + """Get the count of records returned by a query. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + The count of instances + """ + + @abstractmethod + def delete(self, item_id: Any) -> T: + """Delete instance identified by ``item_id``. + + Args: + item_id: Identifier of instance to be deleted. + + Returns: + The deleted instance. + + Raises: + NotFoundError: If no instance found identified by ``item_id``. + """ + + @abstractmethod + def delete_many(self, item_ids: list[Any]) -> list[T]: + """Delete multiple instances identified by list of IDs ``item_ids``. + + Args: + item_ids: list of Identifiers to be deleted. + + Returns: + The deleted instances. + """ + + @abstractmethod + def exists(self, *filters: FilterTypes, **kwargs: Any) -> bool: + """Return true if the object specified by ``kwargs`` exists. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Identifier of the instance to be retrieved. + + Returns: + True if the instance was found. False if not found. + + """ + + @abstractmethod + def get(self, item_id: Any, **kwargs: Any) -> T: + """Get instance identified by ``item_id``. + + Args: + item_id: Identifier of the instance to be retrieved. + **kwargs: Additional arguments + + Returns: + The retrieved instance. + + Raises: + NotFoundError: If no instance found identified by ``item_id``. + """ + + @abstractmethod + def get_one(self, **kwargs: Any) -> T: + """Get an instance specified by the ``kwargs`` filters if it exists. + + Args: + **kwargs: Instance attribute value filters. + + Returns: + The retrieved instance. + + Raises: + NotFoundError: If no instance found identified by ``kwargs``. + """ + + @abstractmethod + def get_or_create(self, **kwargs: Any) -> tuple[T, bool]: + """Get an instance specified by the ``kwargs`` filters if it exists or create it. + + Args: + **kwargs: Instance attribute value filters. + + Returns: + A tuple that includes the retrieved or created instance, and a boolean on whether the record was created or not + """ + + @abstractmethod + def get_one_or_none(self, **kwargs: Any) -> T | None: + """Get an instance if it exists or None. + + Args: + **kwargs: Instance attribute value filters. + + Returns: + The retrieved instance or None. + """ + + @abstractmethod + def update(self, data: T) -> T: + """Update instance with the attribute values present on ``data``. + + Args: + data: An instance that should have a value for :attr:`id_attribute <AbstractAsyncRepository.id_attribute>` that exists in the + collection. + + Returns: + The updated instance. + + Raises: + NotFoundError: If no instance found with same identifier as ``data``. + """ + + @abstractmethod + def update_many(self, data: list[T]) -> list[T]: + """Update multiple instances with the attribute values present on instances in ``data``. + + Args: + data: A list of instance that should have a value for :attr:`id_attribute <AbstractAsyncRepository.id_attribute>` that exists in the + collection. + + Returns: + a list of the updated instances. + + Raises: + NotFoundError: If no instance found with same identifier as ``data``. + """ + + @abstractmethod + def upsert(self, data: T) -> T: + """Update or create instance. + + Updates instance with the attribute values present on ``data``, or creates a new instance if + one doesn't exist. + + Args: + data: Instance to update existing, or be created. Identifier used to determine if an + existing instance exists is the value of an attribute on ``data`` named as value of + :attr:`id_attribute <AbstractAsyncRepository.id_attribute>`. + + Returns: + The updated or created instance. + + Raises: + NotFoundError: If no instance found with same identifier as ``data``. + """ + + @abstractmethod + def upsert_many(self, data: list[T]) -> list[T]: + """Update or create multiple instances. + + Update instances with the attribute values present on ``data``, or create a new instance if + one doesn't exist. + + Args: + data: Instances to update or created. Identifier used to determine if an + existing instance exists is the value of an attribute on ``data`` named as value of + :attr:`id_attribute <AbstractAsyncRepository.id_attribute>`. + + Returns: + The updated or created instances. + + Raises: + NotFoundError: If no instance found with same identifier as ``data``. + """ + + @abstractmethod + def list_and_count(self, *filters: FilterTypes, **kwargs: Any) -> tuple[list[T], int]: + """List records with total count. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + a tuple containing The list of instances, after filtering applied, and a count of records returned by query, ignoring pagination. + """ + + @abstractmethod + def list(self, *filters: FilterTypes, **kwargs: Any) -> list[T]: + """Get a list of instances, optionally filtered. + + Args: + *filters: filters for specific filtering operations + **kwargs: Instance attribute value filters. + + Returns: + The list of instances, after filtering applied + """ + + @abstractmethod + def filter_collection_by_kwargs(self, collection: CollectionT, /, **kwargs: Any) -> CollectionT: + """Filter the collection by kwargs. + + Has ``AND`` semantics where multiple kwargs name/value pairs are provided. + + Args: + collection: the objects to be filtered + **kwargs: key/value pairs such that objects remaining in the collection after filtering + have the property that their attribute named ``key`` has value equal to ``value``. + + + Returns: + The filtered objects + + Raises: + RepositoryError: if a named attribute doesn't exist on :attr:`model_type <AbstractAsyncRepository.model_type>`. + """ + + @staticmethod + def check_not_found(item_or_none: T | None) -> T: + """Raise :class:`NotFoundError` if ``item_or_none`` is ``None``. + + Args: + item_or_none: Item (:class:`T <T>`) to be tested for existence. + + Returns: + The item, if it exists. + """ + if item_or_none is None: + raise NotFoundError("No item found when one was expected") + return item_or_none + + @classmethod + def get_id_attribute_value(cls, item: T | type[T], id_attribute: str | None = None) -> Any: + """Get value of attribute named as :attr:`id_attribute <AbstractAsyncRepository.id_attribute>` on ``item``. + + Args: + item: Anything that should have an attribute named as :attr:`id_attribute <AbstractAsyncRepository.id_attribute>` value. + id_attribute: Allows customization of the unique identifier to use for model fetching. + Defaults to `None`, but can reference any surrogate or candidate key for the table. + + Returns: + The value of attribute on ``item`` named as :attr:`id_attribute <AbstractAsyncRepository.id_attribute>`. + """ + return getattr(item, id_attribute if id_attribute is not None else cls.id_attribute) + + @classmethod + def set_id_attribute_value(cls, item_id: Any, item: T, id_attribute: str | None = None) -> T: + """Return the ``item`` after the ID is set to the appropriate attribute. + + Args: + item_id: Value of ID to be set on instance + item: Anything that should have an attribute named as :attr:`id_attribute <AbstractAsyncRepository.id_attribute>` value. + id_attribute: Allows customization of the unique identifier to use for model fetching. + Defaults to `None`, but can reference any surrogate or candidate key for the table. + + Returns: + Item with ``item_id`` set to :attr:`id_attribute <AbstractAsyncRepository.id_attribute>` + """ + setattr(item, id_attribute if id_attribute is not None else cls.id_attribute, item_id) + return item diff --git a/venv/lib/python3.11/site-packages/litestar/repository/exceptions.py b/venv/lib/python3.11/site-packages/litestar/repository/exceptions.py new file mode 100644 index 0000000..8dad182 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/exceptions.py @@ -0,0 +1,7 @@ +try: + from advanced_alchemy.exceptions import IntegrityError as ConflictError + from advanced_alchemy.exceptions import NotFoundError, RepositoryError +except ImportError: # pragma: no cover + from ._exceptions import ConflictError, NotFoundError, RepositoryError # type: ignore[assignment] + +__all__ = ("ConflictError", "NotFoundError", "RepositoryError") diff --git a/venv/lib/python3.11/site-packages/litestar/repository/filters.py b/venv/lib/python3.11/site-packages/litestar/repository/filters.py new file mode 100644 index 0000000..e0cce48 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/filters.py @@ -0,0 +1,37 @@ +try: + from advanced_alchemy.filters import ( + BeforeAfter, + CollectionFilter, + FilterTypes, + LimitOffset, + NotInCollectionFilter, + NotInSearchFilter, + OnBeforeAfter, + OrderBy, + SearchFilter, + ) +except ImportError: + from ._filters import ( # type: ignore[assignment] + BeforeAfter, + CollectionFilter, + FilterTypes, + LimitOffset, + NotInCollectionFilter, + NotInSearchFilter, + OnBeforeAfter, + OrderBy, + SearchFilter, + ) + + +__all__ = ( + "BeforeAfter", + "CollectionFilter", + "FilterTypes", + "LimitOffset", + "OrderBy", + "SearchFilter", + "NotInCollectionFilter", + "OnBeforeAfter", + "NotInSearchFilter", +) diff --git a/venv/lib/python3.11/site-packages/litestar/repository/handlers.py b/venv/lib/python3.11/site-packages/litestar/repository/handlers.py new file mode 100644 index 0000000..0bc1434 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/handlers.py @@ -0,0 +1,37 @@ +from typing import TYPE_CHECKING + +from litestar.repository.filters import ( + BeforeAfter, + CollectionFilter, + FilterTypes, + LimitOffset, + NotInCollectionFilter, + NotInSearchFilter, + OnBeforeAfter, + OrderBy, + SearchFilter, +) + +if TYPE_CHECKING: + from litestar.config.app import AppConfig + +__all__ = ("signature_namespace_values", "on_app_init") + +signature_namespace_values = { + "BeforeAfter": BeforeAfter, + "OnBeforeAfter": OnBeforeAfter, + "CollectionFilter": CollectionFilter, + "LimitOffset": LimitOffset, + "OrderBy": OrderBy, + "SearchFilter": SearchFilter, + "NotInCollectionFilter": NotInCollectionFilter, + "NotInSearchFilter": NotInSearchFilter, + "FilterTypes": FilterTypes, +} + + +def on_app_init(app_config: "AppConfig") -> "AppConfig": + """Add custom filters for the application during signature modelling.""" + + app_config.signature_namespace.update(signature_namespace_values) + return app_config diff --git a/venv/lib/python3.11/site-packages/litestar/repository/testing/__init__.py b/venv/lib/python3.11/site-packages/litestar/repository/testing/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/testing/__init__.py diff --git a/venv/lib/python3.11/site-packages/litestar/repository/testing/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/repository/testing/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a008ff2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/testing/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/repository/testing/__pycache__/generic_mock_repository.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/repository/testing/__pycache__/generic_mock_repository.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c871877 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/testing/__pycache__/generic_mock_repository.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/repository/testing/generic_mock_repository.py b/venv/lib/python3.11/site-packages/litestar/repository/testing/generic_mock_repository.py new file mode 100644 index 0000000..5aa094c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/repository/testing/generic_mock_repository.py @@ -0,0 +1,784 @@ +"""A repository implementation for tests. + +Uses a `dict` for storage. +""" + +from __future__ import annotations + +from datetime import datetime, timezone, tzinfo +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar +from uuid import uuid4 + +from litestar.repository import AbstractAsyncRepository, AbstractSyncRepository, FilterTypes +from litestar.repository.exceptions import ConflictError, RepositoryError + +if TYPE_CHECKING: + from collections.abc import Callable, Hashable, Iterable, MutableMapping + from typing import Any + +__all__ = ("GenericAsyncMockRepository", "GenericSyncMockRepository") + + +class HasID(Protocol): + id: Any + + +ModelT = TypeVar("ModelT", bound="HasID") +AsyncMockRepoT = TypeVar("AsyncMockRepoT", bound="GenericAsyncMockRepository") +SyncMockRepoT = TypeVar("SyncMockRepoT", bound="GenericSyncMockRepository") + + +class GenericAsyncMockRepository(AbstractAsyncRepository[ModelT], Generic[ModelT]): + """A repository implementation for tests. + + Uses a :class:`dict` for storage. + """ + + collection: MutableMapping[Hashable, ModelT] + model_type: type[ModelT] + match_fields: list[str] | str | None = None + + _model_has_created_at: bool + _model_has_updated_at: bool + + def __init__( + self, id_factory: Callable[[], Any] = uuid4, tz: tzinfo = timezone.utc, allow_ids_on_add: bool = False, **_: Any + ) -> None: + super().__init__() + self._id_factory = id_factory + self.tz = tz + self.allow_ids_on_add = allow_ids_on_add + + @classmethod + def __class_getitem__(cls: type[AsyncMockRepoT], item: type[ModelT]) -> type[AsyncMockRepoT]: + """Add collection to ``_collections`` for the type. + + Args: + item: The type that the class has been parametrized with. + """ + return type( # pyright:ignore + f"{cls.__name__}[{item.__name__}]", + (cls,), + { + "collection": {}, + "model_type": item, + "_model_has_created_at": hasattr(item, "created_at"), + "_model_has_updated_at": hasattr(item, "updated_at"), + }, + ) + + def _find_or_raise_not_found(self, item_id: Any) -> ModelT: + return self.check_not_found(self.collection.get(item_id)) + + def _find_or_none(self, item_id: Any) -> ModelT | None: + return self.collection.get(item_id) + + def _now(self) -> datetime: + return datetime.now(tz=self.tz).replace(tzinfo=None) + + def _update_audit_attributes(self, data: ModelT, now: datetime | None = None, do_created: bool = False) -> ModelT: + now = now or self._now() + if self._model_has_updated_at: + data.updated_at = now # type:ignore[attr-defined] + if do_created: + data.created_at = now # type:ignore[attr-defined] + return data + + async def add(self, data: ModelT) -> ModelT: + """Add ``data`` to the collection. + + Args: + data: Instance to be added to the collection. + + Returns: + The added instance. + """ + if self.allow_ids_on_add is False and self.get_id_attribute_value(data) is not None: + raise ConflictError("`add()` received identified item.") + self._update_audit_attributes(data, do_created=True) + if self.allow_ids_on_add is False: + id_ = self._id_factory() + self.set_id_attribute_value(id_, data) + self.collection[data.id] = data + return data + + async def add_many(self, data: Iterable[ModelT]) -> list[ModelT]: + """Add multiple ``data`` to the collection. + + Args: + data: Instance to be added to the collection. + + Returns: + The added instance. + """ + now = self._now() + for data_row in data: + if self.allow_ids_on_add is False and self.get_id_attribute_value(data_row) is not None: + raise ConflictError("`add()` received identified item.") + + self._update_audit_attributes(data_row, do_created=True, now=now) + if self.allow_ids_on_add is False: + id_ = self._id_factory() + self.set_id_attribute_value(id_, data_row) + self.collection[data_row.id] = data_row + return list(data) + + async def delete(self, item_id: Any) -> ModelT: + """Delete instance identified by ``item_id``. + + Args: + item_id: Identifier of instance to be deleted. + + Returns: + The deleted instance. + + Raises: + NotFoundError: If no instance found identified by ``item_id``. + """ + try: + return self._find_or_raise_not_found(item_id) + finally: + del self.collection[item_id] + + async def delete_many(self, item_ids: list[Any]) -> list[ModelT]: + """Delete instances identified by list of identifiers ``item_ids``. + + Args: + item_ids: list of identifiers of instances to be deleted. + + Returns: + The deleted instances. + + """ + instances: list[ModelT] = [] + for item_id in item_ids: + obj = await self.get_one_or_none(**{self.id_attribute: item_id}) + if obj: + obj = await self.delete(obj.id) + instances.append(obj) + return instances + + async def exists(self, *filters: FilterTypes, **kwargs: Any) -> bool: + """Return true if the object specified by ``kwargs`` exists. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Identifier of the instance to be retrieved. + + Returns: + True if the instance was found. False if not found.. + + """ + existing = await self.count(*filters, **kwargs) + return bool(existing) + + async def get(self, item_id: Any, **kwargs: Any) -> ModelT: + """Get instance identified by ``item_id``. + + Args: + item_id: Identifier of the instance to be retrieved. + **kwargs: additional arguments + + Returns: + The retrieved instance. + + Raises: + NotFoundError: If no instance found identified by ``item_id``. + """ + return self._find_or_raise_not_found(item_id) + + async def get_or_create(self, match_fields: list[str] | str | None = None, **kwargs: Any) -> tuple[ModelT, bool]: + """Get instance identified by ``kwargs`` or create if it doesn't exist. + + Args: + match_fields: a list of keys to use to match the existing model. When empty, all fields are matched. + **kwargs: Identifier of the instance to be retrieved. + + Returns: + a tuple that includes the instance and whether it needed to be created. + + """ + match_fields = match_fields or self.match_fields + if isinstance(match_fields, str): + match_fields = [match_fields] + if match_fields: + match_filter = { + field_name: field_value + for field_name in match_fields + if (field_value := kwargs.get(field_name)) is not None + } + else: + match_filter = kwargs + existing = await self.get_one_or_none(**match_filter) + if existing: + for field_name, new_field_value in kwargs.items(): + field = getattr(existing, field_name, None) + if field and field != new_field_value: + setattr(existing, field_name, new_field_value) + + return existing, False + return await self.add(self.model_type(**kwargs)), True # pyright: ignore[reportGeneralTypeIssues] + + async def get_one(self, **kwargs: Any) -> ModelT: + """Get instance identified by query filters. + + Args: + **kwargs: Instance attribute value filters. + + Returns: + The retrieved instance or None + + Raises: + NotFoundError: If no instance found identified by ``kwargs``. + """ + data = await self.list(**kwargs) + return self.check_not_found(data[0] if data else None) + + async def get_one_or_none(self, **kwargs: Any) -> ModelT | None: + """Get instance identified by query filters or None if not found. + + Args: + **kwargs: Instance attribute value filters. + + Returns: + The retrieved instance or None + """ + data = await self.list(**kwargs) + return data[0] if data else None + + async def count(self, *filters: FilterTypes, **kwargs: Any) -> int: + """Count of rows returned by query. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + Count of instances in collection, ignoring pagination. + """ + return len(await self.list(*filters, **kwargs)) + + async def update(self, data: ModelT) -> ModelT: + """Update instance with the attribute values present on ``data``. + + Args: + data: An instance that should have a value for :attr:`id_attribute <AsyncGenericMockRepository.id_attribute>` that exists in the + collection. + + Returns: + The updated instance. + + Raises: + NotFoundError: If no instance found with same identifier as ``data``. + """ + item = self._find_or_raise_not_found(self.get_id_attribute_value(data)) + self._update_audit_attributes(data, do_created=False) + for key, val in model_items(data): + setattr(item, key, val) + return item + + async def update_many(self, data: list[ModelT]) -> list[ModelT]: + """Update instances with the attribute values present on ``data``. + + Args: + data: A list of instances that should have a value for :attr:`id_attribute <AsyncGenericMockRepository.id_attribute>` + that exists in the collection. + + Returns: + The updated instances. + + Raises: + NotFoundError: If no instance found with same identifier as ``data``. + """ + items = [self._find_or_raise_not_found(self.get_id_attribute_value(row)) for row in data] + now = self._now() + for item in items: + self._update_audit_attributes(item, do_created=False, now=now) + for key, val in model_items(item): + setattr(item, key, val) + return items + + async def upsert(self, data: ModelT) -> ModelT: + """Update or create instance. + + Updates instance with the attribute values present on ``data``, or creates a new instance if + one doesn't exist. + + Args: + data: Instance to update existing, or be created. Identifier used to determine if an + existing instance exists is the value of an attribute on `data` named as value of + :attr:`id_attribute <AsyncGenericMockRepository.id_attribute>`. + + Returns: + The updated or created instance. + + Raises: + NotFoundError: If no instance found with same identifier as ``data``. + """ + item_id = self.get_id_attribute_value(data) + if item_id in self.collection: + return await self.update(data) + return await self.add(data) + + async def upsert_many(self, data: list[ModelT]) -> list[ModelT]: + """Update or create multiple instance. + + Update instance with the attribute values present on ``data``, or create a new instance if + one doesn't exist. + + Args: + data: List of instances to update existing, or be created. Identifier used to determine if an + existing instance exists is the value of an attribute on `data` named as value of + :attr:`id_attribute <AsyncGenericMockRepository.id_attribute>`. + + Returns: + The updated or created instances. + """ + data_to_update = [row for row in data if self._find_or_none(self.get_id_attribute_value(row)) is not None] + data_to_add = [row for row in data if self._find_or_none(self.get_id_attribute_value(row)) is None] + + updated_items = await self.update_many(data_to_update) + added_items = await self.add_many(data_to_add) + return updated_items + added_items + + async def list_and_count( + self, + *filters: FilterTypes, + **kwargs: Any, + ) -> tuple[list[ModelT], int]: + """Get a list of instances, optionally filtered with a total row count. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + List of instances, and count of records returned by query, ignoring pagination. + """ + return await self.list(*filters, **kwargs), await self.count(*filters, **kwargs) + + async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[ModelT]: + """Get a list of instances, optionally filtered. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + The list of instances, after filtering applied. + """ + return list(self.filter_collection_by_kwargs(self.collection, **kwargs).values()) + + def filter_collection_by_kwargs( # type:ignore[override] + self, collection: MutableMapping[Hashable, ModelT], /, **kwargs: Any + ) -> MutableMapping[Hashable, ModelT]: + """Filter the collection by kwargs. + + Args: + collection: set of objects to filter + **kwargs: key/value pairs such that objects remaining in the collection after filtering + have the property that their attribute named ``key`` has value equal to ``value``. + """ + new_collection: dict[Hashable, ModelT] = {} + for item in self.collection.values(): + try: + if all(getattr(item, name) == value for name, value in kwargs.items()): + new_collection[item.id] = item + except AttributeError as orig: + raise RepositoryError from orig + return new_collection + + @classmethod + def seed_collection(cls, instances: Iterable[ModelT]) -> None: + """Seed the collection for repository type. + + Args: + instances: the instances to be added to the collection. + """ + for instance in instances: + cls.collection[cls.get_id_attribute_value(instance)] = instance + + @classmethod + def clear_collection(cls) -> None: + """Empty the collection for repository type.""" + cls.collection = {} + + +class GenericSyncMockRepository(AbstractSyncRepository[ModelT], Generic[ModelT]): + """A repository implementation for tests. + + Uses a :class:`dict` for storage. + """ + + collection: MutableMapping[Hashable, ModelT] + model_type: type[ModelT] + match_fields: list[str] | str | None = None + + _model_has_created_at: bool + _model_has_updated_at: bool + + def __init__( + self, + id_factory: Callable[[], Any] = uuid4, + tz: tzinfo = timezone.utc, + allow_ids_on_add: bool = False, + **_: Any, + ) -> None: + super().__init__() + self._id_factory = id_factory + self.tz = tz + self.allow_ids_on_add = allow_ids_on_add + + @classmethod + def __class_getitem__(cls: type[SyncMockRepoT], item: type[ModelT]) -> type[SyncMockRepoT]: + """Add collection to ``_collections`` for the type. + + Args: + item: The type that the class has been parametrized with. + """ + return type( # pyright:ignore + f"{cls.__name__}[{item.__name__}]", + (cls,), + { + "collection": {}, + "model_type": item, + "_model_has_created_at": hasattr(item, "created_at"), + "_model_has_updated_at": hasattr(item, "updated_at"), + }, + ) + + def _find_or_raise_not_found(self, item_id: Any) -> ModelT: + return self.check_not_found(self.collection.get(item_id)) + + def _find_or_none(self, item_id: Any) -> ModelT | None: + return self.collection.get(item_id) + + def _now(self) -> datetime: + return datetime.now(tz=self.tz).replace(tzinfo=None) + + def _update_audit_attributes(self, data: ModelT, now: datetime | None = None, do_created: bool = False) -> ModelT: + now = now or self._now() + if self._model_has_updated_at: + data.updated_at = now # type:ignore[attr-defined] + if do_created: + data.created_at = now # type:ignore[attr-defined] + return data + + def add(self, data: ModelT) -> ModelT: + """Add ``data`` to the collection. + + Args: + data: Instance to be added to the collection. + + Returns: + The added instance. + """ + if self.allow_ids_on_add is False and self.get_id_attribute_value(data) is not None: + raise ConflictError("`add()` received identified item.") + self._update_audit_attributes(data, do_created=True) + if self.allow_ids_on_add is False: + id_ = self._id_factory() + self.set_id_attribute_value(id_, data) + self.collection[data.id] = data + return data + + def add_many(self, data: Iterable[ModelT]) -> list[ModelT]: + """Add multiple ``data`` to the collection. + + Args: + data: Instance to be added to the collection. + + Returns: + The added instance. + """ + now = self._now() + for data_row in data: + if self.allow_ids_on_add is False and self.get_id_attribute_value(data_row) is not None: + raise ConflictError("`add()` received identified item.") + + self._update_audit_attributes(data_row, do_created=True, now=now) + if self.allow_ids_on_add is False: + id_ = self._id_factory() + self.set_id_attribute_value(id_, data_row) + self.collection[data_row.id] = data_row + return list(data) + + def delete(self, item_id: Any) -> ModelT: + """Delete instance identified by ``item_id``. + + Args: + item_id: Identifier of instance to be deleted. + + Returns: + The deleted instance. + + Raises: + NotFoundError: If no instance found identified by ``item_id``. + """ + try: + return self._find_or_raise_not_found(item_id) + finally: + del self.collection[item_id] + + def delete_many(self, item_ids: list[Any]) -> list[ModelT]: + """Delete instances identified by list of identifiers ``item_ids``. + + Args: + item_ids: list of identifiers of instances to be deleted. + + Returns: + The deleted instances. + + """ + instances: list[ModelT] = [] + for item_id in item_ids: + if obj := self.get_one_or_none(**{self.id_attribute: item_id}): + obj = self.delete(obj.id) + instances.append(obj) + return instances + + def exists(self, *filters: FilterTypes, **kwargs: Any) -> bool: + """Return true if the object specified by ``kwargs`` exists. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Identifier of the instance to be retrieved. + + Returns: + True if the instance was found. False if not found.. + + """ + existing = self.count(*filters, **kwargs) + return bool(existing) + + def get(self, item_id: Any, **kwargs: Any) -> ModelT: + """Get instance identified by ``item_id``. + + Args: + item_id: Identifier of the instance to be retrieved. + **kwargs: additional arguments + + Returns: + The retrieved instance. + + Raises: + NotFoundError: If no instance found identified by ``item_id``. + """ + return self._find_or_raise_not_found(item_id) + + def get_or_create(self, match_fields: list[str] | str | None = None, **kwargs: Any) -> tuple[ModelT, bool]: + """Get instance identified by ``kwargs`` or create if it doesn't exist. + + Args: + match_fields: a list of keys to use to match the existing model. When empty, all fields are matched. + **kwargs: Identifier of the instance to be retrieved. + + Returns: + a tuple that includes the instance and whether it needed to be created. + + """ + match_fields = match_fields or self.match_fields + if isinstance(match_fields, str): + match_fields = [match_fields] + if match_fields: + match_filter = { + field_name: field_value + for field_name in match_fields + if (field_value := kwargs.get(field_name)) is not None + } + else: + match_filter = kwargs + if existing := self.get_one_or_none(**match_filter): + for field_name, new_field_value in kwargs.items(): + field = getattr(existing, field_name, None) + if field and field != new_field_value: + setattr(existing, field_name, new_field_value) + + return existing, False + return self.add(self.model_type(**kwargs)), True # pyright: ignore[reportGeneralTypeIssues] + + def get_one(self, **kwargs: Any) -> ModelT: + """Get instance identified by query filters. + + Args: + **kwargs: Instance attribute value filters. + + Returns: + The retrieved instance or None + + Raises: + NotFoundError: If no instance found identified by ``kwargs``. + """ + data = self.list(**kwargs) + return self.check_not_found(data[0] if data else None) + + def get_one_or_none(self, **kwargs: Any) -> ModelT | None: + """Get instance identified by query filters or None if not found. + + Args: + **kwargs: Instance attribute value filters. + + Returns: + The retrieved instance or None + """ + data = self.list(**kwargs) + return data[0] if data else None + + def count(self, *filters: FilterTypes, **kwargs: Any) -> int: + """Count of rows returned by query. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + Count of instances in collection, ignoring pagination. + """ + return len(self.list(*filters, **kwargs)) + + def update(self, data: ModelT) -> ModelT: + """Update instance with the attribute values present on ``data``. + + Args: + data: An instance that should have a value for :attr:`id_attribute <AsyncGenericMockRepository.id_attribute>` that exists in the + collection. + + Returns: + The updated instance. + + Raises: + NotFoundError: If no instance found with same identifier as ``data``. + """ + item = self._find_or_raise_not_found(self.get_id_attribute_value(data)) + self._update_audit_attributes(data, do_created=False) + for key, val in model_items(data): + setattr(item, key, val) + return item + + def update_many(self, data: list[ModelT]) -> list[ModelT]: + """Update instances with the attribute values present on ``data``. + + Args: + data: A list of instances that should have a value for :attr:`id_attribute <AsyncGenericMockRepository.id_attribute>` + that exists in the collection. + + Returns: + The updated instances. + + Raises: + NotFoundError: If no instance found with same identifier as ``data``. + """ + items = [self._find_or_raise_not_found(self.get_id_attribute_value(row)) for row in data] + now = self._now() + for item in items: + self._update_audit_attributes(item, do_created=False, now=now) + for key, val in model_items(item): + setattr(item, key, val) + return items + + def upsert(self, data: ModelT) -> ModelT: + """Update or create instance. + + Updates instance with the attribute values present on ``data``, or creates a new instance if + one doesn't exist. + + Args: + data: Instance to update existing, or be created. Identifier used to determine if an + existing instance exists is the value of an attribute on `data` named as value of + :attr:`id_attribute <AsyncGenericMockRepository.id_attribute>`. + + Returns: + The updated or created instance. + + Raises: + NotFoundError: If no instance found with same identifier as ``data``. + """ + item_id = self.get_id_attribute_value(data) + return self.update(data) if item_id in self.collection else self.add(data) + + def upsert_many(self, data: list[ModelT]) -> list[ModelT]: + """Update or create multiple instance. + + Update instance with the attribute values present on ``data``, or create a new instance if + one doesn't exist. + + Args: + data: List of instances to update existing, or be created. Identifier used to determine if an + existing instance exists is the value of an attribute on `data` named as value of + :attr:`id_attribute <AsyncGenericMockRepository.id_attribute>`. + + Returns: + The updated or created instances. + """ + data_to_update = [row for row in data if self._find_or_none(self.get_id_attribute_value(row)) is not None] + data_to_add = [row for row in data if self._find_or_none(self.get_id_attribute_value(row)) is None] + + updated_items = self.update_many(data_to_update) + added_items = self.add_many(data_to_add) + return updated_items + added_items + + def list_and_count( + self, + *filters: FilterTypes, + **kwargs: Any, + ) -> tuple[list[ModelT], int]: + """Get a list of instances, optionally filtered with a total row count. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + List of instances, and count of records returned by query, ignoring pagination. + """ + return self.list(*filters, **kwargs), self.count(*filters, **kwargs) + + def list(self, *filters: FilterTypes, **kwargs: Any) -> list[ModelT]: + """Get a list of instances, optionally filtered. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + The list of instances, after filtering applied. + """ + return list(self.filter_collection_by_kwargs(self.collection, **kwargs).values()) + + def filter_collection_by_kwargs( # type:ignore[override] + self, collection: MutableMapping[Hashable, ModelT], /, **kwargs: Any + ) -> MutableMapping[Hashable, ModelT]: + """Filter the collection by kwargs. + + Args: + collection: set of objects to filter + **kwargs: key/value pairs such that objects remaining in the collection after filtering + have the property that their attribute named ``key`` has value equal to ``value``. + """ + new_collection: dict[Hashable, ModelT] = {} + for item in self.collection.values(): + try: + if all(getattr(item, name) == value for name, value in kwargs.items()): + new_collection[item.id] = item + except AttributeError as orig: + raise RepositoryError from orig + return new_collection + + @classmethod + def seed_collection(cls, instances: Iterable[ModelT]) -> None: + """Seed the collection for repository type. + + Args: + instances: the instances to be added to the collection. + """ + for instance in instances: + cls.collection[cls.get_id_attribute_value(instance)] = instance + + @classmethod + def clear_collection(cls) -> None: + """Empty the collection for repository type.""" + cls.collection = {} + + +def model_items(model: Any) -> list[tuple[str, Any]]: + return [(k, v) for k, v in model.__dict__.items() if not k.startswith("_")] diff --git a/venv/lib/python3.11/site-packages/litestar/response/__init__.py b/venv/lib/python3.11/site-packages/litestar/response/__init__.py new file mode 100644 index 0000000..c655758 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/response/__init__.py @@ -0,0 +1,16 @@ +from .base import Response +from .file import File +from .redirect import Redirect +from .sse import ServerSentEvent, ServerSentEventMessage +from .streaming import Stream +from .template import Template + +__all__ = ( + "File", + "Redirect", + "Response", + "ServerSentEvent", + "ServerSentEventMessage", + "Stream", + "Template", +) diff --git a/venv/lib/python3.11/site-packages/litestar/response/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/response/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a806282 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/response/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/response/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/response/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..885cbe8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/response/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/response/__pycache__/file.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/response/__pycache__/file.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..23c09f2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/response/__pycache__/file.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/response/__pycache__/redirect.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/response/__pycache__/redirect.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..d776095 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/response/__pycache__/redirect.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/response/__pycache__/sse.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/response/__pycache__/sse.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2a03883 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/response/__pycache__/sse.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/response/__pycache__/streaming.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/response/__pycache__/streaming.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..350cf9e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/response/__pycache__/streaming.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/response/__pycache__/template.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/response/__pycache__/template.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..91dde51 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/response/__pycache__/template.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/response/base.py b/venv/lib/python3.11/site-packages/litestar/response/base.py new file mode 100644 index 0000000..67eec09 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/response/base.py @@ -0,0 +1,459 @@ +from __future__ import annotations + +import itertools +import re +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Iterable, Literal, Mapping, TypeVar, overload + +from litestar.datastructures.cookie import Cookie +from litestar.datastructures.headers import ETag, MutableScopeHeaders +from litestar.enums import MediaType, OpenAPIMediaType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.serialization import default_serializer, encode_json, encode_msgpack, get_serializer +from litestar.status_codes import HTTP_200_OK, HTTP_204_NO_CONTENT, HTTP_304_NOT_MODIFIED +from litestar.types.empty import Empty +from litestar.utils.deprecation import deprecated, warn_deprecation +from litestar.utils.helpers import get_enum_string_value + +if TYPE_CHECKING: + from typing import Optional + + from litestar.app import Litestar + from litestar.background_tasks import BackgroundTask, BackgroundTasks + from litestar.connection import Request + from litestar.types import ( + HTTPResponseBodyEvent, + HTTPResponseStartEvent, + Receive, + ResponseCookies, + ResponseHeaders, + Scope, + Send, + Serializer, + TypeEncodersMap, + ) + +__all__ = ("ASGIResponse", "Response") + +T = TypeVar("T") + +MEDIA_TYPE_APPLICATION_JSON_PATTERN = re.compile(r"^application/(?:.+\+)?json") + + +class ASGIResponse: + """A low-level ASGI response class.""" + + __slots__ = ( + "background", + "body", + "content_length", + "encoding", + "is_head_response", + "status_code", + "_encoded_cookies", + "headers", + ) + + _should_set_content_length: ClassVar[bool] = True + """A flag to indicate whether the content-length header should be set by default or not.""" + + def __init__( + self, + *, + background: BackgroundTask | BackgroundTasks | None = None, + body: bytes | str = b"", + content_length: int | None = None, + cookies: Iterable[Cookie] | None = None, + encoded_headers: Iterable[tuple[bytes, bytes]] | None = None, + encoding: str = "utf-8", + headers: dict[str, Any] | Iterable[tuple[str, str]] | None = None, + is_head_response: bool = False, + media_type: MediaType | str | None = None, + status_code: int | None = None, + ) -> None: + """A low-level ASGI response class. + + Args: + background: A background task or a list of background tasks to be executed after the response is sent. + body: encoded content to send in the response body. + content_length: The response content length. + cookies: The response cookies. + encoded_headers: The response headers. + encoding: The response encoding. + headers: The response headers. + is_head_response: A boolean indicating if the response is a HEAD response. + media_type: The response media type. + status_code: The response status code. + """ + body = body.encode() if isinstance(body, str) else body + status_code = status_code or HTTP_200_OK + self.headers = MutableScopeHeaders() + + if encoded_headers is not None: + warn_deprecation("3.0", kind="parameter", deprecated_name="encoded_headers", alternative="headers") + for header_name, header_value in encoded_headers: + self.headers.add(header_name.decode("latin-1"), header_value.decode("latin-1")) + + if headers is not None: + for k, v in headers.items() if isinstance(headers, dict) else headers: + self.headers.add(k, v) # pyright: ignore + + media_type = get_enum_string_value(media_type or MediaType.JSON) + + status_allows_body = ( + status_code not in {HTTP_204_NO_CONTENT, HTTP_304_NOT_MODIFIED} and status_code >= HTTP_200_OK + ) + + if content_length is None: + content_length = len(body) + + if not status_allows_body or is_head_response: + if body and body != b"null": + raise ImproperlyConfiguredException( + "response content is not supported for HEAD responses and responses with a status code " + "that does not allow content (304, 204, < 200)" + ) + body = b"" + else: + self.headers.setdefault( + "content-type", (f"{media_type}; charset={encoding}" if media_type.startswith("text/") else media_type) + ) + + if self._should_set_content_length: + self.headers.setdefault("content-length", str(content_length)) + + self.background = background + self.body = body + self.content_length = content_length + self._encoded_cookies = tuple( + cookie.to_encoded_header() for cookie in (cookies or ()) if not cookie.documentation_only + ) + self.encoding = encoding + self.is_head_response = is_head_response + self.status_code = status_code + + @property + @deprecated("3.0", kind="property", alternative="encode_headers()") + def encoded_headers(self) -> list[tuple[bytes, bytes]]: + return self.encode_headers() + + def encode_headers(self) -> list[tuple[bytes, bytes]]: + return [*self.headers.headers, *self._encoded_cookies] + + async def after_response(self) -> None: + """Execute after the response is sent. + + Returns: + None + """ + if self.background is not None: + await self.background() + + async def start_response(self, send: Send) -> None: + """Emit the start event of the response. This event includes the headers and status codes. + + Args: + send: The ASGI send function. + + Returns: + None + """ + event: HTTPResponseStartEvent = { + "type": "http.response.start", + "status": self.status_code, + "headers": self.encode_headers(), + } + await send(event) + + async def send_body(self, send: Send, receive: Receive) -> None: + """Emit the response body. + + Args: + send: The ASGI send function. + receive: The ASGI receive function. + + Notes: + - Response subclasses should customize this method if there is a need to customize sending data. + + Returns: + None + """ + event: HTTPResponseBodyEvent = {"type": "http.response.body", "body": self.body, "more_body": False} + await send(event) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable of the ``Response``. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + await self.start_response(send=send) + + if self.is_head_response: + event: HTTPResponseBodyEvent = {"type": "http.response.body", "body": b"", "more_body": False} + await send(event) + else: + await self.send_body(send=send, receive=receive) + + await self.after_response() + + +class Response(Generic[T]): + """Base Litestar HTTP response class, used as the basis for all other response classes.""" + + __slots__ = ( + "background", + "content", + "cookies", + "encoding", + "headers", + "media_type", + "status_code", + "response_type_encoders", + ) + + content: T + type_encoders: Optional[TypeEncodersMap] = None # noqa: UP007 + + def __init__( + self, + content: T, + *, + background: BackgroundTask | BackgroundTasks | None = None, + cookies: ResponseCookies | None = None, + encoding: str = "utf-8", + headers: ResponseHeaders | None = None, + media_type: MediaType | OpenAPIMediaType | str | None = None, + status_code: int | None = None, + type_encoders: TypeEncodersMap | None = None, + ) -> None: + """Initialize the response. + + Args: + content: A value for the response body that will be rendered into bytes string. + status_code: An HTTP status code. + media_type: A value for the response ``Content-Type`` header. + background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or + :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. + Defaults to ``None``. + headers: A string keyed dictionary of response headers. Header keys are insensitive. + cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response + ``Set-Cookie`` header. + encoding: The encoding to be used for the response headers. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + """ + self.content = content + self.background = background + self.cookies: list[Cookie] = ( + [Cookie(key=key, value=value) for key, value in cookies.items()] + if isinstance(cookies, Mapping) + else list(cookies or []) + ) + self.encoding = encoding + self.headers: dict[str, Any] = ( + dict(headers) if isinstance(headers, Mapping) else {h.name: h.value for h in headers or {}} + ) + self.media_type = media_type + self.status_code = status_code + self.response_type_encoders = {**(self.type_encoders or {}), **(type_encoders or {})} + + @overload + def set_cookie(self, /, cookie: Cookie) -> None: ... + + @overload + def set_cookie( + self, + key: str, + value: str | None = None, + max_age: int | None = None, + expires: int | None = None, + path: str = "/", + domain: str | None = None, + secure: bool = False, + httponly: bool = False, + samesite: Literal["lax", "strict", "none"] = "lax", + ) -> None: ... + + def set_cookie( # type: ignore[misc] + self, + key: str | Cookie, + value: str | None = None, + max_age: int | None = None, + expires: int | None = None, + path: str = "/", + domain: str | None = None, + secure: bool = False, + httponly: bool = False, + samesite: Literal["lax", "strict", "none"] = "lax", + ) -> None: + """Set a cookie on the response. If passed a :class:`Cookie <.datastructures.Cookie>` instance, keyword + arguments will be ignored. + + Args: + key: Key for the cookie or a :class:`Cookie <.datastructures.Cookie>` instance. + value: Value for the cookie, if none given defaults to empty string. + max_age: Maximal age of the cookie before its invalidated. + expires: Seconds from now until the cookie expires. + path: Path fragment that must exist in the request url for the cookie to be valid. Defaults to ``/``. + domain: Domain for which the cookie is valid. + secure: Https is required for the cookie. + httponly: Forbids javascript to access the cookie via ``document.cookie``. + samesite: Controls whether a cookie is sent with cross-site requests. Defaults to ``lax``. + + Returns: + None. + """ + if not isinstance(key, Cookie): + key = Cookie( + domain=domain, + expires=expires, + httponly=httponly, + key=key, + max_age=max_age, + path=path, + samesite=samesite, + secure=secure, + value=value, + ) + self.cookies.append(key) + + def set_header(self, key: str, value: Any) -> None: + """Set a header on the response. + + Args: + key: Header key. + value: Header value. + + Returns: + None. + """ + self.headers[key] = value + + def set_etag(self, etag: str | ETag) -> None: + """Set an etag header. + + Args: + etag: An etag value. + + Returns: + None + """ + self.headers["etag"] = etag.to_header() if isinstance(etag, ETag) else etag + + def delete_cookie( + self, + key: str, + path: str = "/", + domain: str | None = None, + ) -> None: + """Delete a cookie. + + Args: + key: Key of the cookie. + path: Path of the cookie. + domain: Domain of the cookie. + + Returns: + None. + """ + cookie = Cookie(key=key, path=path, domain=domain, expires=0, max_age=0) + self.cookies = [c for c in self.cookies if c != cookie] + self.cookies.append(cookie) + + def render(self, content: Any, media_type: str, enc_hook: Serializer = default_serializer) -> bytes: + """Handle the rendering of content into a bytes string. + + Returns: + An encoded bytes string + """ + if isinstance(content, bytes): + return content + + if content is Empty: + raise RuntimeError("The `Empty` sentinel cannot be used as response content") + + try: + if media_type.startswith("text/") and not content: + return b"" + + if isinstance(content, str): + return content.encode(self.encoding) + + if media_type == MediaType.MESSAGEPACK: + return encode_msgpack(content, enc_hook) + + if MEDIA_TYPE_APPLICATION_JSON_PATTERN.match( + media_type, + ): + return encode_json(content, enc_hook) + + raise ImproperlyConfiguredException(f"unsupported media_type {media_type} for content {content!r}") + except (AttributeError, ValueError, TypeError) as e: + raise ImproperlyConfiguredException("Unable to serialize response content") from e + + def to_asgi_response( + self, + app: Litestar | None, + request: Request, + *, + background: BackgroundTask | BackgroundTasks | None = None, + cookies: Iterable[Cookie] | None = None, + encoded_headers: Iterable[tuple[bytes, bytes]] | None = None, + headers: dict[str, str] | None = None, + is_head_response: bool = False, + media_type: MediaType | str | None = None, + status_code: int | None = None, + type_encoders: TypeEncodersMap | None = None, + ) -> ASGIResponse: + """Create an ASGIResponse from a Response instance. + + Args: + app: The :class:`Litestar <.app.Litestar>` application instance. + background: Background task(s) to be executed after the response is sent. + cookies: A list of cookies to be set on the response. + encoded_headers: A list of already encoded headers. + headers: Additional headers to be merged with the response headers. Response headers take precedence. + is_head_response: Whether the response is a HEAD response. + media_type: Media type for the response. If ``media_type`` is already set on the response, this is ignored. + request: The :class:`Request <.connection.Request>` instance. + status_code: Status code for the response. If ``status_code`` is already set on the response, this is + type_encoders: A dictionary of type encoders to use for encoding the response content. + + Returns: + An ASGIResponse instance. + """ + + if app is not None: + warn_deprecation( + version="2.1", + deprecated_name="app", + kind="parameter", + removal_in="3.0.0", + alternative="request.app", + ) + + headers = {**headers, **self.headers} if headers is not None else self.headers + cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies) + + if type_encoders: + type_encoders = {**type_encoders, **(self.response_type_encoders or {})} + else: + type_encoders = self.response_type_encoders + + media_type = get_enum_string_value(self.media_type or media_type or MediaType.JSON) + + return ASGIResponse( + background=self.background or background, + body=self.render(self.content, media_type, get_serializer(type_encoders)), + cookies=cookies, + encoded_headers=encoded_headers, + encoding=self.encoding, + headers=headers, + is_head_response=is_head_response, + media_type=media_type, + status_code=self.status_code or status_code, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/response/file.py b/venv/lib/python3.11/site-packages/litestar/response/file.py new file mode 100644 index 0000000..1fc6f86 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/response/file.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +import itertools +from email.utils import formatdate +from inspect import iscoroutine +from mimetypes import encodings_map, guess_type +from typing import TYPE_CHECKING, Any, AsyncGenerator, Coroutine, Iterable, Literal, cast +from urllib.parse import quote +from zlib import adler32 + +from litestar.constants import ONE_MEGABYTE +from litestar.exceptions import ImproperlyConfiguredException +from litestar.file_system import BaseLocalFileSystem, FileSystemAdapter +from litestar.response.base import Response +from litestar.response.streaming import ASGIStreamingResponse +from litestar.utils.deprecation import warn_deprecation +from litestar.utils.helpers import get_enum_string_value + +if TYPE_CHECKING: + from os import PathLike + from os import stat_result as stat_result_type + + from anyio import Path + + from litestar.app import Litestar + from litestar.background_tasks import BackgroundTask, BackgroundTasks + from litestar.connection import Request + from litestar.datastructures.cookie import Cookie + from litestar.datastructures.headers import ETag + from litestar.enums import MediaType + from litestar.types import ( + HTTPResponseBodyEvent, + PathType, + Receive, + ResponseCookies, + ResponseHeaders, + Send, + TypeEncodersMap, + ) + from litestar.types.file_types import FileInfo, FileSystemProtocol + +__all__ = ( + "ASGIFileResponse", + "File", + "async_file_iterator", + "create_etag_for_file", +) + +# brotli not supported in 'mimetypes.encodings_map' until py 3.9. +encodings_map[".br"] = "br" + + +async def async_file_iterator( + file_path: PathType, chunk_size: int, adapter: FileSystemAdapter +) -> AsyncGenerator[bytes, None]: + """Return an async that asynchronously reads a file and yields its chunks. + + Args: + file_path: A path to a file. + chunk_size: The chunk file to use. + adapter: File system adapter class. + adapter: File system adapter class. + + Returns: + An async generator. + """ + async with await adapter.open(file_path) as file: + while chunk := await file.read(chunk_size): + yield chunk + + +def create_etag_for_file(path: PathType, modified_time: float, file_size: int) -> str: + """Create an etag. + + Notes: + - Function is derived from flask. + + Returns: + An etag. + """ + check = adler32(str(path).encode("utf-8")) & 0xFFFFFFFF + return f'"{modified_time}-{file_size}-{check}"' + + +class ASGIFileResponse(ASGIStreamingResponse): + """A low-level ASGI response, streaming a file as response body.""" + + def __init__( + self, + *, + background: BackgroundTask | BackgroundTasks | None = None, + body: bytes | str = b"", + chunk_size: int = ONE_MEGABYTE, + content_disposition_type: Literal["attachment", "inline"] = "attachment", + content_length: int | None = None, + cookies: Iterable[Cookie] | None = None, + encoded_headers: Iterable[tuple[bytes, bytes]] | None = None, + encoding: str = "utf-8", + etag: ETag | None = None, + file_info: FileInfo | Coroutine[None, None, FileInfo] | None = None, + file_path: str | PathLike | Path, + file_system: FileSystemProtocol | None = None, + filename: str = "", + headers: dict[str, str] | None = None, + is_head_response: bool = False, + media_type: MediaType | str | None = None, + stat_result: stat_result_type | None = None, + status_code: int | None = None, + ) -> None: + """A low-level ASGI response, streaming a file as response body. + + Args: + background: A background task or a list of background tasks to be executed after the response is sent. + body: encoded content to send in the response body. + chunk_size: The chunk size to use. + content_disposition_type: The type of the ``Content-Disposition``. Either ``inline`` or ``attachment``. + content_length: The response content length. + cookies: The response cookies. + encoded_headers: A list of encoded headers. + encoding: The response encoding. + etag: An etag. + file_info: A file info. + file_path: A path to a file. + file_system: A file system adapter. + filename: The name of the file. + headers: A dictionary of headers. + headers: The response headers. + is_head_response: A boolean indicating if the response is a HEAD response. + media_type: The media type of the file. + stat_result: A stat result. + status_code: The response status code. + """ + headers = headers or {} + if not media_type: + mimetype, content_encoding = guess_type(filename) if filename else (None, None) + media_type = mimetype or "application/octet-stream" + if content_encoding is not None: + headers.update({"content-encoding": content_encoding}) + + self.adapter = FileSystemAdapter(file_system or BaseLocalFileSystem()) + + super().__init__( + iterator=async_file_iterator(file_path=file_path, chunk_size=chunk_size, adapter=self.adapter), + headers=headers, + media_type=media_type, + cookies=cookies, + background=background, + status_code=status_code, + body=body, + content_length=content_length, + encoding=encoding, + is_head_response=is_head_response, + encoded_headers=encoded_headers, + ) + + quoted_filename = quote(filename) + is_utf8 = quoted_filename == filename + if is_utf8: + content_disposition = f'{content_disposition_type}; filename="{filename}"' + else: + content_disposition = f"{content_disposition_type}; filename*=utf-8''{quoted_filename}" + + self.headers.setdefault("content-disposition", content_disposition) + + self.chunk_size = chunk_size + self.etag = etag + self.file_path = file_path + + if file_info: + self.file_info: FileInfo | Coroutine[Any, Any, FileInfo] = file_info + elif stat_result: + self.file_info = self.adapter.parse_stat_result(result=stat_result, path=file_path) + else: + self.file_info = self.adapter.info(self.file_path) + + async def send_body(self, send: Send, receive: Receive) -> None: + """Emit a stream of events correlating with the response body. + + Args: + send: The ASGI send function. + receive: The ASGI receive function. + + Returns: + None + """ + if self.chunk_size < self.content_length: + await super().send_body(send=send, receive=receive) + return + + async with await self.adapter.open(self.file_path) as file: + body_event: HTTPResponseBodyEvent = { + "type": "http.response.body", + "body": await file.read(), + "more_body": False, + } + await send(body_event) + + async def start_response(self, send: Send) -> None: + """Emit the start event of the response. This event includes the headers and status codes. + + Args: + send: The ASGI send function. + + Returns: + None + """ + try: + fs_info = self.file_info = cast( + "FileInfo", (await self.file_info if iscoroutine(self.file_info) else self.file_info) + ) + except FileNotFoundError as e: + raise ImproperlyConfiguredException(f"{self.file_path} does not exist") from e + + if fs_info["type"] != "file": + raise ImproperlyConfiguredException(f"{self.file_path} is not a file") + + self.content_length = fs_info["size"] + + self.headers.setdefault("content-length", str(self.content_length)) + self.headers.setdefault("last-modified", formatdate(fs_info["mtime"], usegmt=True)) + + if self.etag: + self.headers.setdefault("etag", self.etag.to_header()) + else: + self.headers.setdefault( + "etag", + create_etag_for_file(path=self.file_path, modified_time=fs_info["mtime"], file_size=fs_info["size"]), + ) + + await super().start_response(send=send) + + +class File(Response): + """A response, streaming a file as response body.""" + + __slots__ = ( + "chunk_size", + "content_disposition_type", + "etag", + "file_path", + "file_system", + "filename", + "file_info", + "stat_result", + ) + + def __init__( + self, + path: str | PathLike | Path, + *, + background: BackgroundTask | BackgroundTasks | None = None, + chunk_size: int = ONE_MEGABYTE, + content_disposition_type: Literal["attachment", "inline"] = "attachment", + cookies: ResponseCookies | None = None, + encoding: str = "utf-8", + etag: ETag | None = None, + file_info: FileInfo | Coroutine[Any, Any, FileInfo] | None = None, + file_system: FileSystemProtocol | None = None, + filename: str | None = None, + headers: ResponseHeaders | None = None, + media_type: Literal[MediaType.TEXT] | str | None = None, + stat_result: stat_result_type | None = None, + status_code: int | None = None, + ) -> None: + """Initialize ``File`` + + Notes: + - This class extends the :class:`Stream <.response.Stream>` class. + + Args: + path: A file path in one of the supported formats. + background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or + :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. + Defaults to None. + chunk_size: The chunk sizes to use when streaming the file. Defaults to 1MB. + content_disposition_type: The type of the ``Content-Disposition``. Either ``inline`` or ``attachment``. + cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response + ``Set-Cookie`` header. + encoding: The encoding to be used for the response headers. + etag: An optional :class:`ETag <.datastructures.ETag>` instance. If not provided, an etag will be + generated. + file_info: The output of calling :meth:`file_system.info <types.FileSystemProtocol.info>`, equivalent to + providing an :class:`os.stat_result`. + file_system: An implementation of the :class:`FileSystemProtocol <.types.FileSystemProtocol>`. If provided + it will be used to load the file. + filename: An optional filename to set in the header. + headers: A string keyed dictionary of response headers. Header keys are insensitive. + media_type: A value for the response ``Content-Type`` header. If not provided, the value will be either + derived from the filename if provided and supported by the stdlib, or will default to + ``application/octet-stream``. + stat_result: An optional result of calling :func:os.stat:. If not provided, this will be done by the + response constructor. + status_code: An HTTP status code. + """ + + if file_system is not None and not ( + callable(getattr(file_system, "info", None)) and callable(getattr(file_system, "open", None)) + ): + raise ImproperlyConfiguredException("file_system must adhere to the FileSystemProtocol type") + + self.chunk_size = chunk_size + self.content_disposition_type = content_disposition_type + self.etag = etag + self.file_info = file_info + self.file_path = path + self.file_system = file_system + self.filename = filename or "" + self.stat_result = stat_result + + super().__init__( + content=None, + status_code=status_code, + media_type=media_type, + background=background, + headers=headers, + cookies=cookies, + encoding=encoding, + ) + + def to_asgi_response( + self, + app: Litestar | None, + request: Request, + *, + background: BackgroundTask | BackgroundTasks | None = None, + encoded_headers: Iterable[tuple[bytes, bytes]] | None = None, + cookies: Iterable[Cookie] | None = None, + headers: dict[str, str] | None = None, + is_head_response: bool = False, + media_type: MediaType | str | None = None, + status_code: int | None = None, + type_encoders: TypeEncodersMap | None = None, + ) -> ASGIFileResponse: + """Create an :class:`ASGIFileResponse <litestar.response.file.ASGIFileResponse>` instance. + + Args: + app: The :class:`Litestar <.app.Litestar>` application instance. + background: Background task(s) to be executed after the response is sent. + cookies: A list of cookies to be set on the response. + encoded_headers: A list of already encoded headers. + headers: Additional headers to be merged with the response headers. Response headers take precedence. + is_head_response: Whether the response is a HEAD response. + media_type: Media type for the response. If ``media_type`` is already set on the response, this is ignored. + request: The :class:`Request <.connection.Request>` instance. + status_code: Status code for the response. If ``status_code`` is already set on the response, this is + type_encoders: A dictionary of type encoders to use for encoding the response content. + + Returns: + A low-level ASGI file response. + """ + if app is not None: + warn_deprecation( + version="2.1", + deprecated_name="app", + kind="parameter", + removal_in="3.0.0", + alternative="request.app", + ) + + headers = {**headers, **self.headers} if headers is not None else self.headers + cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies) + + media_type = self.media_type or media_type + if media_type is not None: + media_type = get_enum_string_value(media_type) + + return ASGIFileResponse( + background=self.background or background, + body=b"", + chunk_size=self.chunk_size, + content_disposition_type=self.content_disposition_type, # pyright: ignore + content_length=0, + cookies=cookies, + encoded_headers=encoded_headers, + encoding=self.encoding, + etag=self.etag, + file_info=self.file_info, + file_path=self.file_path, + file_system=self.file_system, + filename=self.filename, + headers=headers, + is_head_response=is_head_response, + media_type=media_type, + stat_result=self.stat_result, + status_code=self.status_code or status_code, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/response/redirect.py b/venv/lib/python3.11/site-packages/litestar/response/redirect.py new file mode 100644 index 0000000..6a07076 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/response/redirect.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING, Any, Iterable, Literal + +from litestar.constants import REDIRECT_ALLOWED_MEDIA_TYPES, REDIRECT_STATUS_CODES +from litestar.enums import MediaType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.response.base import ASGIResponse, Response +from litestar.status_codes import HTTP_302_FOUND +from litestar.utils import url_quote +from litestar.utils.deprecation import warn_deprecation +from litestar.utils.helpers import get_enum_string_value + +if TYPE_CHECKING: + from litestar.app import Litestar + from litestar.background_tasks import BackgroundTask, BackgroundTasks + from litestar.connection import Request + from litestar.datastructures import Cookie + from litestar.types import ResponseCookies, ResponseHeaders, TypeEncodersMap + +__all__ = ( + "ASGIRedirectResponse", + "Redirect", +) + + +RedirectStatusType = Literal[301, 302, 303, 307, 308] +"""Acceptable status codes for redirect responses.""" + + +class ASGIRedirectResponse(ASGIResponse): + """A low-level ASGI redirect response class.""" + + def __init__( + self, + path: str | bytes, + media_type: str | None = None, + status_code: RedirectStatusType | None = None, + headers: dict[str, Any] | None = None, + encoded_headers: Iterable[tuple[bytes, bytes]] | None = None, + background: BackgroundTask | BackgroundTasks | None = None, + body: bytes | str = b"", + content_length: int | None = None, + cookies: Iterable[Cookie] | None = None, + encoding: str = "utf-8", + is_head_response: bool = False, + ) -> None: + headers = {**(headers or {}), "location": url_quote(path)} + media_type = media_type or MediaType.TEXT + status_code = status_code or HTTP_302_FOUND + + if status_code not in REDIRECT_STATUS_CODES: + raise ImproperlyConfiguredException( + f"{status_code} is not a valid for this response. " + f"Redirect responses should have one of " + f"the following status codes: {', '.join([str(s) for s in REDIRECT_STATUS_CODES])}" + ) + + if media_type not in REDIRECT_ALLOWED_MEDIA_TYPES: + raise ImproperlyConfiguredException( + f"{media_type} media type is not supported yet. " + f"Media type should be one of " + f"the following values: {', '.join([str(s) for s in REDIRECT_ALLOWED_MEDIA_TYPES])}" + ) + + super().__init__( + status_code=status_code, + headers=headers, + media_type=media_type, + background=background, + is_head_response=is_head_response, + encoding=encoding, + cookies=cookies, + content_length=content_length, + body=body, + encoded_headers=encoded_headers, + ) + + +class Redirect(Response[Any]): + """A redirect response.""" + + __slots__ = ("url",) + + def __init__( + self, + path: str, + *, + background: BackgroundTask | BackgroundTasks | None = None, + cookies: ResponseCookies | None = None, + encoding: str = "utf-8", + headers: ResponseHeaders | None = None, + media_type: str | MediaType | None = None, + status_code: RedirectStatusType | None = None, + type_encoders: TypeEncodersMap | None = None, + ) -> None: + """Initialize the response. + + Args: + path: A path to redirect to. + background: A background task or tasks to be run after the response is sent. + cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response + ``Set-Cookie`` header. + encoding: The encoding to be used for the response headers. + headers: A string keyed dictionary of response headers. Header keys are insensitive. + media_type: A value for the response ``Content-Type`` header. + status_code: An HTTP status code. The status code should be one of 301, 302, 303, 307 or 308, + otherwise an exception will be raised. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + + Raises: + ImproperlyConfiguredException: Either if status code is not a redirect status code or media type is not + supported. + """ + self.url = path + if status_code is None: + status_code = HTTP_302_FOUND + super().__init__( + background=background, + content=b"", + cookies=cookies, + encoding=encoding, + headers=headers, + media_type=media_type, + status_code=status_code, + type_encoders=type_encoders, + ) + + def to_asgi_response( + self, + app: Litestar | None, + request: Request, + *, + background: BackgroundTask | BackgroundTasks | None = None, + cookies: Iterable[Cookie] | None = None, + encoded_headers: Iterable[tuple[bytes, bytes]] | None = None, + headers: dict[str, str] | None = None, + is_head_response: bool = False, + media_type: MediaType | str | None = None, + status_code: int | None = None, + type_encoders: TypeEncodersMap | None = None, + ) -> ASGIResponse: + headers = {**headers, **self.headers} if headers is not None else self.headers + cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies) + media_type = get_enum_string_value(self.media_type or media_type or MediaType.TEXT) + + if app is not None: + warn_deprecation( + version="2.1", + deprecated_name="app", + kind="parameter", + removal_in="3.0.0", + alternative="request.app", + ) + + return ASGIRedirectResponse( + path=self.url, + background=self.background or background, + body=b"", + content_length=None, + cookies=cookies, + encoded_headers=encoded_headers, + encoding=self.encoding, + headers=headers, + is_head_response=is_head_response, + media_type=media_type, + status_code=self.status_code or status_code, # type:ignore[arg-type] + ) diff --git a/venv/lib/python3.11/site-packages/litestar/response/sse.py b/venv/lib/python3.11/site-packages/litestar/response/sse.py new file mode 100644 index 0000000..48a9192 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/response/sse.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import io +import re +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Iterator + +from litestar.concurrency import sync_to_thread +from litestar.exceptions import ImproperlyConfiguredException +from litestar.response.streaming import Stream +from litestar.utils import AsyncIteratorWrapper + +if TYPE_CHECKING: + from litestar.background_tasks import BackgroundTask, BackgroundTasks + from litestar.types import ResponseCookies, ResponseHeaders, SSEData, StreamType + +_LINE_BREAK_RE = re.compile(r"\r\n|\r|\n") +DEFAULT_SEPARATOR = "\r\n" + + +class _ServerSentEventIterator(AsyncIteratorWrapper[bytes]): + __slots__ = ("content_async_iterator", "event_id", "event_type", "retry_duration", "comment_message") + + content_async_iterator: AsyncIterable[SSEData] + + def __init__( + self, + content: str | bytes | StreamType[SSEData], + event_type: str | None = None, + event_id: int | str | None = None, + retry_duration: int | None = None, + comment_message: str | None = None, + ) -> None: + self.comment_message = comment_message + self.event_id = event_id + self.event_type = event_type + self.retry_duration = retry_duration + chunks: list[bytes] = [] + if comment_message is not None: + chunks.extend([f": {chunk}\r\n".encode() for chunk in _LINE_BREAK_RE.split(comment_message)]) + + if event_id is not None: + chunks.append(f"id: {event_id}\r\n".encode()) + + if event_type is not None: + chunks.append(f"event: {event_type}\r\n".encode()) + + if retry_duration is not None: + chunks.append(f"retry: {retry_duration}\r\n".encode()) + + super().__init__(iterator=chunks) + + if not isinstance(content, (Iterator, AsyncIterator, AsyncIteratorWrapper)) and callable(content): + content = content() # type: ignore[unreachable] + + if isinstance(content, (str, bytes)): + self.content_async_iterator = AsyncIteratorWrapper([content]) + elif isinstance(content, (Iterable, Iterator)): + self.content_async_iterator = AsyncIteratorWrapper(content) + elif isinstance(content, (AsyncIterable, AsyncIterator, AsyncIteratorWrapper)): + self.content_async_iterator = content + else: + raise ImproperlyConfiguredException(f"Invalid type {type(content)} for ServerSentEvent") + + def ensure_bytes(self, data: str | int | bytes | dict | ServerSentEventMessage | Any, sep: str) -> bytes: + if isinstance(data, ServerSentEventMessage): + return data.encode() + if isinstance(data, dict): + data["sep"] = sep + return ServerSentEventMessage(**data).encode() + + return ServerSentEventMessage( + data=data, id=self.event_id, event=self.event_type, retry=self.retry_duration, sep=sep + ).encode() + + def _call_next(self) -> bytes: + try: + return next(self.iterator) + except StopIteration as e: + raise ValueError from e + + async def _async_generator(self) -> AsyncGenerator[bytes, None]: + while True: + try: + yield await sync_to_thread(self._call_next) + except ValueError: + async for value in self.content_async_iterator: + yield self.ensure_bytes(value, DEFAULT_SEPARATOR) + break + + +@dataclass +class ServerSentEventMessage: + data: str | int | bytes | None = "" + event: str | None = None + id: int | str | None = None + retry: int | None = None + comment: str | None = None + sep: str = DEFAULT_SEPARATOR + + def encode(self) -> bytes: + buffer = io.StringIO() + if self.comment is not None: + for chunk in _LINE_BREAK_RE.split(str(self.comment)): + buffer.write(f": {chunk}") + buffer.write(self.sep) + + if self.id is not None: + buffer.write(_LINE_BREAK_RE.sub("", f"id: {self.id}")) + buffer.write(self.sep) + + if self.event is not None: + buffer.write(_LINE_BREAK_RE.sub("", f"event: {self.event}")) + buffer.write(self.sep) + + if self.data is not None: + data = self.data + for chunk in _LINE_BREAK_RE.split(data.decode() if isinstance(data, bytes) else str(data)): + buffer.write(f"data: {chunk}") + buffer.write(self.sep) + + if self.retry is not None: + buffer.write(f"retry: {self.retry}") + buffer.write(self.sep) + + buffer.write(self.sep) + return buffer.getvalue().encode("utf-8") + + +class ServerSentEvent(Stream): + def __init__( + self, + content: str | bytes | StreamType[SSEData], + *, + background: BackgroundTask | BackgroundTasks | None = None, + cookies: ResponseCookies | None = None, + encoding: str = "utf-8", + headers: ResponseHeaders | None = None, + event_type: str | None = None, + event_id: int | str | None = None, + retry_duration: int | None = None, + comment_message: str | None = None, + status_code: int | None = None, + ) -> None: + """Initialize the response. + + Args: + content: Bytes, string or a sync or async iterator or iterable. + background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or + :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. + Defaults to None. + cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response + ``Set-Cookie`` header. + encoding: The encoding to be used for the response headers. + headers: A string keyed dictionary of response headers. Header keys are insensitive. + status_code: The response status code. Defaults to 200. + event_type: The type of the SSE event. If given, the browser will sent the event to any 'event-listener' + declared for it (e.g. via 'addEventListener' in JS). + event_id: The event ID. This sets the event source's 'last event id'. + retry_duration: Retry duration in milliseconds. + comment_message: A comment message. This value is ignored by clients and is used mostly for pinging. + """ + super().__init__( + content=_ServerSentEventIterator( + content=content, + event_type=event_type, + event_id=event_id, + retry_duration=retry_duration, + comment_message=comment_message, + ), + media_type="text/event-stream", + background=background, + cookies=cookies, + encoding=encoding, + headers=headers, + status_code=status_code, + ) + self.headers.setdefault("Cache-Control", "no-cache") + self.headers["Connection"] = "keep-alive" + self.headers["X-Accel-Buffering"] = "no" diff --git a/venv/lib/python3.11/site-packages/litestar/response/streaming.py b/venv/lib/python3.11/site-packages/litestar/response/streaming.py new file mode 100644 index 0000000..fc76522 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/response/streaming.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import itertools +from functools import partial +from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Iterator, Union + +from anyio import CancelScope, create_task_group + +from litestar.enums import MediaType +from litestar.response.base import ASGIResponse, Response +from litestar.types.helper_types import StreamType +from litestar.utils.deprecation import warn_deprecation +from litestar.utils.helpers import get_enum_string_value +from litestar.utils.sync import AsyncIteratorWrapper + +if TYPE_CHECKING: + from litestar.app import Litestar + from litestar.background_tasks import BackgroundTask, BackgroundTasks + from litestar.connection import Request + from litestar.datastructures.cookie import Cookie + from litestar.enums import OpenAPIMediaType + from litestar.types import HTTPResponseBodyEvent, Receive, ResponseCookies, ResponseHeaders, Send, TypeEncodersMap + +__all__ = ( + "ASGIStreamingResponse", + "Stream", +) + + +class ASGIStreamingResponse(ASGIResponse): + """A streaming response.""" + + __slots__ = ("iterator",) + + _should_set_content_length = False + + def __init__( + self, + *, + iterator: StreamType, + background: BackgroundTask | BackgroundTasks | None = None, + body: bytes | str = b"", + content_length: int | None = None, + cookies: Iterable[Cookie] | None = None, + encoded_headers: Iterable[tuple[bytes, bytes]] | None = None, + encoding: str = "utf-8", + headers: dict[str, Any] | None = None, + is_head_response: bool = False, + media_type: MediaType | str | None = None, + status_code: int | None = None, + ) -> None: + """A low-level ASGI streaming response. + + Args: + background: A background task or a list of background tasks to be executed after the response is sent. + body: encoded content to send in the response body. + content_length: The response content length. + cookies: The response cookies. + encoded_headers: The response headers. + encoding: The response encoding. + headers: The response headers. + is_head_response: A boolean indicating if the response is a HEAD response. + iterator: An async iterator or iterable. + media_type: The response media type. + status_code: The response status code. + """ + super().__init__( + background=background, + body=body, + content_length=content_length, + cookies=cookies, + encoding=encoding, + headers=headers, + is_head_response=is_head_response, + media_type=media_type, + status_code=status_code, + encoded_headers=encoded_headers, + ) + self.iterator: AsyncIterable[str | bytes] | AsyncGenerator[str | bytes, None] = ( + iterator if isinstance(iterator, (AsyncIterable, AsyncIterator)) else AsyncIteratorWrapper(iterator) + ) + + async def _listen_for_disconnect(self, cancel_scope: CancelScope, receive: Receive) -> None: + """Listen for a cancellation message, and if received - call cancel on the cancel scope. + + Args: + cancel_scope: A task group cancel scope instance. + receive: The ASGI receive function. + + Returns: + None + """ + if not cancel_scope.cancel_called: + message = await receive() + if message["type"] == "http.disconnect": + # despite the IDE warning, this is not a coroutine because anyio 3+ changed this. + # therefore make sure not to await this. + cancel_scope.cancel() + else: + await self._listen_for_disconnect(cancel_scope=cancel_scope, receive=receive) + + async def _stream(self, send: Send) -> None: + """Send the chunks from the iterator as a stream of ASGI 'http.response.body' events. + + Args: + send: The ASGI Send function. + + Returns: + None + """ + async for chunk in self.iterator: + stream_event: HTTPResponseBodyEvent = { + "type": "http.response.body", + "body": chunk if isinstance(chunk, bytes) else chunk.encode(self.encoding), + "more_body": True, + } + await send(stream_event) + terminus_event: HTTPResponseBodyEvent = {"type": "http.response.body", "body": b"", "more_body": False} + await send(terminus_event) + + async def send_body(self, send: Send, receive: Receive) -> None: + """Emit a stream of events correlating with the response body. + + Args: + send: The ASGI send function. + receive: The ASGI receive function. + + Returns: + None + """ + + async with create_task_group() as task_group: + task_group.start_soon(partial(self._stream, send)) + await self._listen_for_disconnect(cancel_scope=task_group.cancel_scope, receive=receive) + + +class Stream(Response[StreamType[Union[str, bytes]]]): + """An HTTP response that streams the response data as a series of ASGI ``http.response.body`` events.""" + + __slots__ = ("iterator",) + + def __init__( + self, + content: StreamType[str | bytes] | Callable[[], StreamType[str | bytes]], + *, + background: BackgroundTask | BackgroundTasks | None = None, + cookies: ResponseCookies | None = None, + encoding: str = "utf-8", + headers: ResponseHeaders | None = None, + media_type: MediaType | OpenAPIMediaType | str | None = None, + status_code: int | None = None, + ) -> None: + """Initialize the response. + + Args: + content: A sync or async iterator or iterable. + background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or + :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. + Defaults to None. + cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response + ``Set-Cookie`` header. + encoding: The encoding to be used for the response headers. + headers: A string keyed dictionary of response headers. Header keys are insensitive. + media_type: A value for the response ``Content-Type`` header. + status_code: An HTTP status code. + """ + super().__init__( + background=background, + content=b"", # type: ignore[arg-type] + cookies=cookies, + encoding=encoding, + headers=headers, + media_type=media_type, + status_code=status_code, + ) + self.iterator = content + + def to_asgi_response( + self, + app: Litestar | None, + request: Request, + *, + background: BackgroundTask | BackgroundTasks | None = None, + cookies: Iterable[Cookie] | None = None, + encoded_headers: Iterable[tuple[bytes, bytes]] | None = None, + headers: dict[str, str] | None = None, + is_head_response: bool = False, + media_type: MediaType | str | None = None, + status_code: int | None = None, + type_encoders: TypeEncodersMap | None = None, + ) -> ASGIResponse: + """Create an ASGIStreamingResponse from a StremaingResponse instance. + + Args: + app: The :class:`Litestar <.app.Litestar>` application instance. + background: Background task(s) to be executed after the response is sent. + cookies: A list of cookies to be set on the response. + encoded_headers: A list of already encoded headers. + headers: Additional headers to be merged with the response headers. Response headers take precedence. + is_head_response: Whether the response is a HEAD response. + media_type: Media type for the response. If ``media_type`` is already set on the response, this is ignored. + request: The :class:`Request <.connection.Request>` instance. + status_code: Status code for the response. If ``status_code`` is already set on the response, this is + type_encoders: A dictionary of type encoders to use for encoding the response content. + + Returns: + An ASGIStreamingResponse instance. + """ + if app is not None: + warn_deprecation( + version="2.1", + deprecated_name="app", + kind="parameter", + removal_in="3.0.0", + alternative="request.app", + ) + + headers = {**headers, **self.headers} if headers is not None else self.headers + cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies) + + media_type = get_enum_string_value(media_type or self.media_type or MediaType.JSON) + + iterator = self.iterator + if not isinstance(iterator, (Iterable, Iterator, AsyncIterable, AsyncIterator)) and callable(iterator): + iterator = iterator() + + return ASGIStreamingResponse( + background=self.background or background, + body=b"", + content_length=0, + cookies=cookies, + encoded_headers=encoded_headers, + encoding=self.encoding, + headers=headers, + is_head_response=is_head_response, + iterator=iterator, + media_type=media_type, + status_code=self.status_code or status_code, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/response/template.py b/venv/lib/python3.11/site-packages/litestar/response/template.py new file mode 100644 index 0000000..6499aae --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/response/template.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import itertools +from mimetypes import guess_type +from pathlib import PurePath +from typing import TYPE_CHECKING, Any, Iterable, cast + +from litestar.enums import MediaType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.response.base import ASGIResponse, Response +from litestar.status_codes import HTTP_200_OK +from litestar.utils.deprecation import warn_deprecation +from litestar.utils.empty import value_or_default +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from litestar.app import Litestar + from litestar.background_tasks import BackgroundTask, BackgroundTasks + from litestar.connection import Request + from litestar.datastructures import Cookie + from litestar.types import ResponseCookies, TypeEncodersMap + +__all__ = ("Template",) + + +class Template(Response[bytes]): + """Template-based response, rendering a given template into a bytes string.""" + + __slots__ = ( + "template_name", + "template_str", + "context", + ) + + def __init__( + self, + template_name: str | None = None, + *, + template_str: str | None = None, + background: BackgroundTask | BackgroundTasks | None = None, + context: dict[str, Any] | None = None, + cookies: ResponseCookies | None = None, + encoding: str = "utf-8", + headers: dict[str, Any] | None = None, + media_type: MediaType | str | None = None, + status_code: int = HTTP_200_OK, + ) -> None: + """Handle the rendering of a given template into a bytes string. + + Args: + template_name: Path-like name for the template to be rendered, e.g. ``index.html``. + template_str: A string representing the template, e.g. ``tmpl = "Hello <strong>World</strong>"``. + background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or + :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. + Defaults to ``None``. + context: A dictionary of key/value pairs to be passed to the temple engine's render method. + cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response + ``Set-Cookie`` header. + encoding: Content encoding + headers: A string keyed dictionary of response headers. Header keys are insensitive. + media_type: A string or member of the :class:`MediaType <.enums.MediaType>` enum. If not set, try to infer + the media type based on the template name. If this fails, fall back to ``text/plain``. + status_code: A value for the response HTTP status code. + """ + if not (template_name or template_str): + raise ValueError("Either template_name or template_str must be provided.") + + if template_name and template_str: + raise ValueError("Either template_name or template_str must be provided, not both.") + + super().__init__( + background=background, + content=b"", + cookies=cookies, + encoding=encoding, + headers=headers, + media_type=media_type, + status_code=status_code, + ) + self.context = context or {} + self.template_name = template_name + self.template_str = template_str + + def create_template_context(self, request: Request) -> dict[str, Any]: + """Create a context object for the template. + + Args: + request: A :class:`Request <.connection.Request>` instance. + + Returns: + A dictionary holding the template context + """ + csrf_token = value_or_default(ScopeState.from_scope(request.scope).csrf_token, "") + return { + **self.context, + "request": request, + "csrf_input": f'<input type="hidden" name="_csrf_token" value="{csrf_token}" />', + } + + def to_asgi_response( + self, + app: Litestar | None, + request: Request, + *, + background: BackgroundTask | BackgroundTasks | None = None, + cookies: Iterable[Cookie] | None = None, + encoded_headers: Iterable[tuple[bytes, bytes]] | None = None, + headers: dict[str, str] | None = None, + is_head_response: bool = False, + media_type: MediaType | str | None = None, + status_code: int | None = None, + type_encoders: TypeEncodersMap | None = None, + ) -> ASGIResponse: + if app is not None: + warn_deprecation( + version="2.1", + deprecated_name="app", + kind="parameter", + removal_in="3.0.0", + alternative="request.app", + ) + + if not (template_engine := request.app.template_engine): + raise ImproperlyConfiguredException("Template engine is not configured") + + headers = {**headers, **self.headers} if headers is not None else self.headers + cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies) + + media_type = self.media_type or media_type + if not media_type: + if self.template_name: + suffixes = PurePath(self.template_name).suffixes + for suffix in suffixes: + if _type := guess_type(f"name{suffix}")[0]: + media_type = _type + break + else: + media_type = MediaType.TEXT + else: + media_type = MediaType.HTML + + context = self.create_template_context(request) + + if self.template_str is not None: + body = template_engine.render_string(self.template_str, context) + else: + # cast to str b/c we know that either template_name cannot be None if template_str is None + template = template_engine.get_template(cast("str", self.template_name)) + body = template.render(**context).encode(self.encoding) + + return ASGIResponse( + background=self.background or background, + body=body, + content_length=None, + cookies=cookies, + encoded_headers=encoded_headers, + encoding=self.encoding, + headers=headers, + is_head_response=is_head_response, + media_type=media_type, + status_code=self.status_code or status_code, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/router.py b/venv/lib/python3.11/site-packages/litestar/router.py new file mode 100644 index 0000000..85346d8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/router.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +from collections import defaultdict +from copy import copy, deepcopy +from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast + +from litestar._layers.utils import narrow_response_cookies, narrow_response_headers +from litestar.controller import Controller +from litestar.exceptions import ImproperlyConfiguredException +from litestar.handlers.asgi_handlers import ASGIRouteHandler +from litestar.handlers.http_handlers import HTTPRouteHandler +from litestar.handlers.websocket_handlers import WebsocketListener, WebsocketRouteHandler +from litestar.routes import ASGIRoute, HTTPRoute, WebSocketRoute +from litestar.types.empty import Empty +from litestar.utils import find_index, is_class_and_subclass, join_paths, normalize_path, unique +from litestar.utils.signature import add_types_to_signature_namespace +from litestar.utils.sync import ensure_async_callable + +__all__ = ("Router",) + + +if TYPE_CHECKING: + from litestar.connection import Request, WebSocket + from litestar.datastructures import CacheControlHeader, ETag + from litestar.dto import AbstractDTO + from litestar.openapi.spec import SecurityRequirement + from litestar.response import Response + from litestar.routes import BaseRoute + from litestar.types import ( + AfterRequestHookHandler, + AfterResponseHookHandler, + BeforeRequestHookHandler, + ControllerRouterHandler, + ExceptionHandlersMap, + Guard, + Middleware, + ParametersMap, + ResponseCookies, + RouteHandlerMapItem, + RouteHandlerType, + TypeEncodersMap, + ) + from litestar.types.composite_types import Dependencies, ResponseHeaders, TypeDecodersSequence + from litestar.types.empty import EmptyType + + +class Router: + """The Litestar Router class. + + A Router instance is used to group controller, routers and route handler functions under a shared path fragment + """ + + __slots__ = ( + "after_request", + "after_response", + "before_request", + "cache_control", + "dependencies", + "dto", + "etag", + "exception_handlers", + "guards", + "include_in_schema", + "middleware", + "opt", + "owner", + "parameters", + "path", + "registered_route_handler_ids", + "request_class", + "response_class", + "response_cookies", + "response_headers", + "return_dto", + "routes", + "security", + "signature_namespace", + "tags", + "type_decoders", + "type_encoders", + "websocket_class", + ) + + def __init__( + self, + path: str, + *, + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + before_request: BeforeRequestHookHandler | None = None, + cache_control: CacheControlHeader | None = None, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + etag: ETag | None = None, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + include_in_schema: bool | EmptyType = Empty, + middleware: Sequence[Middleware] | None = None, + opt: Mapping[str, Any] | None = None, + parameters: ParametersMap | None = None, + request_class: type[Request] | None = None, + response_class: type[Response] | None = None, + response_cookies: ResponseCookies | None = None, + response_headers: ResponseHeaders | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + route_handlers: Sequence[ControllerRouterHandler], + security: Sequence[SecurityRequirement] | None = None, + signature_namespace: Mapping[str, Any] | None = None, + signature_types: Sequence[Any] | None = None, + tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, + ) -> None: + """Initialize a ``Router``. + + Args: + after_request: A sync or async function executed before a :class:`Request <.connection.Request>` is passed + to any route handler. If this function returns a value, the request will not reach the route handler, + and instead this value will be used. + after_response: A sync or async function called after the response has been awaited. It receives the + :class:`Request <.connection.Request>` object and should not return any values. + before_request: A sync or async function called immediately before calling the route handler. Receives + the :class:`litestar.connection.Request` instance and any non-``None`` return value is used for the + response, bypassing the route handler. + cache_control: A ``cache-control`` header of type + :class:`CacheControlHeader <.datastructures.CacheControlHeader>` to add to route handlers of + this router. Can be overridden by route handlers. + dependencies: A string keyed mapping of dependency :class:`Provide <.di.Provide>` instances. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + etag: An ``etag`` header of type :class:`ETag <.datastructures.ETag>` to add to route handlers of this app. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :data:`Guard <.types.Guard>` callables. + include_in_schema: A boolean flag dictating whether the route handler should be documented in the OpenAPI schema. + middleware: A sequence of :data:`Middleware <.types.Middleware>`. + opt: A string keyed mapping of arbitrary values that can be accessed in :data:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or + :data:`ASGI Scope <.types.Scope>`. + parameters: A mapping of :func:`Parameter <.params.Parameter>` definitions available to all application + paths. + path: A path fragment that is prefixed to all route handlers, controllers and other routers associated + with the router instance. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as the default for + all route handlers, controllers and other routers associated with the router instance. + response_class: A custom subclass of :class:`Response <.response.Response>` to be used as the default for + all route handlers, controllers and other routers associated with the router instance. + response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. + response_headers: A string keyed mapping of :class:`ResponseHeader <.datastructures.ResponseHeader>` + instances. + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + route_handlers: A required sequence of route handlers, which can include instances of + :class:`Router <.router.Router>`, subclasses of :class:`Controller <.controller.Controller>` or any + function decorated by the route handler decorators. + security: A sequence of dicts that will be added to the schema of all route handlers in the application. + See :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>` + for details. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modeling. + signature_types: A sequence of types for use in forward reference resolution during signature modeling. + These types will be added to the signature namespace using their ``__name__`` attribute. + tags: A sequence of string tags that will be appended to the schema of all route handlers under the + application. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as the default for + all route handlers, controllers and other routers associated with the router instance. + """ + + self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore + self.after_response = ensure_async_callable(after_response) if after_response else None + self.before_request = ensure_async_callable(before_request) if before_request else None + self.cache_control = cache_control + self.dto = dto + self.etag = etag + self.dependencies = dict(dependencies or {}) + self.exception_handlers = dict(exception_handlers or {}) + self.guards = list(guards or []) + self.include_in_schema = include_in_schema + self.middleware = list(middleware or []) + self.opt = dict(opt or {}) + self.owner: Router | None = None + self.parameters = dict(parameters or {}) + self.path = normalize_path(path) + self.request_class = request_class + self.response_class = response_class + self.response_cookies = narrow_response_cookies(response_cookies) + self.response_headers = narrow_response_headers(response_headers) + self.return_dto = return_dto + self.routes: list[HTTPRoute | ASGIRoute | WebSocketRoute] = [] + self.security = list(security or []) + self.signature_namespace = add_types_to_signature_namespace( + signature_types or [], dict(signature_namespace or {}) + ) + self.tags = list(tags or []) + self.registered_route_handler_ids: set[int] = set() + self.type_encoders = dict(type_encoders) if type_encoders is not None else None + self.type_decoders = list(type_decoders) if type_decoders is not None else None + self.websocket_class = websocket_class + + for route_handler in route_handlers or []: + self.register(value=route_handler) + + def register(self, value: ControllerRouterHandler) -> list[BaseRoute]: + """Register a Controller, Route instance or RouteHandler on the router. + + Args: + value: a subclass or instance of Controller, an instance of :class:`Router` or a function/method that has + been decorated by any of the routing decorators, e.g. :class:`get <.handlers.get>`, + :class:`post <.handlers.post>`. + + Returns: + Collection of handlers added to the router. + """ + validated_value = self._validate_registration_value(value) + + routes: list[BaseRoute] = [] + + for route_path, handlers_map in self.get_route_handler_map(value=validated_value).items(): + path = join_paths([self.path, route_path]) + if http_handlers := unique( + [handler for handler in handlers_map.values() if isinstance(handler, HTTPRouteHandler)] + ): + if existing_handlers := unique( + [ + handler + for handler in self.route_handler_method_map.get(path, {}).values() + if isinstance(handler, HTTPRouteHandler) + ] + ): + http_handlers.extend(existing_handlers) + existing_route_index = find_index(self.routes, lambda x: x.path == path) # noqa: B023 + + if existing_route_index == -1: # pragma: no cover + raise ImproperlyConfiguredException("unable to find_index existing route index") + + route: WebSocketRoute | ASGIRoute | HTTPRoute = HTTPRoute( + path=path, + route_handlers=http_handlers, + ) + self.routes[existing_route_index] = route + else: + route = HTTPRoute(path=path, route_handlers=http_handlers) + self.routes.append(route) + + routes.append(route) + + if websocket_handler := handlers_map.get("websocket"): + route = WebSocketRoute(path=path, route_handler=cast("WebsocketRouteHandler", websocket_handler)) + self.routes.append(route) + routes.append(route) + + if asgi_handler := handlers_map.get("asgi"): + route = ASGIRoute(path=path, route_handler=cast("ASGIRouteHandler", asgi_handler)) + self.routes.append(route) + routes.append(route) + + return routes + + @property + def route_handler_method_map(self) -> dict[str, RouteHandlerMapItem]: + """Map route paths to :class:`RouteHandlerMapItem <litestar.types.internal_typ es.RouteHandlerMapItem>` + + Returns: + A dictionary mapping paths to route handlers + """ + route_map: dict[str, RouteHandlerMapItem] = defaultdict(dict) + for route in self.routes: + if isinstance(route, HTTPRoute): + for route_handler in route.route_handlers: + for method in route_handler.http_methods: + route_map[route.path][method] = route_handler + else: + route_map[route.path]["websocket" if isinstance(route, WebSocketRoute) else "asgi"] = ( + route.route_handler + ) + + return route_map + + @classmethod + def get_route_handler_map( + cls, + value: Controller | RouteHandlerType | Router, + ) -> dict[str, RouteHandlerMapItem]: + """Map route handlers to HTTP methods.""" + if isinstance(value, Router): + return value.route_handler_method_map + + if isinstance(value, (HTTPRouteHandler, ASGIRouteHandler, WebsocketRouteHandler)): + copied_value = copy(value) + if isinstance(value, HTTPRouteHandler): + return {path: {http_method: copied_value for http_method in value.http_methods} for path in value.paths} + + return { + path: {"websocket" if isinstance(value, WebsocketRouteHandler) else "asgi": copied_value} + for path in value.paths + } + + handlers_map: defaultdict[str, RouteHandlerMapItem] = defaultdict(dict) + for route_handler in value.get_route_handlers(): + for handler_path in route_handler.paths: + path = join_paths([value.path, handler_path]) if handler_path else value.path + if isinstance(route_handler, HTTPRouteHandler): + for http_method in route_handler.http_methods: + handlers_map[path][http_method] = route_handler + else: + handlers_map[path]["websocket" if isinstance(route_handler, WebsocketRouteHandler) else "asgi"] = ( + cast("WebsocketRouteHandler | ASGIRouteHandler", route_handler) + ) + + return handlers_map + + def _validate_registration_value(self, value: ControllerRouterHandler) -> Controller | RouteHandlerType | Router: + """Ensure values passed to the register method are supported.""" + if is_class_and_subclass(value, Controller): + return value(owner=self) + + # this narrows down to an ABC, but we assume a non-abstract subclass of the ABC superclass + if is_class_and_subclass(value, WebsocketListener): + return value(owner=self).to_handler() # pyright: ignore + + if isinstance(value, Router): + if value is self: + raise ImproperlyConfiguredException("Cannot register a router on itself") + + router_copy = deepcopy(value) + router_copy.owner = self + return router_copy + + if isinstance(value, (ASGIRouteHandler, HTTPRouteHandler, WebsocketRouteHandler)): + value.owner = self + return value + + raise ImproperlyConfiguredException( + "Unsupported value passed to `Router.register`. " + "If you passed in a function or method, " + "make sure to decorate it first with one of the routing decorators" + ) diff --git a/venv/lib/python3.11/site-packages/litestar/routes/__init__.py b/venv/lib/python3.11/site-packages/litestar/routes/__init__.py new file mode 100644 index 0000000..c8b5d3d --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/routes/__init__.py @@ -0,0 +1,6 @@ +from .asgi import ASGIRoute +from .base import BaseRoute +from .http import HTTPRoute +from .websocket import WebSocketRoute + +__all__ = ("BaseRoute", "ASGIRoute", "WebSocketRoute", "HTTPRoute") diff --git a/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ae8dee1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/asgi.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/asgi.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b3330b1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/asgi.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..98b66b4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/http.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/http.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a430b70 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/http.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/websocket.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/websocket.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..79da4a4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/websocket.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/routes/asgi.py b/venv/lib/python3.11/site-packages/litestar/routes/asgi.py new file mode 100644 index 0000000..a8564d0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/routes/asgi.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from litestar.connection import ASGIConnection +from litestar.enums import ScopeType +from litestar.routes.base import BaseRoute + +if TYPE_CHECKING: + from litestar.handlers.asgi_handlers import ASGIRouteHandler + from litestar.types import Receive, Scope, Send + + +class ASGIRoute(BaseRoute): + """An ASGI route, handling a single ``ASGIRouteHandler``""" + + __slots__ = ("route_handler",) + + def __init__( + self, + *, + path: str, + route_handler: ASGIRouteHandler, + ) -> None: + """Initialize the route. + + Args: + path: The path for the route. + route_handler: An instance of :class:`~.handlers.ASGIRouteHandler`. + """ + self.route_handler = route_handler + super().__init__( + path=path, + scope_type=ScopeType.ASGI, + handler_names=[route_handler.handler_name], + ) + + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI app that authorizes the connection and then awaits the handler function. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + + if self.route_handler.resolve_guards(): + connection = ASGIConnection["ASGIRouteHandler", Any, Any, Any](scope=scope, receive=receive) + await self.route_handler.authorize_connection(connection=connection) + + await self.route_handler.fn(scope=scope, receive=receive, send=send) diff --git a/venv/lib/python3.11/site-packages/litestar/routes/base.py b/venv/lib/python3.11/site-packages/litestar/routes/base.py new file mode 100644 index 0000000..b9baab4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/routes/base.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import re +from abc import ABC, abstractmethod +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable +from uuid import UUID + +import msgspec + +from litestar._kwargs import KwargsModel +from litestar.exceptions import ImproperlyConfiguredException +from litestar.types.internal_types import PathParameterDefinition +from litestar.utils import join_paths, normalize_path + +if TYPE_CHECKING: + from litestar.enums import ScopeType + from litestar.handlers.base import BaseRouteHandler + from litestar.types import Method, Receive, Scope, Send + + +def _parse_datetime(value: str) -> datetime: + return msgspec.convert(value, datetime) + + +def _parse_date(value: str) -> date: + return msgspec.convert(value, date) + + +def _parse_time(value: str) -> time: + return msgspec.convert(value, time) + + +def _parse_timedelta(value: str) -> timedelta: + try: + return msgspec.convert(value, timedelta) + except msgspec.ValidationError: + return timedelta(seconds=int(float(value))) + + +param_match_regex = re.compile(r"{(.*?)}") + +param_type_map = { + "str": str, + "int": int, + "float": float, + "uuid": UUID, + "decimal": Decimal, + "date": date, + "datetime": datetime, + "time": time, + "timedelta": timedelta, + "path": Path, +} + + +parsers_map: dict[Any, Callable[[Any], Any]] = { + float: float, + int: int, + Decimal: Decimal, + UUID: UUID, + date: _parse_date, + datetime: _parse_datetime, + time: _parse_time, + timedelta: _parse_timedelta, +} + + +class BaseRoute(ABC): + """Base Route class used by Litestar. + + It's an abstract class meant to be extended. + """ + + __slots__ = ( + "app", + "handler_names", + "methods", + "path", + "path_format", + "path_parameters", + "path_components", + "scope_type", + ) + + def __init__( + self, + *, + handler_names: list[str], + path: str, + scope_type: ScopeType, + methods: list[Method] | None = None, + ) -> None: + """Initialize the route. + + Args: + handler_names: Names of the associated handler functions + path: Base path of the route + scope_type: Type of the ASGI scope + methods: Supported methods + """ + self.path, self.path_format, self.path_components = self._parse_path(path) + self.path_parameters: tuple[PathParameterDefinition, ...] = tuple( + component for component in self.path_components if isinstance(component, PathParameterDefinition) + ) + self.handler_names = handler_names + self.scope_type = scope_type + self.methods = set(methods or []) + + @abstractmethod + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI App of the route. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + raise NotImplementedError("Route subclasses must implement handle which serves as the ASGI app entry point") + + def create_handler_kwargs_model(self, route_handler: BaseRouteHandler) -> KwargsModel: + """Create a `KwargsModel` for a given route handler.""" + + path_parameters = set() + for param in self.path_parameters: + if param.name in path_parameters: + raise ImproperlyConfiguredException(f"Duplicate parameter '{param.name}' detected in '{self.path}'.") + path_parameters.add(param.name) + + return KwargsModel.create_for_signature_model( + signature_model=route_handler.signature_model, + parsed_signature=route_handler.parsed_fn_signature, + dependencies=route_handler.resolve_dependencies(), + path_parameters=path_parameters, + layered_parameters=route_handler.resolve_layered_parameters(), + ) + + @staticmethod + def _validate_path_parameter(param: str, path: str) -> None: + """Validate that a path parameter adheres to the required format and datatypes. + + Raises: + ImproperlyConfiguredException: If the parameter has an invalid format. + """ + if len(param.split(":")) != 2: + raise ImproperlyConfiguredException( + f"Path parameters should be declared with a type using the following pattern: '{{parameter_name:type}}', e.g. '/my-path/{{my_param:int}}' in path: '{path}'" + ) + param_name, param_type = (p.strip() for p in param.split(":")) + if not param_name: + raise ImproperlyConfiguredException("Path parameter names should be of length greater than zero") + if param_type not in param_type_map: + raise ImproperlyConfiguredException( + f"Path parameters should be declared with an allowed type, i.e. one of {', '.join(param_type_map.keys())} in path: '{path}'" + ) + + @classmethod + def _parse_path(cls, path: str) -> tuple[str, str, list[str | PathParameterDefinition]]: + """Normalize and parse a path. + + Splits the path into a list of components, parsing any that are path parameters. Also builds the OpenAPI + compatible path, which does not include the type of the path parameters. + + Returns: + A 3-tuple of the normalized path, the OpenAPI formatted path, and the list of parsed components. + """ + path = normalize_path(path) + + parsed_components: list[str | PathParameterDefinition] = [] + path_format_components = [] + + components = [component for component in path.split("/") if component] + for component in components: + if param_match := param_match_regex.fullmatch(component): + param = param_match.group(1) + cls._validate_path_parameter(param, path) + param_name, param_type = (p.strip() for p in param.split(":")) + type_class = param_type_map[param_type] + parser = parsers_map[type_class] if type_class not in {str, Path} else None + parsed_components.append( + PathParameterDefinition(name=param_name, type=type_class, full=param, parser=parser) + ) + path_format_components.append("{" + param_name + "}") + else: + parsed_components.append(component) + path_format_components.append(component) + + path_format = join_paths(path_format_components) + + return path, path_format, parsed_components diff --git a/venv/lib/python3.11/site-packages/litestar/routes/http.py b/venv/lib/python3.11/site-packages/litestar/routes/http.py new file mode 100644 index 0000000..b1f70cb --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/routes/http.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +from itertools import chain +from typing import TYPE_CHECKING, Any, cast + +from msgspec.msgpack import decode as _decode_msgpack_plain + +from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS +from litestar.datastructures.headers import Headers +from litestar.datastructures.upload_file import UploadFile +from litestar.enums import HttpMethod, MediaType, ScopeType +from litestar.exceptions import ClientException, ImproperlyConfiguredException, SerializationException +from litestar.handlers.http_handlers import HTTPRouteHandler +from litestar.response import Response +from litestar.routes.base import BaseRoute +from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST +from litestar.types.empty import Empty +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from litestar._kwargs import KwargsModel + from litestar._kwargs.cleanup import DependencyCleanupGroup + from litestar.connection import Request + from litestar.types import ASGIApp, HTTPScope, Method, Receive, Scope, Send + + +class HTTPRoute(BaseRoute): + """An HTTP route, capable of handling multiple ``HTTPRouteHandler``\\ s.""" # noqa: D301 + + __slots__ = ( + "route_handler_map", + "route_handlers", + ) + + def __init__( + self, + *, + path: str, + route_handlers: list[HTTPRouteHandler], + ) -> None: + """Initialize ``HTTPRoute``. + + Args: + path: The path for the route. + route_handlers: A list of :class:`~.handlers.HTTPRouteHandler`. + """ + methods = list(chain.from_iterable([route_handler.http_methods for route_handler in route_handlers])) + if "OPTIONS" not in methods: + methods.append("OPTIONS") + options_handler = self.create_options_handler(path) + options_handler.owner = route_handlers[0].owner + route_handlers.append(options_handler) + + self.route_handlers = route_handlers + self.route_handler_map: dict[Method, tuple[HTTPRouteHandler, KwargsModel]] = {} + + super().__init__( + methods=methods, + path=path, + scope_type=ScopeType.HTTP, + handler_names=[route_handler.handler_name for route_handler in self.route_handlers], + ) + + async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None: # type: ignore[override] + """ASGI app that creates a Request from the passed in args, determines which handler function to call and then + handles the call. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + route_handler, parameter_model = self.route_handler_map[scope["method"]] + request: Request[Any, Any, Any] = route_handler.resolve_request_class()(scope=scope, receive=receive, send=send) + + if route_handler.resolve_guards(): + await route_handler.authorize_connection(connection=request) + + response = await self._get_response_for_request( + scope=scope, request=request, route_handler=route_handler, parameter_model=parameter_model + ) + + await response(scope, receive, send) + + if after_response_handler := route_handler.resolve_after_response(): + await after_response_handler(request) + + if form_data := scope.get("_form", {}): + await self._cleanup_temporary_files(form_data=cast("dict[str, Any]", form_data)) + + def create_handler_map(self) -> None: + """Parse the ``router_handlers`` of this route and return a mapping of + http- methods and route handlers. + """ + for route_handler in self.route_handlers: + kwargs_model = self.create_handler_kwargs_model(route_handler=route_handler) + for http_method in route_handler.http_methods: + if self.route_handler_map.get(http_method): + raise ImproperlyConfiguredException( + f"Handler already registered for path {self.path!r} and http method {http_method}" + ) + self.route_handler_map[http_method] = (route_handler, kwargs_model) + + async def _get_response_for_request( + self, + scope: Scope, + request: Request[Any, Any, Any], + route_handler: HTTPRouteHandler, + parameter_model: KwargsModel, + ) -> ASGIApp: + """Return a response for the request. + + If caching is enabled and a response exist in the cache, the cached response will be returned. + If caching is enabled and a response does not exist in the cache, the newly created + response will be cached. + + Args: + scope: The Request's scope + request: The Request instance + route_handler: The HTTPRouteHandler instance + parameter_model: The Handler's KwargsModel + + Returns: + An instance of Response or a compatible ASGIApp or a subclass of it + """ + if route_handler.cache and ( + response := await self._get_cached_response(request=request, route_handler=route_handler) + ): + return response + + return await self._call_handler_function( + scope=scope, request=request, parameter_model=parameter_model, route_handler=route_handler + ) + + async def _call_handler_function( + self, scope: Scope, request: Request, parameter_model: KwargsModel, route_handler: HTTPRouteHandler + ) -> ASGIApp: + """Call the before request handlers, retrieve any data required for the route handler, and call the route + handler's ``to_response`` method. + + This is wrapped in a try except block - and if an exception is raised, + it tries to pass it to an appropriate exception handler - if defined. + """ + response_data: Any = None + cleanup_group: DependencyCleanupGroup | None = None + + if before_request_handler := route_handler.resolve_before_request(): + response_data = await before_request_handler(request) + + if not response_data: + response_data, cleanup_group = await self._get_response_data( + route_handler=route_handler, parameter_model=parameter_model, request=request + ) + + response: ASGIApp = await route_handler.to_response(app=scope["app"], data=response_data, request=request) + + if cleanup_group: + await cleanup_group.cleanup() + + return response + + @staticmethod + async def _get_response_data( + route_handler: HTTPRouteHandler, parameter_model: KwargsModel, request: Request + ) -> tuple[Any, DependencyCleanupGroup | None]: + """Determine what kwargs are required for the given route handler's ``fn`` and calls it.""" + parsed_kwargs: dict[str, Any] = {} + cleanup_group: DependencyCleanupGroup | None = None + + if parameter_model.has_kwargs and route_handler.signature_model: + kwargs = parameter_model.to_kwargs(connection=request) + + if "data" in kwargs: + try: + data = await kwargs["data"] + except SerializationException as e: + raise ClientException(str(e)) from e + + if data is Empty: + del kwargs["data"] + else: + kwargs["data"] = data + + if "body" in kwargs: + kwargs["body"] = await kwargs["body"] + + if parameter_model.dependency_batches: + cleanup_group = await parameter_model.resolve_dependencies(request, kwargs) + + parsed_kwargs = route_handler.signature_model.parse_values_from_connection_kwargs( + connection=request, **kwargs + ) + + if cleanup_group: + async with cleanup_group: + data = ( + route_handler.fn(**parsed_kwargs) + if route_handler.has_sync_callable + else await route_handler.fn(**parsed_kwargs) + ) + elif route_handler.has_sync_callable: + data = route_handler.fn(**parsed_kwargs) + else: + data = await route_handler.fn(**parsed_kwargs) + + return data, cleanup_group + + @staticmethod + async def _get_cached_response(request: Request, route_handler: HTTPRouteHandler) -> ASGIApp | None: + """Retrieve and un-pickle the cached response, if existing. + + Args: + request: The :class:`Request <litestar.connection.Request>` instance + route_handler: The :class:`~.handlers.HTTPRouteHandler` instance + + Returns: + A cached response instance, if existing. + """ + + cache_config = request.app.response_cache_config + cache_key = (route_handler.cache_key_builder or cache_config.key_builder)(request) + store = cache_config.get_store_from_app(request.app) + + if not (cached_response_data := await store.get(key=cache_key)): + return None + + # we use the regular msgspec.msgpack.decode here since we don't need any of + # the added decoders + messages = _decode_msgpack_plain(cached_response_data) + + async def cached_response(scope: Scope, receive: Receive, send: Send) -> None: + ScopeState.from_scope(scope).is_cached = True + for message in messages: + await send(message) + + return cached_response + + def create_options_handler(self, path: str) -> HTTPRouteHandler: + """Args: + path: The route path + + Returns: + An HTTP route handler for OPTIONS requests. + """ + + def options_handler(scope: Scope) -> Response: + """Handler function for OPTIONS requests. + + Args: + scope: The ASGI Scope. + + Returns: + Response + """ + cors_config = scope["app"].cors_config + request_headers = Headers.from_scope(scope=scope) + origin = request_headers.get("origin") + + if cors_config and origin: + pre_flight_method = request_headers.get("Access-Control-Request-Method") + failures = [] + + if not cors_config.is_allow_all_methods and ( + pre_flight_method and pre_flight_method not in cors_config.allow_methods + ): + failures.append("method") + + response_headers = cors_config.preflight_headers.copy() + + if not cors_config.is_origin_allowed(origin): + failures.append("Origin") + elif response_headers.get("Access-Control-Allow-Origin") != "*": + response_headers["Access-Control-Allow-Origin"] = origin + + pre_flight_requested_headers = [ + header.strip() + for header in request_headers.get("Access-Control-Request-Headers", "").split(",") + if header.strip() + ] + + if pre_flight_requested_headers: + if cors_config.is_allow_all_headers: + response_headers["Access-Control-Allow-Headers"] = ", ".join( + sorted(set(pre_flight_requested_headers) | DEFAULT_ALLOWED_CORS_HEADERS) # pyright: ignore + ) + elif any( + header.lower() not in cors_config.allow_headers for header in pre_flight_requested_headers + ): + failures.append("headers") + + return ( + Response( + content=f"Disallowed CORS {', '.join(failures)}", + status_code=HTTP_400_BAD_REQUEST, + media_type=MediaType.TEXT, + ) + if failures + else Response( + content=None, + status_code=HTTP_204_NO_CONTENT, + media_type=MediaType.TEXT, + headers=response_headers, + ) + ) + + return Response( + content=None, + status_code=HTTP_204_NO_CONTENT, + headers={"Allow": ", ".join(sorted(self.methods))}, # pyright: ignore + media_type=MediaType.TEXT, + ) + + return HTTPRouteHandler( + path=path, + http_method=[HttpMethod.OPTIONS], + include_in_schema=False, + sync_to_thread=False, + )(options_handler) + + @staticmethod + async def _cleanup_temporary_files(form_data: dict[str, Any]) -> None: + for v in form_data.values(): + if isinstance(v, UploadFile) and not v.file.closed: + await v.close() diff --git a/venv/lib/python3.11/site-packages/litestar/routes/websocket.py b/venv/lib/python3.11/site-packages/litestar/routes/websocket.py new file mode 100644 index 0000000..ebf4959 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/routes/websocket.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from litestar.enums import ScopeType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.routes.base import BaseRoute + +if TYPE_CHECKING: + from litestar._kwargs import KwargsModel + from litestar._kwargs.cleanup import DependencyCleanupGroup + from litestar.connection import WebSocket + from litestar.handlers.websocket_handlers import WebsocketRouteHandler + from litestar.types import Receive, Send, WebSocketScope + + +class WebSocketRoute(BaseRoute): + """A websocket route, handling a single ``WebsocketRouteHandler``""" + + __slots__ = ( + "route_handler", + "handler_parameter_model", + ) + + def __init__( + self, + *, + path: str, + route_handler: WebsocketRouteHandler, + ) -> None: + """Initialize the route. + + Args: + path: The path for the route. + route_handler: An instance of :class:`~.handlers.WebsocketRouteHandler`. + """ + self.route_handler = route_handler + self.handler_parameter_model: KwargsModel | None = None + + super().__init__( + path=path, + scope_type=ScopeType.WEBSOCKET, + handler_names=[route_handler.handler_name], + ) + + async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> None: # type: ignore[override] + """ASGI app that creates a WebSocket from the passed in args, and then awaits the handler function. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + + if not self.handler_parameter_model: # pragma: no cover + raise ImproperlyConfiguredException("handler parameter model not defined") + + websocket: WebSocket[Any, Any, Any] = self.route_handler.resolve_websocket_class()( + scope=scope, receive=receive, send=send + ) + + if self.route_handler.resolve_guards(): + await self.route_handler.authorize_connection(connection=websocket) + + parsed_kwargs: dict[str, Any] = {} + cleanup_group: DependencyCleanupGroup | None = None + + if self.handler_parameter_model.has_kwargs and self.route_handler.signature_model: + parsed_kwargs = self.handler_parameter_model.to_kwargs(connection=websocket) + + if self.handler_parameter_model.dependency_batches: + cleanup_group = await self.handler_parameter_model.resolve_dependencies(websocket, parsed_kwargs) + + parsed_kwargs = self.route_handler.signature_model.parse_values_from_connection_kwargs( + connection=websocket, **parsed_kwargs + ) + + if cleanup_group: + async with cleanup_group: + await self.route_handler.fn(**parsed_kwargs) + await cleanup_group.cleanup() + else: + await self.route_handler.fn(**parsed_kwargs) diff --git a/venv/lib/python3.11/site-packages/litestar/security/__init__.py b/venv/lib/python3.11/site-packages/litestar/security/__init__.py new file mode 100644 index 0000000..d864d43 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/__init__.py @@ -0,0 +1,3 @@ +from litestar.security.base import AbstractSecurityConfig + +__all__ = ("AbstractSecurityConfig",) diff --git a/venv/lib/python3.11/site-packages/litestar/security/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/security/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..419b39e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/security/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/security/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6133974 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/security/base.py b/venv/lib/python3.11/site-packages/litestar/security/base.py new file mode 100644 index 0000000..fbe7913 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/base.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from copy import copy +from dataclasses import field +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Sequence, TypeVar, cast + +from litestar import Response +from litestar.utils.sync import ensure_async_callable + +if TYPE_CHECKING: + from litestar.config.app import AppConfig + from litestar.connection import ASGIConnection + from litestar.di import Provide + from litestar.enums import MediaType, OpenAPIMediaType + from litestar.middleware.authentication import AbstractAuthenticationMiddleware + from litestar.middleware.base import DefineMiddleware + from litestar.openapi.spec import Components, SecurityRequirement + from litestar.types import ( + ControllerRouterHandler, + Guard, + Method, + ResponseCookies, + Scopes, + SyncOrAsyncUnion, + TypeEncodersMap, + ) + +__all__ = ("AbstractSecurityConfig",) + +UserType = TypeVar("UserType") +AuthType = TypeVar("AuthType") + + +class AbstractSecurityConfig(ABC, Generic[UserType, AuthType]): + """A base class for Security Configs - this class can be used on the application level + or be manually configured on the router / controller level to provide auth. + """ + + authentication_middleware_class: type[AbstractAuthenticationMiddleware] + """The authentication middleware class to use. + + Must inherit from + :class:`AbstractAuthenticationMiddleware <litestar.middleware.authentication.AbstractAuthenticationMiddleware>` + """ + guards: Iterable[Guard] | None = None + """An iterable of guards to call for requests, providing authorization functionalities.""" + exclude: str | list[str] | None = None + """A pattern or list of patterns to skip in the authentication middleware.""" + exclude_opt_key: str = "exclude_from_auth" + """An identifier to use on routes to disable authentication and authorization checks for a particular route.""" + exclude_http_methods: Sequence[Method] | None = field( + default_factory=lambda: cast("Sequence[Method]", ["OPTIONS", "HEAD"]) + ) + """A sequence of http methods that do not require authentication. Defaults to ['OPTIONS', 'HEAD']""" + scopes: Scopes | None = None + """ASGI scopes processed by the authentication middleware, if ``None``, both ``http`` and ``websocket`` will be + processed.""" + route_handlers: Iterable[ControllerRouterHandler] | None = None + """An optional iterable of route handlers to register.""" + dependencies: dict[str, Provide] | None = None + """An optional dictionary of dependency providers.""" + retrieve_user_handler: Callable[[Any, ASGIConnection], SyncOrAsyncUnion[Any | None]] + """Callable that receives the ``auth`` value from the authentication middleware and returns a ``user`` value. + + Notes: + - User and Auth can be any arbitrary values specified by the security backend. + - The User and Auth values will be set by the middleware as ``scope["user"]`` and ``scope["auth"]`` respectively. + Once provided, they can access via the ``connection.user`` and ``connection.auth`` properties. + - The callable can be sync or async. If it is sync, it will be wrapped to support async. + + """ + type_encoders: TypeEncodersMap | None = None + """A mapping of types to callables that transform them into types supported for serialization.""" + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + """Handle app init by injecting middleware, guards etc. into the app. This method can be used only on the app + level. + + Args: + app_config: An instance of :class:`AppConfig <.config.app.AppConfig>` + + Returns: + The :class:`AppConfig <.config.app.AppConfig>`. + """ + app_config.middleware.insert(0, self.middleware) + + if app_config.openapi_config: + app_config.openapi_config = copy(app_config.openapi_config) + if isinstance(app_config.openapi_config.components, list): + app_config.openapi_config.components.append(self.openapi_components) + elif app_config.openapi_config.components: + app_config.openapi_config.components = [self.openapi_components, app_config.openapi_config.components] + else: + app_config.openapi_config.components = [self.openapi_components] + + if isinstance(app_config.openapi_config.security, list): + app_config.openapi_config.security.append(self.security_requirement) + else: + app_config.openapi_config.security = [self.security_requirement] + + if self.guards: + app_config.guards.extend(self.guards) + + if self.dependencies: + app_config.dependencies.update(self.dependencies) + + if self.route_handlers: + app_config.route_handlers.extend(self.route_handlers) + + if self.type_encoders is None: + self.type_encoders = app_config.type_encoders + + return app_config + + def create_response( + self, + content: Any | None, + status_code: int, + media_type: MediaType | OpenAPIMediaType | str, + headers: dict[str, Any] | None = None, + cookies: ResponseCookies | None = None, + ) -> Response[Any]: + """Create a response object. + + Handles setting the type encoders mapping on the response. + + Args: + content: A value for the response body that will be rendered into bytes string. + status_code: An HTTP status code. + media_type: A value for the response 'Content-Type' header. + headers: A string keyed dictionary of response headers. Header keys are insensitive. + cookies: A list of :class:`Cookie <litestar.datastructures.Cookie>` instances to be set under + the response 'Set-Cookie' header. + + Returns: + A response object. + """ + return Response( + content=content, + status_code=status_code, + media_type=media_type, + headers=headers, + cookies=cookies, + type_encoders=self.type_encoders, + ) + + def __post_init__(self) -> None: + self.retrieve_user_handler = ensure_async_callable(self.retrieve_user_handler) + + @property + @abstractmethod + def openapi_components(self) -> Components: + """Create OpenAPI documentation for the JWT auth schema used. + + Returns: + An :class:`Components <litestar.openapi.spec.components.Components>` instance. + """ + raise NotImplementedError + + @property + @abstractmethod + def security_requirement(self) -> SecurityRequirement: + """Return OpenAPI 3.1. + + :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>` for the auth + backend. + + Returns: + An OpenAPI 3.1 :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>` dictionary. + """ + raise NotImplementedError + + @property + @abstractmethod + def middleware(self) -> DefineMiddleware: + """Create an instance of the config's ``authentication_middleware_class`` attribute and any required kwargs, + wrapping it in Litestar's ``DefineMiddleware``. + + Returns: + An instance of :class:`DefineMiddleware <litestar.middleware.base.DefineMiddleware>`. + """ + raise NotImplementedError diff --git a/venv/lib/python3.11/site-packages/litestar/security/jwt/__init__.py b/venv/lib/python3.11/site-packages/litestar/security/jwt/__init__.py new file mode 100644 index 0000000..4fd88de --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/jwt/__init__.py @@ -0,0 +1,23 @@ +from litestar.security.jwt.auth import ( + BaseJWTAuth, + JWTAuth, + JWTCookieAuth, + OAuth2Login, + OAuth2PasswordBearerAuth, +) +from litestar.security.jwt.middleware import ( + JWTAuthenticationMiddleware, + JWTCookieAuthenticationMiddleware, +) +from litestar.security.jwt.token import Token + +__all__ = ( + "BaseJWTAuth", + "JWTAuth", + "JWTAuthenticationMiddleware", + "JWTCookieAuth", + "JWTCookieAuthenticationMiddleware", + "OAuth2Login", + "OAuth2PasswordBearerAuth", + "Token", +) diff --git a/venv/lib/python3.11/site-packages/litestar/security/jwt/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/security/jwt/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f04d57f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/jwt/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/security/jwt/__pycache__/auth.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/security/jwt/__pycache__/auth.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..cec42c0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/jwt/__pycache__/auth.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/security/jwt/__pycache__/middleware.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/security/jwt/__pycache__/middleware.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8d5603e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/jwt/__pycache__/middleware.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/security/jwt/__pycache__/token.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/security/jwt/__pycache__/token.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b4f8c45 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/jwt/__pycache__/token.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/security/jwt/auth.py b/venv/lib/python3.11/site-packages/litestar/security/jwt/auth.py new file mode 100644 index 0000000..2a0f094 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/jwt/auth.py @@ -0,0 +1,691 @@ +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Literal, Sequence, TypeVar, cast + +from litestar.datastructures import Cookie +from litestar.enums import MediaType +from litestar.middleware import DefineMiddleware +from litestar.openapi.spec import Components, OAuthFlow, OAuthFlows, SecurityRequirement, SecurityScheme +from litestar.security.base import AbstractSecurityConfig +from litestar.security.jwt.middleware import JWTAuthenticationMiddleware, JWTCookieAuthenticationMiddleware +from litestar.security.jwt.token import Token +from litestar.status_codes import HTTP_201_CREATED +from litestar.types import ControllerRouterHandler, Empty, Guard, Method, Scopes, SyncOrAsyncUnion, TypeEncodersMap + +__all__ = ("BaseJWTAuth", "JWTAuth", "JWTCookieAuth", "OAuth2Login", "OAuth2PasswordBearerAuth") + + +if TYPE_CHECKING: + from litestar import Response + from litestar.connection import ASGIConnection + from litestar.di import Provide + + +UserType = TypeVar("UserType") + + +class BaseJWTAuth(Generic[UserType], AbstractSecurityConfig[UserType, Token]): + """Base class for JWT Auth backends""" + + token_secret: str + """Key with which to generate the token hash. + + Notes: + - This value should be kept as a secret and the standard practice is to inject it into the environment. + """ + retrieve_user_handler: Callable[[Any, ASGIConnection], SyncOrAsyncUnion[Any | None]] + """Callable that receives the ``auth`` value from the authentication middleware and returns a ``user`` value. + + Notes: + - User and Auth can be any arbitrary values specified by the security backend. + - The User and Auth values will be set by the middleware as ``scope["user"]`` and ``scope["auth"]`` respectively. + Once provided, they can access via the ``connection.user`` and ``connection.auth`` properties. + - The callable can be sync or async. If it is sync, it will be wrapped to support async. + + """ + algorithm: str + """Algorithm to use for JWT hashing.""" + auth_header: str + """Request header key from which to retrieve the token. + + E.g. ``Authorization`` or ``X-Api-Key``. + """ + default_token_expiration: timedelta + """The default value for token expiration.""" + openapi_security_scheme_name: str + """The value to use for the OpenAPI security scheme and security requirements.""" + description: str + """Description for the OpenAPI security scheme.""" + authentication_middleware_class: type[JWTAuthenticationMiddleware] # pyright: ignore + """The authentication middleware class to use. + + Must inherit from :class:`JWTAuthenticationMiddleware` + """ + + @property + def openapi_components(self) -> Components: + """Create OpenAPI documentation for the JWT auth schema used. + + Returns: + An :class:`Components <litestar.openapi.spec.components.Components>` instance. + """ + return Components( + security_schemes={ + self.openapi_security_scheme_name: SecurityScheme( + type="http", + scheme="Bearer", + name=self.auth_header, + bearer_format="JWT", + description=self.description, + ) + } + ) + + @property + def security_requirement(self) -> SecurityRequirement: + """Return OpenAPI 3.1. + + :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>` + + Returns: + An OpenAPI 3.1 + :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>` + dictionary. + """ + return {self.openapi_security_scheme_name: []} + + @property + def middleware(self) -> DefineMiddleware: + """Create :class:`JWTAuthenticationMiddleware` wrapped in + :class:`DefineMiddleware <.middleware.base.DefineMiddleware>`. + + Returns: + An instance of :class:`DefineMiddleware <.middleware.base.DefineMiddleware>`. + """ + return DefineMiddleware( + self.authentication_middleware_class, + algorithm=self.algorithm, + auth_header=self.auth_header, + exclude=self.exclude, + exclude_opt_key=self.exclude_opt_key, + exclude_http_methods=self.exclude_http_methods, + retrieve_user_handler=self.retrieve_user_handler, + scopes=self.scopes, + token_secret=self.token_secret, + ) + + def login( + self, + identifier: str, + *, + response_body: Any = Empty, + response_media_type: str | MediaType = MediaType.JSON, + response_status_code: int = HTTP_201_CREATED, + token_expiration: timedelta | None = None, + token_issuer: str | None = None, + token_audience: str | None = None, + token_unique_jwt_id: str | None = None, + token_extras: dict[str, Any] | None = None, + send_token_as_response_body: bool = False, + ) -> Response[Any]: + """Create a response with a JWT header. + + Args: + identifier: Unique identifier of the token subject. Usually this is a user ID or equivalent kind of value. + response_body: An optional response body to send. + response_media_type: An optional ``Content-Type``. Defaults to ``application/json``. + response_status_code: An optional status code for the response. Defaults to ``201``. + token_expiration: An optional timedelta for the token expiration. + token_issuer: An optional value of the token ``iss`` field. + token_audience: An optional value for the token ``aud`` field. + token_unique_jwt_id: An optional value for the token ``jti`` field. + token_extras: An optional dictionary to include in the token ``extras`` field. + send_token_as_response_body: If ``True`` the response will be a dict including the token: ``{ "token": <token> }`` + will be returned as the response body. Note: if a response body is passed this setting will be ignored. + + Returns: + A :class:`Response <.response.Response>` instance. + """ + encoded_token = self.create_token( + identifier=identifier, + token_expiration=token_expiration, + token_issuer=token_issuer, + token_audience=token_audience, + token_unique_jwt_id=token_unique_jwt_id, + token_extras=token_extras, + ) + + if response_body is not Empty: + body = response_body + elif send_token_as_response_body: + body = {"token": encoded_token} + else: + body = None + + return self.create_response( + content=body, + headers={self.auth_header: self.format_auth_header(encoded_token)}, + media_type=response_media_type, + status_code=response_status_code, + ) + + def create_token( + self, + identifier: str, + token_expiration: timedelta | None = None, + token_issuer: str | None = None, + token_audience: str | None = None, + token_unique_jwt_id: str | None = None, + token_extras: dict | None = None, + ) -> str: + """Create a Token instance from the passed in parameters, persists and returns it. + + Args: + identifier: Unique identifier of the token subject. Usually this is a user ID or equivalent kind of value. + token_expiration: An optional timedelta for the token expiration. + token_issuer: An optional value of the token ``iss`` field. + token_audience: An optional value for the token ``aud`` field. + token_unique_jwt_id: An optional value for the token ``jti`` field. + token_extras: An optional dictionary to include in the token ``extras`` field. + + Returns: + The created token. + """ + token = Token( + sub=identifier, + exp=(datetime.now(timezone.utc) + (token_expiration or self.default_token_expiration)), + iss=token_issuer, + aud=token_audience, + jti=token_unique_jwt_id, + extras=token_extras or {}, + ) + return token.encode(secret=self.token_secret, algorithm=self.algorithm) + + def format_auth_header(self, encoded_token: str) -> str: + """Format a token according to the specified OpenAPI scheme. + + Args: + encoded_token: An encoded JWT token + + Returns: + The encoded token formatted for the HTTP headers + """ + security = self.openapi_components.security_schemes.get(self.openapi_security_scheme_name, None) # type: ignore[union-attr] + return f"{security.scheme} {encoded_token}" if isinstance(security, SecurityScheme) else encoded_token + + +@dataclass +class JWTAuth(Generic[UserType], BaseJWTAuth[UserType]): + """JWT Authentication Configuration. + + This class is the main entry point to the library, and it includes methods to create the middleware, provide login + functionality, and create OpenAPI documentation. + """ + + token_secret: str + """Key with which to generate the token hash. + + Notes: + - This value should be kept as a secret and the standard practice is to inject it into the environment. + """ + retrieve_user_handler: Callable[[Any, ASGIConnection], SyncOrAsyncUnion[Any | None]] + """Callable that receives the ``auth`` value from the authentication middleware and returns a ``user`` value. + + Notes: + - User and Auth can be any arbitrary values specified by the security backend. + - The User and Auth values will be set by the middleware as ``scope["user"]`` and ``scope["auth"]`` respectively. + Once provided, they can access via the ``connection.user`` and ``connection.auth`` properties. + - The callable can be sync or async. If it is sync, it will be wrapped to support async. + + """ + guards: Iterable[Guard] | None = field(default=None) + """An iterable of guards to call for requests, providing authorization functionalities.""" + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns to skip in the authentication middleware.""" + exclude_opt_key: str = field(default="exclude_from_auth") + """An identifier to use on routes to disable authentication and authorization checks for a particular route.""" + exclude_http_methods: Sequence[Method] | None = field( + default_factory=lambda: cast("Sequence[Method]", ["OPTIONS", "HEAD"]) + ) + """A sequence of http methods that do not require authentication. Defaults to ['OPTIONS', 'HEAD']""" + scopes: Scopes | None = field(default=None) + """ASGI scopes processed by the authentication middleware, if ``None``, both ``http`` and ``websocket`` will be + processed.""" + route_handlers: Iterable[ControllerRouterHandler] | None = field(default=None) + """An optional iterable of route handlers to register.""" + dependencies: dict[str, Provide] | None = field(default=None) + """An optional dictionary of dependency providers.""" + + type_encoders: TypeEncodersMap | None = field(default=None) + """A mapping of types to callables that transform them into types supported for serialization.""" + + algorithm: str = field(default="HS256") + """Algorithm to use for JWT hashing.""" + auth_header: str = field(default="Authorization") + """Request header key from which to retrieve the token. + + E.g. ``Authorization`` or ``X-Api-Key``. + """ + default_token_expiration: timedelta = field(default_factory=lambda: timedelta(days=1)) + """The default value for token expiration.""" + openapi_security_scheme_name: str = field(default="BearerToken") + """The value to use for the OpenAPI security scheme and security requirements.""" + description: str = field(default="JWT api-key authentication and authorization.") + """Description for the OpenAPI security scheme.""" + authentication_middleware_class: type[JWTAuthenticationMiddleware] = field(default=JWTAuthenticationMiddleware) + """The authentication middleware class to use. + + Must inherit from :class:`JWTAuthenticationMiddleware` + """ + + +@dataclass +class JWTCookieAuth(Generic[UserType], BaseJWTAuth[UserType]): + """JWT Cookie Authentication Configuration. + + This class is an alternate entry point to the library, and it includes all the functionality of the :class:`JWTAuth` + class and adds support for passing JWT tokens ``HttpOnly`` cookies. + """ + + token_secret: str + """Key with which to generate the token hash. + + Notes: + - This value should be kept as a secret and the standard practice is to inject it into the environment. + """ + retrieve_user_handler: Callable[[Any, ASGIConnection], SyncOrAsyncUnion[Any | None]] + """Callable that receives the ``auth`` value from the authentication middleware and returns a ``user`` value. + + Notes: + - User and Auth can be any arbitrary values specified by the security backend. + - The User and Auth values will be set by the middleware as ``scope["user"]`` and ``scope["auth"]`` respectively. + Once provided, they can access via the ``connection.user`` and ``connection.auth`` properties. + - The callable can be sync or async. If it is sync, it will be wrapped to support async. + + """ + guards: Iterable[Guard] | None = field(default=None) + """An iterable of guards to call for requests, providing authorization functionalities.""" + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns to skip in the authentication middleware.""" + exclude_opt_key: str = field(default="exclude_from_auth") + """An identifier to use on routes to disable authentication and authorization checks for a particular route.""" + scopes: Scopes | None = field(default=None) + """ASGI scopes processed by the authentication middleware, if ``None``, both ``http`` and ``websocket`` will be + processed.""" + exclude_http_methods: Sequence[Method] | None = field( + default_factory=lambda: cast("Sequence[Method]", ["OPTIONS", "HEAD"]) + ) + """A sequence of http methods that do not require authentication. Defaults to ['OPTIONS', 'HEAD']""" + route_handlers: Iterable[ControllerRouterHandler] | None = field(default=None) + """An optional iterable of route handlers to register.""" + dependencies: dict[str, Provide] | None = field(default=None) + """An optional dictionary of dependency providers.""" + + type_encoders: TypeEncodersMap | None = field(default=None) + """A mapping of types to callables that transform them into types supported for serialization.""" + + algorithm: str = field(default="HS256") + """Algorithm to use for JWT hashing.""" + auth_header: str = field(default="Authorization") + """Request header key from which to retrieve the token. + + E.g. ``Authorization`` or ``X-Api-Key``. + """ + default_token_expiration: timedelta = field(default_factory=lambda: timedelta(days=1)) + """The default value for token expiration.""" + openapi_security_scheme_name: str = field(default="BearerToken") + """The value to use for the OpenAPI security scheme and security requirements.""" + key: str = field(default="token") + """Key for the cookie.""" + path: str = field(default="/") + """Path fragment that must exist in the request url for the cookie to be valid. + + Defaults to ``/``. + """ + domain: str | None = field(default=None) + """Domain for which the cookie is valid.""" + secure: bool | None = field(default=None) + """Https is required for the cookie.""" + samesite: Literal["lax", "strict", "none"] = field(default="lax") + """Controls whether or not a cookie is sent with cross-site requests. Defaults to ``lax``. """ + description: str = field(default="JWT cookie-based authentication and authorization.") + """Description for the OpenAPI security scheme.""" + authentication_middleware_class: type[JWTCookieAuthenticationMiddleware] = field( # pyright: ignore + default=JWTCookieAuthenticationMiddleware + ) + """The authentication middleware class to use. Must inherit from :class:`JWTCookieAuthenticationMiddleware` + """ + + @property + def openapi_components(self) -> Components: + """Create OpenAPI documentation for the JWT Cookie auth scheme. + + Returns: + A :class:`Components <litestar.openapi.spec.components.Components>` instance. + """ + return Components( + security_schemes={ + self.openapi_security_scheme_name: SecurityScheme( + type="http", + scheme="Bearer", + name=self.key, + security_scheme_in="cookie", + bearer_format="JWT", + description=self.description, + ) + } + ) + + @property + def middleware(self) -> DefineMiddleware: + """Create :class:`JWTCookieAuthenticationMiddleware` wrapped in + :class:`DefineMiddleware <.middleware.base.DefineMiddleware>`. + + Returns: + An instance of :class:`DefineMiddleware <.middleware.base.DefineMiddleware>`. + """ + return DefineMiddleware( + self.authentication_middleware_class, + algorithm=self.algorithm, + auth_cookie_key=self.key, + auth_header=self.auth_header, + exclude=self.exclude, + exclude_opt_key=self.exclude_opt_key, + exclude_http_methods=self.exclude_http_methods, + retrieve_user_handler=self.retrieve_user_handler, + scopes=self.scopes, + token_secret=self.token_secret, + ) + + def login( + self, + identifier: str, + *, + response_body: Any = Empty, + response_media_type: str | MediaType = MediaType.JSON, + response_status_code: int = HTTP_201_CREATED, + token_expiration: timedelta | None = None, + token_issuer: str | None = None, + token_audience: str | None = None, + token_unique_jwt_id: str | None = None, + token_extras: dict[str, Any] | None = None, + send_token_as_response_body: bool = False, + ) -> Response[Any]: + """Create a response with a JWT header. + + Args: + identifier: Unique identifier of the token subject. Usually this is a user ID or equivalent kind of value. + response_body: An optional response body to send. + response_media_type: An optional 'Content-Type'. Defaults to 'application/json'. + response_status_code: An optional status code for the response. Defaults to '201 Created'. + token_expiration: An optional timedelta for the token expiration. + token_issuer: An optional value of the token ``iss`` field. + token_audience: An optional value for the token ``aud`` field. + token_unique_jwt_id: An optional value for the token ``jti`` field. + token_extras: An optional dictionary to include in the token ``extras`` field. + send_token_as_response_body: If ``True`` the response will be a dict including the token: ``{ "token": <token> }`` + will be returned as the response body. Note: if a response body is passed this setting will be ignored. + + Returns: + A :class:`Response <.response.Response>` instance. + """ + + encoded_token = self.create_token( + identifier=identifier, + token_expiration=token_expiration, + token_issuer=token_issuer, + token_audience=token_audience, + token_unique_jwt_id=token_unique_jwt_id, + token_extras=token_extras, + ) + cookie = Cookie( + key=self.key, + path=self.path, + httponly=True, + value=self.format_auth_header(encoded_token), + max_age=int((token_expiration or self.default_token_expiration).total_seconds()), + secure=self.secure, + samesite=self.samesite, + domain=self.domain, + ) + + if response_body is not Empty: + body = response_body + elif send_token_as_response_body: + body = {"token": encoded_token} + else: + body = None + + return self.create_response( + content=body, + headers={self.auth_header: self.format_auth_header(encoded_token)}, + cookies=[cookie], + media_type=response_media_type, + status_code=response_status_code, + ) + + +@dataclass +class OAuth2Login: + """OAuth2 Login DTO""" + + access_token: str + """Valid JWT access token""" + token_type: str + """Type of the OAuth token used""" + refresh_token: str | None = field(default=None) + """Optional valid refresh token JWT""" + expires_in: int | None = field(default=None) + """Expiration time of the token in seconds. """ + + +@dataclass +class OAuth2PasswordBearerAuth(Generic[UserType], BaseJWTAuth[UserType]): + """OAUTH2 Schema for Password Bearer Authentication. + + This class implements an OAUTH2 authentication flow entry point to the library, and it includes all the + functionality of the :class:`JWTAuth` class and adds support for passing JWT tokens ``HttpOnly`` cookies. + + ``token_url`` is the only additional argument that is required, and it should point at your login route + """ + + token_secret: str + """Key with which to generate the token hash. + + Notes: + - This value should be kept as a secret and the standard practice is to inject it into the environment. + """ + token_url: str + """The URL for retrieving a new token.""" + retrieve_user_handler: Callable[[Any, ASGIConnection], SyncOrAsyncUnion[Any | None]] + """Callable that receives the ``auth`` value from the authentication middleware and returns a ``user`` value. + + Notes: + - User and Auth can be any arbitrary values specified by the security backend. + - The User and Auth values will be set by the middleware as ``scope["user"]`` and ``scope["auth"]`` respectively. + Once provided, they can access via the ``connection.user`` and ``connection.auth`` properties. + - The callable can be sync or async. If it is sync, it will be wrapped to support async. + + """ + guards: Iterable[Guard] | None = field(default=None) + """An iterable of guards to call for requests, providing authorization functionalities.""" + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns to skip in the authentication middleware.""" + exclude_opt_key: str = field(default="exclude_from_auth") + """An identifier to use on routes to disable authentication and authorization checks for a particular route.""" + exclude_http_methods: Sequence[Method] | None = field( + default_factory=lambda: cast("Sequence[Method]", ["OPTIONS", "HEAD"]) + ) + """A sequence of http methods that do not require authentication. Defaults to ['OPTIONS', 'HEAD']""" + scopes: Scopes | None = field(default=None) + """ASGI scopes processed by the authentication middleware, if ``None``, both ``http`` and ``websocket`` will be + processed.""" + route_handlers: Iterable[ControllerRouterHandler] | None = field(default=None) + """An optional iterable of route handlers to register.""" + dependencies: dict[str, Provide] | None = field(default=None) + """An optional dictionary of dependency providers.""" + type_encoders: TypeEncodersMap | None = field(default=None) + """A mapping of types to callables that transform them into types supported for serialization.""" + algorithm: str = field(default="HS256") + """Algorithm to use for JWT hashing.""" + auth_header: str = field(default="Authorization") + """Request header key from which to retrieve the token. + + E.g. ``Authorization`` or 'X-Api-Key'. + """ + default_token_expiration: timedelta = field(default_factory=lambda: timedelta(days=1)) + """The default value for token expiration.""" + openapi_security_scheme_name: str = field(default="BearerToken") + """The value to use for the OpenAPI security scheme and security requirements.""" + oauth_scopes: dict[str, str] | None = field(default=None) + """Oauth Scopes available for the token.""" + key: str = field(default="token") + """Key for the cookie.""" + path: str = field(default="/") + """Path fragment that must exist in the request url for the cookie to be valid. + + Defaults to ``/``. + """ + domain: str | None = field(default=None) + """Domain for which the cookie is valid.""" + secure: bool | None = field(default=None) + """Https is required for the cookie.""" + samesite: Literal["lax", "strict", "none"] = field(default="lax") + """Controls whether or not a cookie is sent with cross-site requests. Defaults to ``lax``. """ + description: str = field(default="OAUTH2 password bearer authentication and authorization.") + """Description for the OpenAPI security scheme.""" + authentication_middleware_class: type[JWTCookieAuthenticationMiddleware] = field( # pyright: ignore + default=JWTCookieAuthenticationMiddleware + ) + """The authentication middleware class to use. + + Must inherit from :class:`JWTCookieAuthenticationMiddleware` + """ + + @property + def middleware(self) -> DefineMiddleware: + """Create ``JWTCookieAuthenticationMiddleware`` wrapped in + :class:`DefineMiddleware <.middleware.base.DefineMiddleware>`. + + Returns: + An instance of :class:`DefineMiddleware <.middleware.base.DefineMiddleware>`. + """ + return DefineMiddleware( + self.authentication_middleware_class, + algorithm=self.algorithm, + auth_cookie_key=self.key, + auth_header=self.auth_header, + exclude=self.exclude, + exclude_opt_key=self.exclude_opt_key, + exclude_http_methods=self.exclude_http_methods, + retrieve_user_handler=self.retrieve_user_handler, + scopes=self.scopes, + token_secret=self.token_secret, + ) + + @property + def oauth_flow(self) -> OAuthFlow: + """Create an OpenAPI OAuth2 flow for the password bearer authentication scheme. + + Returns: + An :class:`OAuthFlow <litestar.openapi.spec.oauth_flow.OAuthFlow>` instance. + """ + return OAuthFlow( + token_url=self.token_url, + scopes=self.oauth_scopes, + ) + + @property + def openapi_components(self) -> Components: + """Create OpenAPI documentation for the OAUTH2 Password bearer auth scheme. + + Returns: + An :class:`Components <litestar.openapi.spec.components.Components>` instance. + """ + return Components( + security_schemes={ + self.openapi_security_scheme_name: SecurityScheme( + type="oauth2", + scheme="Bearer", + name=self.auth_header, + security_scheme_in="header", + flows=OAuthFlows(password=self.oauth_flow), # pyright: ignore[reportGeneralTypeIssues] + bearer_format="JWT", + description=self.description, + ) + } + ) + + def login( + self, + identifier: str, + *, + response_body: Any = Empty, + response_media_type: str | MediaType = MediaType.JSON, + response_status_code: int = HTTP_201_CREATED, + token_expiration: timedelta | None = None, + token_issuer: str | None = None, + token_audience: str | None = None, + token_unique_jwt_id: str | None = None, + token_extras: dict[str, Any] | None = None, + send_token_as_response_body: bool = True, + ) -> Response[Any]: + """Create a response with a JWT header. + + Args: + identifier: Unique identifier of the token subject. Usually this is a user ID or equivalent kind of value. + response_body: An optional response body to send. + response_media_type: An optional ``Content-Type``. Defaults to ``application/json``. + response_status_code: An optional status code for the response. Defaults to ``201``. + token_expiration: An optional timedelta for the token expiration. + token_issuer: An optional value of the token ``iss`` field. + token_audience: An optional value for the token ``aud`` field. + token_unique_jwt_id: An optional value for the token ``jti`` field. + token_extras: An optional dictionary to include in the token ``extras`` field. + send_token_as_response_body: If ``True`` the response will be an oAuth2 token response dict. + Note: if a response body is passed this setting will be ignored. + + Returns: + A :class:`Response <.response.Response>` instance. + """ + encoded_token = self.create_token( + identifier=identifier, + token_expiration=token_expiration, + token_issuer=token_issuer, + token_audience=token_audience, + token_unique_jwt_id=token_unique_jwt_id, + token_extras=token_extras, + ) + expires_in = int((token_expiration or self.default_token_expiration).total_seconds()) + cookie = Cookie( + key=self.key, + path=self.path, + httponly=True, + value=self.format_auth_header(encoded_token), + max_age=expires_in, + secure=self.secure, + samesite=self.samesite, + domain=self.domain, + ) + + if response_body is not Empty: + body = response_body + elif send_token_as_response_body: + token_dto = OAuth2Login( + access_token=encoded_token, + expires_in=expires_in, + token_type="bearer", # noqa: S106 + ) + body = asdict(token_dto) + else: + body = None + + return self.create_response( + content=body, + headers={self.auth_header: self.format_auth_header(encoded_token)}, + cookies=[cookie], + media_type=response_media_type, + status_code=response_status_code, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/security/jwt/middleware.py b/venv/lib/python3.11/site-packages/litestar/security/jwt/middleware.py new file mode 100644 index 0000000..84326da --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/jwt/middleware.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Awaitable, Callable, Sequence + +from litestar.exceptions import NotAuthorizedException +from litestar.middleware.authentication import ( + AbstractAuthenticationMiddleware, + AuthenticationResult, +) +from litestar.security.jwt.token import Token + +__all__ = ("JWTAuthenticationMiddleware", "JWTCookieAuthenticationMiddleware") + + +if TYPE_CHECKING: + from typing import Any + + from litestar.connection import ASGIConnection + from litestar.types import ASGIApp, Method, Scopes + + +class JWTAuthenticationMiddleware(AbstractAuthenticationMiddleware): + """JWT Authentication middleware. + + This class provides JWT authentication functionalities. + """ + + __slots__ = ( + "algorithm", + "auth_header", + "retrieve_user_handler", + "token_secret", + ) + + def __init__( + self, + algorithm: str, + app: ASGIApp, + auth_header: str, + exclude: str | list[str] | None, + exclude_http_methods: Sequence[Method] | None, + exclude_opt_key: str, + retrieve_user_handler: Callable[[Token, ASGIConnection[Any, Any, Any, Any]], Awaitable[Any]], + scopes: Scopes, + token_secret: str, + ) -> None: + """Check incoming requests for an encoded token in the auth header specified, and if present retrieve the user + from persistence using the provided function. + + Args: + algorithm: JWT hashing algorithm to use. + app: An ASGIApp, this value is the next ASGI handler to call in the middleware stack. + auth_header: Request header key from which to retrieve the token. E.g. ``Authorization`` or ``X-Api-Key``. + exclude: A pattern or list of patterns to skip. + exclude_opt_key: An identifier to use on routes to disable authentication for a particular route. + exclude_http_methods: A sequence of http methods that do not require authentication. + retrieve_user_handler: A function that receives a :class:`Token <.security.jwt.Token>` and returns a user, + which can be any arbitrary value. + scopes: ASGI scopes processed by the authentication middleware. + token_secret: Secret for decoding the JWT token. This value should be equivalent to the secret used to + encode it. + """ + super().__init__( + app=app, + exclude=exclude, + exclude_from_auth_key=exclude_opt_key, + exclude_http_methods=exclude_http_methods, + scopes=scopes, + ) + self.algorithm = algorithm + self.auth_header = auth_header + self.retrieve_user_handler = retrieve_user_handler + self.token_secret = token_secret + + async def authenticate_request(self, connection: ASGIConnection[Any, Any, Any, Any]) -> AuthenticationResult: + """Given an HTTP Connection, parse the JWT api key stored in the header and retrieve the user correlating to the + token from the DB. + + Args: + connection: An Litestar HTTPConnection instance. + + Returns: + AuthenticationResult + + Raises: + NotAuthorizedException: If token is invalid or user is not found. + """ + auth_header = connection.headers.get(self.auth_header) + if not auth_header: + raise NotAuthorizedException("No JWT token found in request header") + encoded_token = auth_header.partition(" ")[-1] + return await self.authenticate_token(encoded_token=encoded_token, connection=connection) + + async def authenticate_token( + self, encoded_token: str, connection: ASGIConnection[Any, Any, Any, Any] + ) -> AuthenticationResult: + """Given an encoded JWT token, parse, validate and look up sub within token. + + Args: + encoded_token: Encoded JWT token. + connection: An ASGI connection instance. + + Raises: + NotAuthorizedException: If token is invalid or user is not found. + + Returns: + AuthenticationResult + """ + token = Token.decode( + encoded_token=encoded_token, + secret=self.token_secret, + algorithm=self.algorithm, + ) + + user = await self.retrieve_user_handler(token, connection) + + if not user: + raise NotAuthorizedException() + + return AuthenticationResult(user=user, auth=token) + + +class JWTCookieAuthenticationMiddleware(JWTAuthenticationMiddleware): + """Cookie based JWT authentication middleware.""" + + __slots__ = ("auth_cookie_key",) + + def __init__( + self, + algorithm: str, + app: ASGIApp, + auth_cookie_key: str, + auth_header: str, + exclude: str | list[str] | None, + exclude_opt_key: str, + exclude_http_methods: Sequence[Method] | None, + retrieve_user_handler: Callable[[Token, ASGIConnection[Any, Any, Any, Any]], Awaitable[Any]], + scopes: Scopes, + token_secret: str, + ) -> None: + """Check incoming requests for an encoded token in the auth header or cookie name specified, and if present + retrieves the user from persistence using the provided function. + + Args: + algorithm: JWT hashing algorithm to use. + app: An ASGIApp, this value is the next ASGI handler to call in the middleware stack. + auth_cookie_key: Cookie name from which to retrieve the token. E.g. ``token`` or ``accessToken``. + auth_header: Request header key from which to retrieve the token. E.g. ``Authorization`` or ``X-Api-Key``. + exclude: A pattern or list of patterns to skip. + exclude_opt_key: An identifier to use on routes to disable authentication for a particular route. + exclude_http_methods: A sequence of http methods that do not require authentication. + retrieve_user_handler: A function that receives a :class:`Token <.security.jwt.Token>` and returns a user, + which can be any arbitrary value. + scopes: ASGI scopes processed by the authentication middleware. + token_secret: Secret for decoding the JWT token. This value should be equivalent to the secret used to + encode it. + """ + super().__init__( + algorithm=algorithm, + app=app, + auth_header=auth_header, + exclude=exclude, + exclude_http_methods=exclude_http_methods, + exclude_opt_key=exclude_opt_key, + retrieve_user_handler=retrieve_user_handler, + scopes=scopes, + token_secret=token_secret, + ) + self.auth_cookie_key = auth_cookie_key + + async def authenticate_request(self, connection: ASGIConnection[Any, Any, Any, Any]) -> AuthenticationResult: + """Given an HTTP Connection, parse the JWT api key stored in the header and retrieve the user correlating to the + token from the DB. + + Args: + connection: An Litestar HTTPConnection instance. + + Raises: + NotAuthorizedException: If token is invalid or user is not found. + + Returns: + AuthenticationResult + """ + auth_header = connection.headers.get(self.auth_header) or connection.cookies.get(self.auth_cookie_key) + if not auth_header: + raise NotAuthorizedException("No JWT token found in request header or cookies") + encoded_token = auth_header.partition(" ")[-1] + return await self.authenticate_token(encoded_token=encoded_token, connection=connection) diff --git a/venv/lib/python3.11/site-packages/litestar/security/jwt/token.py b/venv/lib/python3.11/site-packages/litestar/security/jwt/token.py new file mode 100644 index 0000000..279111a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/jwt/token.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import dataclasses +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from jose import JWSError, JWTError, jwt + +from litestar.exceptions import ImproperlyConfiguredException, NotAuthorizedException + +if TYPE_CHECKING: + from typing_extensions import Self + + +__all__ = ("Token",) + + +def _normalize_datetime(value: datetime) -> datetime: + """Convert the given value into UTC and strip microseconds. + + Args: + value: A datetime instance + + Returns: + A datetime instance + """ + if value.tzinfo is not None: + value.astimezone(timezone.utc) + + return value.replace(microsecond=0) + + +@dataclass +class Token: + """JWT Token DTO.""" + + exp: datetime + """Expiration - datetime for token expiration.""" + sub: str + """Subject - usually a unique identifier of the user or equivalent entity.""" + iat: datetime = field(default_factory=lambda: _normalize_datetime(datetime.now(timezone.utc))) + """Issued at - should always be current now.""" + iss: str | None = field(default=None) + """Issuer - optional unique identifier for the issuer.""" + aud: str | None = field(default=None) + """Audience - intended audience.""" + jti: str | None = field(default=None) + """JWT ID - a unique identifier of the JWT between different issuers.""" + extras: dict[str, Any] = field(default_factory=dict) + """Extra fields that were found on the JWT token.""" + + def __post_init__(self) -> None: + if len(self.sub) < 1: + raise ImproperlyConfiguredException("sub must be a string with a length greater than 0") + + if isinstance(self.exp, datetime) and ( + (exp := _normalize_datetime(self.exp)).timestamp() + >= _normalize_datetime(datetime.now(timezone.utc)).timestamp() + ): + self.exp = exp + else: + raise ImproperlyConfiguredException("exp value must be a datetime in the future") + + if isinstance(self.iat, datetime) and ( + (iat := _normalize_datetime(self.iat)).timestamp() + <= _normalize_datetime(datetime.now(timezone.utc)).timestamp() + ): + self.iat = iat + else: + raise ImproperlyConfiguredException("iat must be a current or past time") + + @classmethod + def decode(cls, encoded_token: str, secret: str | dict[str, str], algorithm: str) -> Self: + """Decode a passed in token string and returns a Token instance. + + Args: + encoded_token: A base64 string containing an encoded JWT. + secret: The secret with which the JWT is encoded. It may optionally be an individual JWK or JWS set dict + algorithm: The algorithm used to encode the JWT. + + Returns: + A decoded Token instance. + + Raises: + NotAuthorizedException: If the token is invalid. + """ + try: + payload = jwt.decode(token=encoded_token, key=secret, algorithms=[algorithm], options={"verify_aud": False}) + exp = datetime.fromtimestamp(payload.pop("exp"), tz=timezone.utc) + iat = datetime.fromtimestamp(payload.pop("iat"), tz=timezone.utc) + field_names = {f.name for f in dataclasses.fields(Token)} + extra_fields = payload.keys() - field_names + extras = payload.pop("extras", {}) + for key in extra_fields: + extras[key] = payload.pop(key) + return cls(exp=exp, iat=iat, **payload, extras=extras) + except (KeyError, JWTError, ImproperlyConfiguredException) as e: + raise NotAuthorizedException("Invalid token") from e + + def encode(self, secret: str, algorithm: str) -> str: + """Encode the token instance into a string. + + Args: + secret: The secret with which the JWT is encoded. + algorithm: The algorithm used to encode the JWT. + + Returns: + An encoded token string. + + Raises: + ImproperlyConfiguredException: If encoding fails. + """ + try: + return jwt.encode( + claims={k: v for k, v in asdict(self).items() if v is not None}, key=secret, algorithm=algorithm + ) + except (JWTError, JWSError) as e: + raise ImproperlyConfiguredException("Failed to encode token") from e diff --git a/venv/lib/python3.11/site-packages/litestar/security/session_auth/__init__.py b/venv/lib/python3.11/site-packages/litestar/security/session_auth/__init__.py new file mode 100644 index 0000000..7c83991 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/session_auth/__init__.py @@ -0,0 +1,4 @@ +from litestar.security.session_auth.auth import SessionAuth +from litestar.security.session_auth.middleware import SessionAuthMiddleware + +__all__ = ("SessionAuth", "SessionAuthMiddleware") diff --git a/venv/lib/python3.11/site-packages/litestar/security/session_auth/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/security/session_auth/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..95bf5c1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/session_auth/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/security/session_auth/__pycache__/auth.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/security/session_auth/__pycache__/auth.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8d4aa6c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/session_auth/__pycache__/auth.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/security/session_auth/__pycache__/middleware.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/security/session_auth/__pycache__/middleware.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..27e4213 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/session_auth/__pycache__/middleware.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/security/session_auth/auth.py b/venv/lib/python3.11/site-packages/litestar/security/session_auth/auth.py new file mode 100644 index 0000000..7a5c542 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/session_auth/auth.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Iterable, Sequence, cast + +from litestar.middleware.base import DefineMiddleware +from litestar.middleware.session.base import BaseBackendConfig, BaseSessionBackendT +from litestar.openapi.spec import Components, SecurityRequirement, SecurityScheme +from litestar.security.base import AbstractSecurityConfig, UserType +from litestar.security.session_auth.middleware import MiddlewareWrapper, SessionAuthMiddleware + +__all__ = ("SessionAuth",) + +if TYPE_CHECKING: + from litestar.connection import ASGIConnection + from litestar.di import Provide + from litestar.types import ControllerRouterHandler, Guard, Method, Scopes, SyncOrAsyncUnion, TypeEncodersMap + + +@dataclass +class SessionAuth(Generic[UserType, BaseSessionBackendT], AbstractSecurityConfig[UserType, Dict[str, Any]]): + """Session Based Security Backend.""" + + session_backend_config: BaseBackendConfig[BaseSessionBackendT] # pyright: ignore + """A session backend config.""" + retrieve_user_handler: Callable[[Any, ASGIConnection], SyncOrAsyncUnion[Any | None]] + """Callable that receives the ``auth`` value from the authentication middleware and returns a ``user`` value. + + Notes: + - User and Auth can be any arbitrary values specified by the security backend. + - The User and Auth values will be set by the middleware as ``scope["user"]`` and ``scope["auth"]`` respectively. + Once provided, they can access via the ``connection.user`` and ``connection.auth`` properties. + - The callable can be sync or async. If it is sync, it will be wrapped to support async. + + """ + + authentication_middleware_class: type[SessionAuthMiddleware] = field(default=SessionAuthMiddleware) # pyright: ignore + """The authentication middleware class to use. + + Must inherit from :class:`SessionAuthMiddleware <litestar.security.session_auth.middleware.SessionAuthMiddleware>` + """ + + guards: Iterable[Guard] | None = field(default=None) + """An iterable of guards to call for requests, providing authorization functionalities.""" + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns to skip in the authentication middleware.""" + exclude_opt_key: str = field(default="exclude_from_auth") + """An identifier to use on routes to disable authentication and authorization checks for a particular route.""" + exclude_http_methods: Sequence[Method] | None = field( + default_factory=lambda: cast("Sequence[Method]", ["OPTIONS", "HEAD"]) + ) + """A sequence of http methods that do not require authentication. Defaults to ['OPTIONS', 'HEAD']""" + scopes: Scopes | None = field(default=None) + """ASGI scopes processed by the authentication middleware, if ``None``, both ``http`` and ``websocket`` will be + processed.""" + route_handlers: Iterable[ControllerRouterHandler] | None = field(default=None) + """An optional iterable of route handlers to register.""" + dependencies: dict[str, Provide] | None = field(default=None) + """An optional dictionary of dependency providers.""" + + type_encoders: TypeEncodersMap | None = field(default=None) + """A mapping of types to callables that transform them into types supported for serialization.""" + + @property + def middleware(self) -> DefineMiddleware: + """Use this property to insert the config into a middleware list on one of the application layers. + + Examples: + .. code-block:: python + + from typing import Any + from os import urandom + + from litestar import Litestar, Request, get + from litestar_session import SessionAuth + + + async def retrieve_user_from_session(session: dict[str, Any]) -> Any: + # implement logic here to retrieve a ``user`` datum given the session dictionary + ... + + + session_auth_config = SessionAuth( + secret=urandom(16), retrieve_user_handler=retrieve_user_from_session + ) + + + @get("/") + def my_handler(request: Request) -> None: ... + + + app = Litestar(route_handlers=[my_handler], middleware=[session_auth_config.middleware]) + + + Returns: + An instance of DefineMiddleware including ``self`` as the config kwarg value. + """ + return DefineMiddleware(MiddlewareWrapper, config=self) + + @property + def session_backend(self) -> BaseSessionBackendT: + """Create a session backend. + + Returns: + A subclass of :class:`BaseSessionBackend <litestar.middleware.session.base.BaseSessionBackend>` + """ + return self.session_backend_config._backend_class(config=self.session_backend_config) # pyright: ignore + + @property + def openapi_components(self) -> Components: + """Create OpenAPI documentation for the Session Authentication schema used. + + Returns: + An :class:`Components <litestar.openapi.spec.components.Components>` instance. + """ + return Components( + security_schemes={ + "sessionCookie": SecurityScheme( + type="apiKey", + name=self.session_backend_config.key, + security_scheme_in="cookie", # pyright: ignore + description="Session cookie authentication.", + ) + } + ) + + @property + def security_requirement(self) -> SecurityRequirement: + """Return OpenAPI 3.1. + + :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>` for the auth + backend. + + Returns: + An OpenAPI 3.1 :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>` dictionary. + """ + return {"sessionCookie": []} diff --git a/venv/lib/python3.11/site-packages/litestar/security/session_auth/middleware.py b/venv/lib/python3.11/site-packages/litestar/security/session_auth/middleware.py new file mode 100644 index 0000000..bb3fce4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/session_auth/middleware.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Sequence + +from litestar.exceptions import NotAuthorizedException +from litestar.middleware.authentication import ( + AbstractAuthenticationMiddleware, + AuthenticationResult, +) +from litestar.middleware.exceptions import ExceptionHandlerMiddleware +from litestar.types import Empty, Method, Scopes + +__all__ = ("MiddlewareWrapper", "SessionAuthMiddleware") + +if TYPE_CHECKING: + from litestar.connection import ASGIConnection + from litestar.security.session_auth.auth import SessionAuth + from litestar.types import ASGIApp, Receive, Scope, Send + + +class MiddlewareWrapper: + """Wrapper class that serves as the middleware entry point.""" + + def __init__(self, app: ASGIApp, config: SessionAuth[Any, Any]) -> None: + """Wrap the SessionAuthMiddleware inside ExceptionHandlerMiddleware, and it wraps this inside SessionMiddleware. + This allows the auth middleware to raise exceptions and still have the response handled, while having the + session cleared. + + Args: + app: An ASGIApp, this value is the next ASGI handler to call in the middleware stack. + config: An instance of SessionAuth. + """ + self.app = app + self.config = config + self.has_wrapped_middleware = False + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Handle creating a middleware stack and calling it. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + if not self.has_wrapped_middleware: + litestar_app = scope["app"] + auth_middleware = self.config.authentication_middleware_class( + app=self.app, + exclude=self.config.exclude, + exclude_http_methods=self.config.exclude_http_methods, + exclude_opt_key=self.config.exclude_opt_key, + scopes=self.config.scopes, + retrieve_user_handler=self.config.retrieve_user_handler, # type: ignore[arg-type] + ) + exception_middleware = ExceptionHandlerMiddleware( + app=auth_middleware, + exception_handlers=litestar_app.exception_handlers or {}, # pyright: ignore + debug=None, + ) + self.app = self.config.session_backend_config.middleware.middleware( + app=exception_middleware, + backend=self.config.session_backend, + ) + self.has_wrapped_middleware = True + await self.app(scope, receive, send) + + +class SessionAuthMiddleware(AbstractAuthenticationMiddleware): + """Session Authentication Middleware.""" + + def __init__( + self, + app: ASGIApp, + exclude: str | list[str] | None, + exclude_http_methods: Sequence[Method] | None, + exclude_opt_key: str, + retrieve_user_handler: Callable[[dict[str, Any], ASGIConnection[Any, Any, Any, Any]], Awaitable[Any]], + scopes: Scopes | None, + ) -> None: + """Session based authentication middleware. + + Args: + app: An ASGIApp, this value is the next ASGI handler to call in the middleware stack. + exclude: A pattern or list of patterns to skip in the authentication middleware. + exclude_http_methods: A sequence of http methods that do not require authentication. + exclude_opt_key: An identifier to use on routes to disable authentication and authorization checks for a particular route. + scopes: ASGI scopes processed by the authentication middleware. + retrieve_user_handler: Callable that receives the ``session`` value from the authentication middleware and returns a ``user`` value. + """ + super().__init__( + app=app, + exclude=exclude, + exclude_from_auth_key=exclude_opt_key, + exclude_http_methods=exclude_http_methods, + scopes=scopes, + ) + self.retrieve_user_handler = retrieve_user_handler + + async def authenticate_request(self, connection: ASGIConnection[Any, Any, Any, Any]) -> AuthenticationResult: + """Authenticate an incoming connection. + + Args: + connection: An :class:`ASGIConnection <.connection.ASGIConnection>` instance. + + Raises: + NotAuthorizedException: if session data is empty or user is not found. + + Returns: + :class:`AuthenticationResult <.middleware.authentication.AuthenticationResult>` + """ + if not connection.session or connection.scope["session"] is Empty: + # the assignment of 'Empty' forces the session middleware to clear session data. + connection.scope["session"] = Empty + raise NotAuthorizedException("no session data found") + + user = await self.retrieve_user_handler(connection.session, connection) + + if not user: + connection.scope["session"] = Empty + raise NotAuthorizedException("no user correlating to session found") + + return AuthenticationResult(user=user, auth=connection.session) diff --git a/venv/lib/python3.11/site-packages/litestar/serialization/__init__.py b/venv/lib/python3.11/site-packages/litestar/serialization/__init__.py new file mode 100644 index 0000000..0a9189e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/serialization/__init__.py @@ -0,0 +1,19 @@ +from .msgspec_hooks import ( + decode_json, + decode_msgpack, + default_deserializer, + default_serializer, + encode_json, + encode_msgpack, + get_serializer, +) + +__all__ = ( + "default_deserializer", + "decode_json", + "decode_msgpack", + "default_serializer", + "encode_json", + "encode_msgpack", + "get_serializer", +) diff --git a/venv/lib/python3.11/site-packages/litestar/serialization/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/serialization/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..68fa695 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/serialization/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/serialization/__pycache__/msgspec_hooks.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/serialization/__pycache__/msgspec_hooks.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..00f6d48 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/serialization/__pycache__/msgspec_hooks.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/serialization/msgspec_hooks.py b/venv/lib/python3.11/site-packages/litestar/serialization/msgspec_hooks.py new file mode 100644 index 0000000..3d7fdb8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/serialization/msgspec_hooks.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +from collections import deque +from datetime import date, datetime, time +from decimal import Decimal +from functools import partial +from ipaddress import ( + IPv4Address, + IPv4Interface, + IPv4Network, + IPv6Address, + IPv6Interface, + IPv6Network, +) +from pathlib import Path, PurePath +from re import Pattern +from typing import TYPE_CHECKING, Any, Callable, Mapping, TypeVar, overload +from uuid import UUID + +import msgspec + +from litestar.exceptions import SerializationException +from litestar.types import Empty, EmptyType, Serializer, TypeDecodersSequence + +if TYPE_CHECKING: + from litestar.types import TypeEncodersMap + +__all__ = ( + "decode_json", + "decode_msgpack", + "default_deserializer", + "default_serializer", + "encode_json", + "encode_msgpack", + "get_serializer", +) + +T = TypeVar("T") + +DEFAULT_TYPE_ENCODERS: TypeEncodersMap = { + Path: str, + PurePath: str, + IPv4Address: str, + IPv4Interface: str, + IPv4Network: str, + IPv6Address: str, + IPv6Interface: str, + IPv6Network: str, + datetime: lambda val: val.isoformat(), + date: lambda val: val.isoformat(), + time: lambda val: val.isoformat(), + deque: list, + Decimal: lambda val: int(val) if val.as_tuple().exponent >= 0 else float(val), + Pattern: lambda val: val.pattern, + # support subclasses of stdlib types, If no previous type matched, these will be + # the last type in the mro, so we use this to (attempt to) convert a subclass into + # its base class. # see https://github.com/jcrist/msgspec/issues/248 + # and https://github.com/litestar-org/litestar/issues/1003 + str: str, + int: int, + float: float, + set: set, + frozenset: frozenset, + bytes: bytes, +} + + +def default_serializer(value: Any, type_encoders: Mapping[Any, Callable[[Any], Any]] | None = None) -> Any: + """Transform values non-natively supported by ``msgspec`` + + Args: + value: A value to serialized + type_encoders: Mapping of types to callables to transforming types + Returns: + A serialized value + Raises: + TypeError: if value is not supported + """ + type_encoders = {**DEFAULT_TYPE_ENCODERS, **(type_encoders or {})} + + for base in value.__class__.__mro__[:-1]: + try: + encoder = type_encoders[base] + return encoder(value) + except KeyError: + continue + + raise TypeError(f"Unsupported type: {type(value)!r}") + + +def default_deserializer( + target_type: Any, value: Any, type_decoders: TypeDecodersSequence | None = None +) -> Any: # pragma: no cover + """Transform values non-natively supported by ``msgspec`` + + Args: + target_type: Encountered type + value: Value to coerce + type_decoders: Optional sequence of type decoders + + Returns: + A ``msgspec``-supported type + """ + + from litestar.datastructures.state import ImmutableState + + if isinstance(value, target_type): + return value + + if type_decoders: + for predicate, decoder in type_decoders: + if predicate(target_type): + return decoder(target_type, value) + + if issubclass(target_type, (Path, PurePath, ImmutableState, UUID)): + return target_type(value) + + raise TypeError(f"Unsupported type: {type(value)!r}") + + +_msgspec_json_encoder = msgspec.json.Encoder(enc_hook=default_serializer) +_msgspec_json_decoder = msgspec.json.Decoder(dec_hook=default_deserializer) +_msgspec_msgpack_encoder = msgspec.msgpack.Encoder(enc_hook=default_serializer) +_msgspec_msgpack_decoder = msgspec.msgpack.Decoder(dec_hook=default_deserializer) + + +def encode_json(value: Any, serializer: Callable[[Any], Any] | None = None) -> bytes: + """Encode a value into JSON. + + Args: + value: Value to encode + serializer: Optional callable to support non-natively supported types. + + Returns: + JSON as bytes + + Raises: + SerializationException: If error encoding ``obj``. + """ + try: + return msgspec.json.encode(value, enc_hook=serializer) if serializer else _msgspec_json_encoder.encode(value) + except (TypeError, msgspec.EncodeError) as msgspec_error: + raise SerializationException(str(msgspec_error)) from msgspec_error + + +@overload +def decode_json(value: str | bytes) -> Any: ... + + +@overload +def decode_json(value: str | bytes, type_decoders: TypeDecodersSequence | None) -> Any: ... + + +@overload +def decode_json(value: str | bytes, target_type: type[T]) -> T: ... + + +@overload +def decode_json(value: str | bytes, target_type: type[T], type_decoders: TypeDecodersSequence | None) -> T: ... + + +def decode_json( # type: ignore[misc] + value: str | bytes, + target_type: type[T] | EmptyType = Empty, # pyright: ignore + type_decoders: TypeDecodersSequence | None = None, +) -> Any: + """Decode a JSON string/bytes into an object. + + Args: + value: Value to decode + target_type: An optional type to decode the data into + type_decoders: Optional sequence of type decoders + + Returns: + An object + + Raises: + SerializationException: If error decoding ``value``. + """ + try: + if target_type is Empty: + return _msgspec_json_decoder.decode(value) + return msgspec.json.decode( + value, dec_hook=partial(default_deserializer, type_decoders=type_decoders), type=target_type + ) + except msgspec.DecodeError as msgspec_error: + raise SerializationException(str(msgspec_error)) from msgspec_error + + +def encode_msgpack(value: Any, serializer: Callable[[Any], Any] | None = default_serializer) -> bytes: + """Encode a value into MessagePack. + + Args: + value: Value to encode + serializer: Optional callable to support non-natively supported types + + Returns: + MessagePack as bytes + + Raises: + SerializationException: If error encoding ``obj``. + """ + try: + if serializer is None or serializer is default_serializer: + return _msgspec_msgpack_encoder.encode(value) + return msgspec.msgpack.encode(value, enc_hook=serializer) + except (TypeError, msgspec.EncodeError) as msgspec_error: + raise SerializationException(str(msgspec_error)) from msgspec_error + + +@overload +def decode_msgpack(value: bytes) -> Any: ... + + +@overload +def decode_msgpack(value: bytes, type_decoders: TypeDecodersSequence | None) -> Any: ... + + +@overload +def decode_msgpack(value: bytes, target_type: type[T]) -> T: ... + + +@overload +def decode_msgpack(value: bytes, target_type: type[T], type_decoders: TypeDecodersSequence | None) -> T: ... + + +def decode_msgpack( # type: ignore[misc] + value: bytes, + target_type: type[T] | EmptyType = Empty, # pyright: ignore[reportInvalidTypeVarUse] + type_decoders: TypeDecodersSequence | None = None, +) -> Any: + """Decode a MessagePack string/bytes into an object. + + Args: + value: Value to decode + target_type: An optional type to decode the data into + type_decoders: Optional sequence of type decoders + + Returns: + An object + + Raises: + SerializationException: If error decoding ``value``. + """ + try: + if target_type is Empty: + return _msgspec_msgpack_decoder.decode(value) + return msgspec.msgpack.decode( + value, dec_hook=partial(default_deserializer, type_decoders=type_decoders), type=target_type + ) + except msgspec.DecodeError as msgspec_error: + raise SerializationException(str(msgspec_error)) from msgspec_error + + +def get_serializer(type_encoders: TypeEncodersMap | None = None) -> Serializer: + """Get the serializer for the given type encoders.""" + + if type_encoders: + return partial(default_serializer, type_encoders=type_encoders) + + return default_serializer diff --git a/venv/lib/python3.11/site-packages/litestar/static_files/__init__.py b/venv/lib/python3.11/site-packages/litestar/static_files/__init__.py new file mode 100644 index 0000000..3cd4594 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/static_files/__init__.py @@ -0,0 +1,4 @@ +from litestar.static_files.base import StaticFiles +from litestar.static_files.config import StaticFilesConfig, create_static_files_router + +__all__ = ("StaticFiles", "StaticFilesConfig", "create_static_files_router") diff --git a/venv/lib/python3.11/site-packages/litestar/static_files/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/static_files/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1cc4497 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/static_files/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/static_files/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/static_files/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0ca9dae --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/static_files/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/static_files/__pycache__/config.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/static_files/__pycache__/config.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..fd93af6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/static_files/__pycache__/config.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/static_files/base.py b/venv/lib/python3.11/site-packages/litestar/static_files/base.py new file mode 100644 index 0000000..9827697 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/static_files/base.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from os.path import commonpath +from pathlib import Path +from typing import TYPE_CHECKING, Literal, Sequence + +from litestar.enums import ScopeType +from litestar.exceptions import MethodNotAllowedException, NotFoundException +from litestar.file_system import FileSystemAdapter +from litestar.response.file import ASGIFileResponse +from litestar.status_codes import HTTP_404_NOT_FOUND + +__all__ = ("StaticFiles",) + + +if TYPE_CHECKING: + from litestar.types import Receive, Scope, Send + from litestar.types.composite_types import PathType + from litestar.types.file_types import FileInfo, FileSystemProtocol + + +class StaticFiles: + """ASGI App that handles file sending.""" + + __slots__ = ("is_html_mode", "directories", "adapter", "send_as_attachment", "headers") + + def __init__( + self, + is_html_mode: bool, + directories: Sequence[PathType], + file_system: FileSystemProtocol, + send_as_attachment: bool = False, + resolve_symlinks: bool = True, + headers: dict[str, str] | None = None, + ) -> None: + """Initialize the Application. + + Args: + is_html_mode: Flag dictating whether serving html. If true, the default file will be ``index.html``. + directories: A list of directories to serve files from. + file_system: The file_system spec to use for serving files. + send_as_attachment: Whether to send the file with a ``content-disposition`` header of + ``attachment`` or ``inline`` + resolve_symlinks: Resolve symlinks to the directories + headers: Headers that will be sent with every response. + """ + self.adapter = FileSystemAdapter(file_system) + self.directories = tuple(Path(p).resolve() if resolve_symlinks else Path(p) for p in directories) + self.is_html_mode = is_html_mode + self.send_as_attachment = send_as_attachment + self.headers = headers + + async def get_fs_info( + self, directories: Sequence[PathType], file_path: PathType + ) -> tuple[Path, FileInfo] | tuple[None, None]: + """Return the resolved path and a :class:`stat_result <os.stat_result>`. + + Args: + directories: A list of directory paths. + file_path: A file path to resolve + + Returns: + A tuple with an optional resolved :class:`Path <anyio.Path>` instance and an optional + :class:`stat_result <os.stat_result>`. + """ + for directory in directories: + try: + joined_path = Path(directory, file_path) + file_info = await self.adapter.info(joined_path) + if file_info and commonpath([str(directory), file_info["name"], joined_path]) == str(directory): + return joined_path, file_info + except FileNotFoundError: + continue + return None, None + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: ASGI scope + receive: ASGI ``receive`` callable + send: ASGI ``send`` callable + + Returns: + None + """ + if scope["type"] != ScopeType.HTTP or scope["method"] not in {"GET", "HEAD"}: + raise MethodNotAllowedException() + + res = await self.handle(path=scope["path"], is_head_response=scope["method"] == "HEAD") + await res(scope=scope, receive=receive, send=send) + + async def handle(self, path: str, is_head_response: bool) -> ASGIFileResponse: + split_path = path.split("/") + filename = split_path[-1] + joined_path = Path(*split_path) + resolved_path, fs_info = await self.get_fs_info(directories=self.directories, file_path=joined_path) + content_disposition_type: Literal["inline", "attachment"] = ( + "attachment" if self.send_as_attachment else "inline" + ) + + if self.is_html_mode and fs_info and fs_info["type"] == "directory": + filename = "index.html" + resolved_path, fs_info = await self.get_fs_info( + directories=self.directories, + file_path=Path(resolved_path or joined_path) / filename, + ) + + if fs_info and fs_info["type"] == "file": + return ASGIFileResponse( + file_path=resolved_path or joined_path, + file_info=fs_info, + file_system=self.adapter.file_system, + filename=filename, + content_disposition_type=content_disposition_type, + is_head_response=is_head_response, + headers=self.headers, + ) + + if self.is_html_mode: + # for some reason coverage doesn't catch these two lines + filename = "404.html" # pragma: no cover + resolved_path, fs_info = await self.get_fs_info( # pragma: no cover + directories=self.directories, file_path=filename + ) + + if fs_info and fs_info["type"] == "file": + return ASGIFileResponse( + file_path=resolved_path or joined_path, + file_info=fs_info, + file_system=self.adapter.file_system, + filename=filename, + status_code=HTTP_404_NOT_FOUND, + content_disposition_type=content_disposition_type, + is_head_response=is_head_response, + headers=self.headers, + ) + + raise NotFoundException( + f"no file or directory match the path {resolved_path or joined_path} was found" + ) # pragma: no cover diff --git a/venv/lib/python3.11/site-packages/litestar/static_files/config.py b/venv/lib/python3.11/site-packages/litestar/static_files/config.py new file mode 100644 index 0000000..22b6620 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/static_files/config.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import PurePath # noqa: TCH003 +from typing import TYPE_CHECKING, Any, Sequence + +from litestar.exceptions import ImproperlyConfiguredException +from litestar.file_system import BaseLocalFileSystem +from litestar.handlers import asgi, get, head +from litestar.response.file import ASGIFileResponse # noqa: TCH001 +from litestar.router import Router +from litestar.static_files.base import StaticFiles +from litestar.types import Empty +from litestar.utils import normalize_path, warn_deprecation + +__all__ = ("StaticFilesConfig",) + +if TYPE_CHECKING: + from litestar.datastructures import CacheControlHeader + from litestar.handlers.asgi_handlers import ASGIRouteHandler + from litestar.openapi.spec import SecurityRequirement + from litestar.types import ( + AfterRequestHookHandler, + AfterResponseHookHandler, + BeforeRequestHookHandler, + EmptyType, + ExceptionHandlersMap, + Guard, + Middleware, + PathType, + ) + + +@dataclass +class StaticFilesConfig: + """Configuration for static file service. + + To enable static files, pass an instance of this class to the :class:`Litestar <litestar.app.Litestar>` constructor using + the 'static_files_config' key. + """ + + path: str + """Path to serve static files from. + + Note that the path cannot contain path parameters. + """ + directories: list[PathType] + """A list of directories to serve files from.""" + html_mode: bool = False + """Flag dictating whether serving html. + + If true, the default file will be 'index.html'. + """ + name: str | None = None + """An optional string identifying the static files handler.""" + file_system: Any = BaseLocalFileSystem() # noqa: RUF009 + """The file_system spec to use for serving files. + + Notes: + - A file_system is a class that adheres to the + :class:`FileSystemProtocol <litestar.types.FileSystemProtocol>`. + - You can use any of the file systems exported from the + [fsspec](https://filesystem-spec.readthedocs.io/en/latest/) library for this purpose. + """ + opt: dict[str, Any] | None = None + """A string key dictionary of arbitrary values that will be added to the static files handler.""" + guards: list[Guard] | None = None + """A list of :class:`Guard <litestar.types.Guard>` callables.""" + exception_handlers: ExceptionHandlersMap | None = None + """A dictionary that maps handler functions to status codes and/or exception types.""" + send_as_attachment: bool = False + """Whether to send the file as an attachment.""" + + def __post_init__(self) -> None: + _validate_config(path=self.path, directories=self.directories, file_system=self.file_system) + self.path = normalize_path(self.path) + warn_deprecation( + "2.6.0", + kind="class", + deprecated_name="StaticFilesConfig", + removal_in="3.0", + alternative="create_static_files_router", + info='Replace static_files_config=[StaticFilesConfig(path="/static", directories=["assets"])] with ' + 'route_handlers=[..., create_static_files_router(path="/static", directories=["assets"])]', + ) + + def to_static_files_app(self) -> ASGIRouteHandler: + """Return an ASGI app serving static files based on the config. + + Returns: + :class:`StaticFiles <litestar.static_files.StaticFiles>` + """ + static_files = StaticFiles( + is_html_mode=self.html_mode, + directories=self.directories, + file_system=self.file_system, + send_as_attachment=self.send_as_attachment, + ) + return asgi( + path=self.path, + name=self.name, + is_static=True, + opt=self.opt, + guards=self.guards, + exception_handlers=self.exception_handlers, + )(static_files) + + +def create_static_files_router( + path: str, + directories: list[PathType], + file_system: Any = None, + send_as_attachment: bool = False, + html_mode: bool = False, + name: str = "static", + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + before_request: BeforeRequestHookHandler | None = None, + cache_control: CacheControlHeader | None = None, + exception_handlers: ExceptionHandlersMap | None = None, + guards: list[Guard] | None = None, + include_in_schema: bool | EmptyType = Empty, + middleware: Sequence[Middleware] | None = None, + opt: dict[str, Any] | None = None, + security: Sequence[SecurityRequirement] | None = None, + tags: Sequence[str] | None = None, + router_class: type[Router] = Router, + resolve_symlinks: bool = True, +) -> Router: + """Create a router with handlers to serve static files. + + Args: + path: Path to serve static files under + directories: Directories to serve static files from + file_system: A *file system* implementing + :class:`~litestar.types.FileSystemProtocol`. + `fsspec <https://filesystem-spec.readthedocs.io/en/latest/>`_ can be passed + here as well + send_as_attachment: Whether to send the file as an attachment + html_mode: When in HTML: + - Serve an ``index.html`` file from ``/`` + - Serve ``404.html`` when a file could not be found + name: Name to pass to the generated handlers + after_request: ``after_request`` handlers passed to the router + after_response: ``after_response`` handlers passed to the router + before_request: ``before_request`` handlers passed to the router + cache_control: ``cache_control`` passed to the router + exception_handlers: Exception handlers passed to the router + guards: Guards passed to the router + include_in_schema: Include the routes / router in the OpenAPI schema + middleware: Middlewares passed to the router + opt: Opts passed to the router + security: Security options passed to the router + tags: ``tags`` passed to the router + router_class: The class used to construct a router from + resolve_symlinks: Resolve symlinks of ``directories`` + """ + + if file_system is None: + file_system = BaseLocalFileSystem() + + _validate_config(path=path, directories=directories, file_system=file_system) + path = normalize_path(path) + + headers = None + if cache_control: + headers = {cache_control.HEADER_NAME: cache_control.to_header()} + + static_files = StaticFiles( + is_html_mode=html_mode, + directories=directories, + file_system=file_system, + send_as_attachment=send_as_attachment, + resolve_symlinks=resolve_symlinks, + headers=headers, + ) + + @get("{file_path:path}", name=name) + async def get_handler(file_path: PurePath) -> ASGIFileResponse: + return await static_files.handle(path=file_path.as_posix(), is_head_response=False) + + @head("/{file_path:path}", name=f"{name}/head") + async def head_handler(file_path: PurePath) -> ASGIFileResponse: + return await static_files.handle(path=file_path.as_posix(), is_head_response=True) + + handlers = [get_handler, head_handler] + + if html_mode: + + @get("/", name=f"{name}/index") + async def index_handler() -> ASGIFileResponse: + return await static_files.handle(path="/", is_head_response=False) + + handlers.append(index_handler) + + return router_class( + after_request=after_request, + after_response=after_response, + before_request=before_request, + cache_control=cache_control, + exception_handlers=exception_handlers, + guards=guards, + include_in_schema=include_in_schema, + middleware=middleware, + opt=opt, + path=path, + route_handlers=handlers, + security=security, + tags=tags, + ) + + +def _validate_config(path: str, directories: list[PathType], file_system: Any) -> None: + if not path: + raise ImproperlyConfiguredException("path must be a non-zero length string,") + + if not directories or not any(bool(d) for d in directories): + raise ImproperlyConfiguredException("directories must include at least one path.") + + if "{" in path: + raise ImproperlyConfiguredException("path parameters are not supported for static files") + + if not (callable(getattr(file_system, "info", None)) and callable(getattr(file_system, "open", None))): + raise ImproperlyConfiguredException("file_system must adhere to the FileSystemProtocol type") diff --git a/venv/lib/python3.11/site-packages/litestar/status_codes.py b/venv/lib/python3.11/site-packages/litestar/status_codes.py new file mode 100644 index 0000000..9293365 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/status_codes.py @@ -0,0 +1,321 @@ +from typing import Final + +# HTTP Status Codes + +HTTP_100_CONTINUE: Final = 100 +"""HTTP status code 'Continue'""" + +HTTP_101_SWITCHING_PROTOCOLS: Final = 101 +"""HTTP status code 'Switching Protocols'""" + +HTTP_102_PROCESSING: Final = 102 +"""HTTP status code 'Processing'""" + +HTTP_103_EARLY_HINTS: Final = 103 +"""HTTP status code 'Early Hints'""" + +HTTP_200_OK: Final = 200 +"""HTTP status code 'OK'""" + +HTTP_201_CREATED: Final = 201 +"""HTTP status code 'Created'""" + +HTTP_202_ACCEPTED: Final = 202 +"""HTTP status code 'Accepted'""" + +HTTP_203_NON_AUTHORITATIVE_INFORMATION: Final = 203 +"""HTTP status code 'Non Authoritative Information'""" + +HTTP_204_NO_CONTENT: Final = 204 +"""HTTP status code 'No Content'""" + +HTTP_205_RESET_CONTENT: Final = 205 +"""HTTP status code 'Reset Content'""" + +HTTP_206_PARTIAL_CONTENT: Final = 206 +"""HTTP status code 'Partial Content'""" + +HTTP_207_MULTI_STATUS: Final = 207 +"""HTTP status code 'Multi Status'""" + +HTTP_208_ALREADY_REPORTED: Final = 208 +"""HTTP status code 'Already Reported'""" + +HTTP_226_IM_USED: Final = 226 +"""HTTP status code 'I'm Used'""" + +HTTP_300_MULTIPLE_CHOICES: Final = 300 +"""HTTP status code 'Multiple Choices'""" + +HTTP_301_MOVED_PERMANENTLY: Final = 301 +"""HTTP status code 'Moved Permanently'""" + +HTTP_302_FOUND: Final = 302 +"""HTTP status code 'Found'""" + +HTTP_303_SEE_OTHER: Final = 303 +"""HTTP status code 'See Other'""" + +HTTP_304_NOT_MODIFIED: Final = 304 +"""HTTP status code 'Not Modified'""" + +HTTP_305_USE_PROXY: Final = 305 +"""HTTP status code 'Use Proxy'""" + +HTTP_306_RESERVED: Final = 306 +"""HTTP status code 'Reserved'""" + +HTTP_307_TEMPORARY_REDIRECT: Final = 307 +"""HTTP status code 'Temporary Redirect'""" + +HTTP_308_PERMANENT_REDIRECT: Final = 308 +"""HTTP status code 'Permanent Redirect'""" + +HTTP_400_BAD_REQUEST: Final = 400 +"""HTTP status code 'Bad Request'""" + +HTTP_401_UNAUTHORIZED: Final = 401 +"""HTTP status code 'Unauthorized'""" + +HTTP_402_PAYMENT_REQUIRED: Final = 402 +"""HTTP status code 'Payment Required'""" + +HTTP_403_FORBIDDEN: Final = 403 +"""HTTP status code 'Forbidden'""" + +HTTP_404_NOT_FOUND: Final = 404 +"""HTTP status code 'Not Found'""" + +HTTP_405_METHOD_NOT_ALLOWED: Final = 405 +"""HTTP status code 'Method Not Allowed'""" + +HTTP_406_NOT_ACCEPTABLE: Final = 406 +"""HTTP status code 'Not Acceptable'""" + +HTTP_407_PROXY_AUTHENTICATION_REQUIRED: Final = 407 +"""HTTP status code 'Proxy Authentication Required'""" + +HTTP_408_REQUEST_TIMEOUT: Final = 408 +"""HTTP status code 'Request Timeout'""" + +HTTP_409_CONFLICT: Final = 409 +"""HTTP status code 'Conflict'""" + +HTTP_410_GONE: Final = 410 +"""HTTP status code 'Gone'""" + +HTTP_411_LENGTH_REQUIRED: Final = 411 +"""HTTP status code 'Length Required'""" + +HTTP_412_PRECONDITION_FAILED: Final = 412 +"""HTTP status code 'Precondition Failed'""" + +HTTP_413_REQUEST_ENTITY_TOO_LARGE: Final = 413 +"""HTTP status code 'Request Entity Too Large'""" + +HTTP_414_REQUEST_URI_TOO_LONG: Final = 414 +"""HTTP status code 'Request URI Too Long'""" + +HTTP_415_UNSUPPORTED_MEDIA_TYPE: Final = 415 +"""HTTP status code 'Unsupported Media Type'""" + +HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE: Final = 416 +"""HTTP status code 'Requested Range Not Satisfiable'""" + +HTTP_417_EXPECTATION_FAILED: Final = 417 +"""HTTP status code 'Expectation Failed'""" + +HTTP_418_IM_A_TEAPOT: Final = 418 +"""HTTP status code 'I'm A Teapot'""" + +HTTP_421_MISDIRECTED_REQUEST: Final = 421 +"""HTTP status code 'Misdirected Request'""" + +HTTP_422_UNPROCESSABLE_ENTITY: Final = 422 +"""HTTP status code 'Unprocessable Entity'""" + +HTTP_423_LOCKED: Final = 423 +"""HTTP status code 'Locked'""" + +HTTP_424_FAILED_DEPENDENCY: Final = 424 +"""HTTP status code 'Failed Dependency'""" + +HTTP_425_TOO_EARLY: Final = 425 +"""HTTP status code 'Too Early'""" + +HTTP_426_UPGRADE_REQUIRED: Final = 426 +"""HTTP status code 'Upgrade Required'""" + +HTTP_428_PRECONDITION_REQUIRED: Final = 428 +"""HTTP status code 'Precondition Required'""" + +HTTP_429_TOO_MANY_REQUESTS: Final = 429 +"""HTTP status code 'Too Many Requests'""" + +HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE: Final = 431 +"""HTTP status code 'Request Header Fields Too Large'""" + +HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS: Final = 451 +"""HTTP status code 'Unavailable For Legal Reasons'""" + +HTTP_500_INTERNAL_SERVER_ERROR: Final = 500 +"""HTTP status code 'Internal Server Error'""" + +HTTP_501_NOT_IMPLEMENTED: Final = 501 +"""HTTP status code 'Not Implemented'""" + +HTTP_502_BAD_GATEWAY: Final = 502 +"""HTTP status code 'Bad Gateway'""" + +HTTP_503_SERVICE_UNAVAILABLE: Final = 503 +"""HTTP status code 'Service Unavailable'""" + +HTTP_504_GATEWAY_TIMEOUT: Final = 504 +"""HTTP status code 'Gateway Timeout'""" + +HTTP_505_HTTP_VERSION_NOT_SUPPORTED: Final = 505 +"""HTTP status code 'Http Version Not Supported'""" + +HTTP_506_VARIANT_ALSO_NEGOTIATES: Final = 506 +"""HTTP status code 'Variant Also Negotiates'""" + +HTTP_507_INSUFFICIENT_STORAGE: Final = 507 +"""HTTP status code 'Insufficient Storage'""" + +HTTP_508_LOOP_DETECTED: Final = 508 +"""HTTP status code 'Loop Detected'""" + +HTTP_510_NOT_EXTENDED: Final = 510 +"""HTTP status code 'Not Extended'""" + +HTTP_511_NETWORK_AUTHENTICATION_REQUIRED: Final = 511 +"""HTTP status code 'Network Authentication Required'""" + + +# Websocket Codes +WS_1000_NORMAL_CLOSURE: Final = 1000 +"""WebSocket status code 'Normal Closure'""" + +WS_1001_GOING_AWAY: Final = 1001 +"""WebSocket status code 'Going Away'""" + +WS_1002_PROTOCOL_ERROR: Final = 1002 +"""WebSocket status code 'Protocol Error'""" + +WS_1003_UNSUPPORTED_DATA: Final = 1003 +"""WebSocket status code 'Unsupported Data'""" + +WS_1005_NO_STATUS_RECEIVED: Final = 1005 +"""WebSocket status code 'No Status Received'""" + +WS_1006_ABNORMAL_CLOSURE: Final = 1006 +"""WebSocket status code 'Abnormal Closure'""" + +WS_1007_INVALID_FRAME_PAYLOAD_DATA: Final = 1007 +"""WebSocket status code 'Invalid Frame Payload Data'""" + +WS_1008_POLICY_VIOLATION: Final = 1008 +"""WebSocket status code 'Policy Violation'""" + +WS_1009_MESSAGE_TOO_BIG: Final = 1009 +"""WebSocket status code 'Message Too Big'""" + +WS_1010_MANDATORY_EXT: Final = 1010 +"""WebSocket status code 'Mandatory Ext.'""" + +WS_1011_INTERNAL_ERROR: Final = 1011 +"""WebSocket status code 'Internal Error'""" + +WS_1012_SERVICE_RESTART: Final = 1012 +"""WebSocket status code 'Service Restart'""" + +WS_1013_TRY_AGAIN_LATER: Final = 1013 +"""WebSocket status code 'Try Again Later'""" + +WS_1014_BAD_GATEWAY: Final = 1014 +"""WebSocket status code 'Bad Gateway'""" + +WS_1015_TLS_HANDSHAKE: Final = 1015 +"""WebSocket status code 'TLS Handshake'""" + + +__all__ = ( + "HTTP_100_CONTINUE", + "HTTP_101_SWITCHING_PROTOCOLS", + "HTTP_102_PROCESSING", + "HTTP_103_EARLY_HINTS", + "HTTP_200_OK", + "HTTP_201_CREATED", + "HTTP_202_ACCEPTED", + "HTTP_203_NON_AUTHORITATIVE_INFORMATION", + "HTTP_204_NO_CONTENT", + "HTTP_205_RESET_CONTENT", + "HTTP_206_PARTIAL_CONTENT", + "HTTP_207_MULTI_STATUS", + "HTTP_208_ALREADY_REPORTED", + "HTTP_226_IM_USED", + "HTTP_300_MULTIPLE_CHOICES", + "HTTP_301_MOVED_PERMANENTLY", + "HTTP_302_FOUND", + "HTTP_303_SEE_OTHER", + "HTTP_304_NOT_MODIFIED", + "HTTP_305_USE_PROXY", + "HTTP_306_RESERVED", + "HTTP_307_TEMPORARY_REDIRECT", + "HTTP_308_PERMANENT_REDIRECT", + "HTTP_400_BAD_REQUEST", + "HTTP_401_UNAUTHORIZED", + "HTTP_402_PAYMENT_REQUIRED", + "HTTP_403_FORBIDDEN", + "HTTP_404_NOT_FOUND", + "HTTP_405_METHOD_NOT_ALLOWED", + "HTTP_406_NOT_ACCEPTABLE", + "HTTP_407_PROXY_AUTHENTICATION_REQUIRED", + "HTTP_408_REQUEST_TIMEOUT", + "HTTP_409_CONFLICT", + "HTTP_410_GONE", + "HTTP_411_LENGTH_REQUIRED", + "HTTP_412_PRECONDITION_FAILED", + "HTTP_413_REQUEST_ENTITY_TOO_LARGE", + "HTTP_414_REQUEST_URI_TOO_LONG", + "HTTP_415_UNSUPPORTED_MEDIA_TYPE", + "HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE", + "HTTP_417_EXPECTATION_FAILED", + "HTTP_418_IM_A_TEAPOT", + "HTTP_421_MISDIRECTED_REQUEST", + "HTTP_422_UNPROCESSABLE_ENTITY", + "HTTP_423_LOCKED", + "HTTP_424_FAILED_DEPENDENCY", + "HTTP_425_TOO_EARLY", + "HTTP_426_UPGRADE_REQUIRED", + "HTTP_428_PRECONDITION_REQUIRED", + "HTTP_429_TOO_MANY_REQUESTS", + "HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE", + "HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS", + "HTTP_500_INTERNAL_SERVER_ERROR", + "HTTP_501_NOT_IMPLEMENTED", + "HTTP_502_BAD_GATEWAY", + "HTTP_503_SERVICE_UNAVAILABLE", + "HTTP_504_GATEWAY_TIMEOUT", + "HTTP_505_HTTP_VERSION_NOT_SUPPORTED", + "HTTP_506_VARIANT_ALSO_NEGOTIATES", + "HTTP_507_INSUFFICIENT_STORAGE", + "HTTP_508_LOOP_DETECTED", + "HTTP_510_NOT_EXTENDED", + "HTTP_511_NETWORK_AUTHENTICATION_REQUIRED", + "WS_1000_NORMAL_CLOSURE", + "WS_1001_GOING_AWAY", + "WS_1002_PROTOCOL_ERROR", + "WS_1003_UNSUPPORTED_DATA", + "WS_1005_NO_STATUS_RECEIVED", + "WS_1006_ABNORMAL_CLOSURE", + "WS_1007_INVALID_FRAME_PAYLOAD_DATA", + "WS_1008_POLICY_VIOLATION", + "WS_1009_MESSAGE_TOO_BIG", + "WS_1010_MANDATORY_EXT", + "WS_1011_INTERNAL_ERROR", + "WS_1012_SERVICE_RESTART", + "WS_1013_TRY_AGAIN_LATER", + "WS_1014_BAD_GATEWAY", + "WS_1015_TLS_HANDSHAKE", +) diff --git a/venv/lib/python3.11/site-packages/litestar/stores/__init__.py b/venv/lib/python3.11/site-packages/litestar/stores/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/stores/__init__.py diff --git a/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..78604d0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..21d2fb4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/file.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/file.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b39333c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/file.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/memory.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/memory.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c8e3b05 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/memory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/redis.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/redis.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..21e7e2f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/redis.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/registry.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/registry.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c87c31a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/stores/__pycache__/registry.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/stores/base.py b/venv/lib/python3.11/site-packages/litestar/stores/base.py new file mode 100644 index 0000000..34aa514 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/stores/base.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Optional + +from msgspec import Struct +from msgspec.msgpack import decode as msgpack_decode +from msgspec.msgpack import encode as msgpack_encode + +if TYPE_CHECKING: + from types import TracebackType + + from typing_extensions import Self + + +__all__ = ("Store", "NamespacedStore", "StorageObject") + + +class Store(ABC): + """Thread and process safe asynchronous key/value store.""" + + @abstractmethod + async def set(self, key: str, value: str | bytes, expires_in: int | timedelta | None = None) -> None: + """Set a value. + + Args: + key: Key to associate the value with + value: Value to store + expires_in: Time in seconds before the key is considered expired + + Returns: + ``None`` + """ + raise NotImplementedError + + @abstractmethod + async def get(self, key: str, renew_for: int | timedelta | None = None) -> bytes | None: + """Get a value. + + Args: + key: Key associated with the value + renew_for: If given and the value had an initial expiry time set, renew the + expiry time for ``renew_for`` seconds. If the value has not been set + with an expiry time this is a no-op + + Returns: + The value associated with ``key`` if it exists and is not expired, else + ``None`` + """ + raise NotImplementedError + + @abstractmethod + async def delete(self, key: str) -> None: + """Delete a value. + + If no such key exists, this is a no-op. + + Args: + key: Key of the value to delete + """ + raise NotImplementedError + + @abstractmethod + async def delete_all(self) -> None: + """Delete all stored values.""" + raise NotImplementedError + + @abstractmethod + async def exists(self, key: str) -> bool: + """Check if a given ``key`` exists.""" + raise NotImplementedError + + @abstractmethod + async def expires_in(self, key: str) -> int | None: + """Get the time in seconds ``key`` expires in. If no such ``key`` exists or no + expiry time was set, return ``None``. + """ + raise NotImplementedError + + async def __aenter__(self) -> None: # noqa: B027 + pass + + async def __aexit__( # noqa: B027 + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + pass + + +class NamespacedStore(Store): + """A subclass of :class:`Store`, offering hierarchical namespacing. + + Bulk actions on a parent namespace should affect all child namespaces, whereas other operations on all namespaces + should be isolated. + """ + + @abstractmethod + def with_namespace(self, namespace: str) -> Self: + """Return a new instance of :class:`NamespacedStore`, which exists in a child namespace of the current namespace. + Bulk actions on the parent namespace should affect all child namespaces, whereas other operations on all + namespaces should be isolated. + """ + + +class StorageObject(Struct): + """:class:`msgspec.Struct` to store serialized data alongside with their expiry time.""" + + expires_at: Optional[datetime] # noqa: UP007 + data: bytes + + @classmethod + def new(cls, data: bytes, expires_in: int | timedelta | None) -> StorageObject: + """Construct a new :class:`StorageObject` instance.""" + if expires_in is not None and not isinstance(expires_in, timedelta): + expires_in = timedelta(seconds=expires_in) + return cls( + data=data, + expires_at=(datetime.now(tz=timezone.utc) + expires_in) if expires_in else None, + ) + + @property + def expired(self) -> bool: + """Return if the :class:`StorageObject` is expired""" + return self.expires_at is not None and datetime.now(tz=timezone.utc) >= self.expires_at + + @property + def expires_in(self) -> int: + """Return the expiry time of this ``StorageObject`` in seconds. If no expiry time + was set, return ``-1``. + """ + if self.expires_at: + return int(self.expires_at.timestamp() - datetime.now(tz=timezone.utc).timestamp()) + return -1 + + def to_bytes(self) -> bytes: + """Encode the instance to bytes""" + return msgpack_encode(self) + + @classmethod + def from_bytes(cls, raw: bytes) -> StorageObject: + """Load a previously encoded with :meth:`StorageObject.to_bytes`""" + return msgpack_decode(raw, type=cls) diff --git a/venv/lib/python3.11/site-packages/litestar/stores/file.py b/venv/lib/python3.11/site-packages/litestar/stores/file.py new file mode 100644 index 0000000..25c52eb --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/stores/file.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import os +import shutil +import unicodedata +from tempfile import mkstemp +from typing import TYPE_CHECKING + +from anyio import Path + +from litestar.concurrency import sync_to_thread + +from .base import NamespacedStore, StorageObject + +__all__ = ("FileStore",) + + +if TYPE_CHECKING: + from datetime import timedelta + from os import PathLike + + +def _safe_file_name(name: str) -> str: + name = unicodedata.normalize("NFKD", name) + return "".join(c if c.isalnum() else str(ord(c)) for c in name) + + +class FileStore(NamespacedStore): + """File based, thread and process safe, asynchronous key/value store.""" + + __slots__ = {"path": "file path"} + + def __init__(self, path: PathLike[str]) -> None: + """Initialize ``FileStorage``. + + Args: + path: Path to store data under + """ + self.path = Path(path) + + def with_namespace(self, namespace: str) -> FileStore: + """Return a new instance of :class:`FileStore`, using a sub-path of the current store's path.""" + if not namespace.isalnum(): + raise ValueError(f"Invalid namespace: {namespace!r}") + return FileStore(self.path / namespace) + + def _path_from_key(self, key: str) -> Path: + return self.path / _safe_file_name(key) + + @staticmethod + async def _load_from_path(path: Path) -> StorageObject | None: + try: + data = await path.read_bytes() + return StorageObject.from_bytes(data) + except FileNotFoundError: + return None + + def _write_sync(self, target_file: Path, storage_obj: StorageObject) -> None: + try: + tmp_file_fd, tmp_file_name = mkstemp(dir=self.path, prefix=f"{target_file.name}.tmp") + renamed = False + try: + try: + os.write(tmp_file_fd, storage_obj.to_bytes()) + finally: + os.close(tmp_file_fd) + + os.replace(tmp_file_name, target_file) # noqa: PTH105 + renamed = True + finally: + if not renamed: + os.unlink(tmp_file_name) # noqa: PTH108 + except OSError: + pass + + async def _write(self, target_file: Path, storage_obj: StorageObject) -> None: + await sync_to_thread(self._write_sync, target_file, storage_obj) + + async def set(self, key: str, value: str | bytes, expires_in: int | timedelta | None = None) -> None: + """Set a value. + + Args: + key: Key to associate the value with + value: Value to store + expires_in: Time in seconds before the key is considered expired + + Returns: + ``None`` + """ + + await self.path.mkdir(exist_ok=True) + path = self._path_from_key(key) + if isinstance(value, str): + value = value.encode("utf-8") + storage_obj = StorageObject.new(data=value, expires_in=expires_in) + await self._write(path, storage_obj) + + async def get(self, key: str, renew_for: int | timedelta | None = None) -> bytes | None: + """Get a value. + + Args: + key: Key associated with the value + renew_for: If given and the value had an initial expiry time set, renew the + expiry time for ``renew_for`` seconds. If the value has not been set + with an expiry time this is a no-op + + Returns: + The value associated with ``key`` if it exists and is not expired, else + ``None`` + """ + path = self._path_from_key(key) + storage_obj = await self._load_from_path(path) + + if not storage_obj: + return None + + if storage_obj.expired: + await path.unlink(missing_ok=True) + return None + + if renew_for and storage_obj.expires_at: + await self.set(key, value=storage_obj.data, expires_in=renew_for) + + return storage_obj.data + + async def delete(self, key: str) -> None: + """Delete a value. + + If no such key exists, this is a no-op. + + Args: + key: Key of the value to delete + """ + path = self._path_from_key(key) + await path.unlink(missing_ok=True) + + async def delete_all(self) -> None: + """Delete all stored values. + + Note: + This deletes and recreates :attr:`FileStore.path` + """ + + await sync_to_thread(shutil.rmtree, self.path) + await self.path.mkdir(exist_ok=True) + + async def delete_expired(self) -> None: + """Delete expired items. + + Since expired items are normally only cleared on access (i.e. when calling + :meth:`.get`), this method should be called in regular intervals + to free disk space. + """ + async for file in self.path.iterdir(): + wrapper = await self._load_from_path(file) + if wrapper and wrapper.expired: + await file.unlink(missing_ok=True) + + async def exists(self, key: str) -> bool: + """Check if a given ``key`` exists.""" + path = self._path_from_key(key) + return await path.exists() + + async def expires_in(self, key: str) -> int | None: + """Get the time in seconds ``key`` expires in. If no such ``key`` exists or no + expiry time was set, return ``None``. + """ + if storage_obj := await self._load_from_path(self._path_from_key(key)): + return storage_obj.expires_in + return None diff --git a/venv/lib/python3.11/site-packages/litestar/stores/memory.py b/venv/lib/python3.11/site-packages/litestar/stores/memory.py new file mode 100644 index 0000000..1da8931 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/stores/memory.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import anyio +from anyio import Lock + +from .base import StorageObject, Store + +__all__ = ("MemoryStore",) + + +if TYPE_CHECKING: + from datetime import timedelta + + +class MemoryStore(Store): + """In memory, atomic, asynchronous key/value store.""" + + __slots__ = ("_store", "_lock") + + def __init__(self) -> None: + """Initialize :class:`MemoryStore`""" + self._store: dict[str, StorageObject] = {} + self._lock = Lock() + + async def set(self, key: str, value: str | bytes, expires_in: int | timedelta | None = None) -> None: + """Set a value. + + Args: + key: Key to associate the value with + value: Value to store + expires_in: Time in seconds before the key is considered expired + + Returns: + ``None`` + """ + if isinstance(value, str): + value = value.encode("utf-8") + async with self._lock: + self._store[key] = StorageObject.new(data=value, expires_in=expires_in) + + async def get(self, key: str, renew_for: int | timedelta | None = None) -> bytes | None: + """Get a value. + + Args: + key: Key associated with the value + renew_for: If given and the value had an initial expiry time set, renew the + expiry time for ``renew_for`` seconds. If the value has not been set + with an expiry time this is a no-op + + Returns: + The value associated with ``key`` if it exists and is not expired, else + ``None`` + """ + async with self._lock: + storage_obj = self._store.get(key) + + if not storage_obj: + return None + + if storage_obj.expired: + self._store.pop(key) + return None + + if renew_for and storage_obj.expires_at: + # don't use .set() here, so we can hold onto the lock for the whole operation + storage_obj = StorageObject.new(data=storage_obj.data, expires_in=renew_for) + self._store[key] = storage_obj + + return storage_obj.data + + async def delete(self, key: str) -> None: + """Delete a value. + + If no such key exists, this is a no-op. + + Args: + key: Key of the value to delete + """ + async with self._lock: + self._store.pop(key, None) + + async def delete_all(self) -> None: + """Delete all stored values.""" + async with self._lock: + self._store.clear() + + async def delete_expired(self) -> None: + """Delete expired items. + + Since expired items are normally only cleared on access (i.e. when calling + :meth:`.get`), this method should be called in regular intervals + to free memory. + """ + async with self._lock: + new_store = {} + for i, (key, storage_obj) in enumerate(self._store.items()): + if not storage_obj.expired: + new_store[key] = storage_obj + if i % 1000 == 0: + await anyio.sleep(0) + self._store = new_store + + async def exists(self, key: str) -> bool: + """Check if a given ``key`` exists.""" + return key in self._store + + async def expires_in(self, key: str) -> int | None: + """Get the time in seconds ``key`` expires in. If no such ``key`` exists or no + expiry time was set, return ``None``. + """ + if storage_obj := self._store.get(key): + return storage_obj.expires_in + return None diff --git a/venv/lib/python3.11/site-packages/litestar/stores/redis.py b/venv/lib/python3.11/site-packages/litestar/stores/redis.py new file mode 100644 index 0000000..6697962 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/stores/redis.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import TYPE_CHECKING, cast + +from redis.asyncio import Redis +from redis.asyncio.connection import ConnectionPool + +from litestar.exceptions import ImproperlyConfiguredException +from litestar.types import Empty, EmptyType +from litestar.utils.empty import value_or_default + +from .base import NamespacedStore + +__all__ = ("RedisStore",) + +if TYPE_CHECKING: + from types import TracebackType + + +class RedisStore(NamespacedStore): + """Redis based, thread and process safe asynchronous key/value store.""" + + __slots__ = ("_redis",) + + def __init__( + self, redis: Redis, namespace: str | None | EmptyType = Empty, handle_client_shutdown: bool = False + ) -> None: + """Initialize :class:`RedisStore` + + Args: + redis: An :class:`redis.asyncio.Redis` instance + namespace: A key prefix to simulate a namespace in redis. If not given, + defaults to ``LITESTAR``. Namespacing can be explicitly disabled by passing + ``None``. This will make :meth:`.delete_all` unavailable. + handle_client_shutdown: If ``True``, handle the shutdown of the `redis` instance automatically during the store's lifespan. Should be set to `True` unless the shutdown is handled externally + """ + self._redis = redis + self.namespace: str | None = value_or_default(namespace, "LITESTAR") + self.handle_client_shutdown = handle_client_shutdown + + # script to get and renew a key in one atomic step + self._get_and_renew_script = self._redis.register_script( + b""" + local key = KEYS[1] + local renew = tonumber(ARGV[1]) + + local data = redis.call('GET', key) + local ttl = redis.call('TTL', key) + + if ttl > 0 then + redis.call('EXPIRE', key, renew) + end + + return data + """ + ) + + # script to delete all keys in the namespace + self._delete_all_script = self._redis.register_script( + b""" + local cursor = 0 + + repeat + local result = redis.call('SCAN', cursor, 'MATCH', ARGV[1]) + for _,key in ipairs(result[2]) do + redis.call('UNLINK', key) + end + cursor = tonumber(result[1]) + until cursor == 0 + """ + ) + + async def _shutdown(self) -> None: + if self.handle_client_shutdown: + await self._redis.aclose(close_connection_pool=True) # type: ignore[attr-defined] + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self._shutdown() + + @classmethod + def with_client( + cls, + url: str = "redis://localhost:6379", + *, + db: int | None = None, + port: int | None = None, + username: str | None = None, + password: str | None = None, + namespace: str | None | EmptyType = Empty, + ) -> RedisStore: + """Initialize a :class:`RedisStore` instance with a new class:`redis.asyncio.Redis` instance. + + Args: + url: Redis URL to connect to + db: Redis database to use + port: Redis port to use + username: Redis username to use + password: Redis password to use + namespace: Virtual key namespace to use + """ + pool = ConnectionPool.from_url( + url=url, + db=db, + decode_responses=False, + port=port, + username=username, + password=password, + ) + return cls( + redis=Redis(connection_pool=pool), + namespace=namespace, + handle_client_shutdown=True, + ) + + def with_namespace(self, namespace: str) -> RedisStore: + """Return a new :class:`RedisStore` with a nested virtual key namespace. + The current instances namespace will serve as a prefix for the namespace, so it + can be considered the parent namespace. + """ + return type(self)( + redis=self._redis, + namespace=f"{self.namespace}_{namespace}" if self.namespace else namespace, + handle_client_shutdown=self.handle_client_shutdown, + ) + + def _make_key(self, key: str) -> str: + prefix = f"{self.namespace}:" if self.namespace else "" + return prefix + key + + async def set(self, key: str, value: str | bytes, expires_in: int | timedelta | None = None) -> None: + """Set a value. + + Args: + key: Key to associate the value with + value: Value to store + expires_in: Time in seconds before the key is considered expired + + Returns: + ``None`` + """ + if isinstance(value, str): + value = value.encode("utf-8") + await self._redis.set(self._make_key(key), value, ex=expires_in) + + async def get(self, key: str, renew_for: int | timedelta | None = None) -> bytes | None: + """Get a value. + + Args: + key: Key associated with the value + renew_for: If given and the value had an initial expiry time set, renew the + expiry time for ``renew_for`` seconds. If the value has not been set + with an expiry time this is a no-op. Atomicity of this step is guaranteed + by using a lua script to execute fetch and renewal. If ``renew_for`` is + not given, the script will be bypassed so no overhead will occur + + Returns: + The value associated with ``key`` if it exists and is not expired, else + ``None`` + """ + key = self._make_key(key) + if renew_for: + if isinstance(renew_for, timedelta): + renew_for = renew_for.seconds + data = await self._get_and_renew_script(keys=[key], args=[renew_for]) + return cast("bytes | None", data) + return await self._redis.get(key) + + async def delete(self, key: str) -> None: + """Delete a value. + + If no such key exists, this is a no-op. + + Args: + key: Key of the value to delete + """ + await self._redis.delete(self._make_key(key)) + + async def delete_all(self) -> None: + """Delete all stored values in the virtual key namespace. + + Raises: + ImproperlyConfiguredException: If no namespace was configured + """ + if not self.namespace: + raise ImproperlyConfiguredException("Cannot perform delete operation: No namespace configured") + + await self._delete_all_script(keys=[], args=[f"{self.namespace}*:*"]) + + async def exists(self, key: str) -> bool: + """Check if a given ``key`` exists.""" + return await self._redis.exists(self._make_key(key)) == 1 + + async def expires_in(self, key: str) -> int | None: + """Get the time in seconds ``key`` expires in. If no such ``key`` exists or no + expiry time was set, return ``None``. + """ + ttl = await self._redis.ttl(self._make_key(key)) + return None if ttl == -2 else ttl diff --git a/venv/lib/python3.11/site-packages/litestar/stores/registry.py b/venv/lib/python3.11/site-packages/litestar/stores/registry.py new file mode 100644 index 0000000..11a08c2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/stores/registry.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from .base import Store + + +from .memory import MemoryStore + +__all__ = ("StoreRegistry",) + + +def default_default_factory(name: str) -> Store: + return MemoryStore() + + +class StoreRegistry: + """Registry for :class:`Store <.base.Store>` instances.""" + + __slots__ = ("_stores", "_default_factory") + + def __init__( + self, stores: dict[str, Store] | None = None, default_factory: Callable[[str], Store] = default_default_factory + ) -> None: + """Initialize ``StoreRegistry``. + + Args: + stores: A dictionary mapping store names to stores, used to initialize the registry + default_factory: A callable used by :meth:`StoreRegistry.get` to provide a store, if the requested name hasn't + been registered yet. This callable receives the requested name and should return a + :class:`Store <.base.Store>` instance. + """ + self._stores = stores or {} + self._default_factory = default_factory + + def register(self, name: str, store: Store, allow_override: bool = False) -> None: + """Register a new :class:`Store <.base.Store>`. + + Args: + name: Name to register the store under + store: The store to register + allow_override: Whether to allow overriding an existing store of the same name + + Raises: + ValueError: If a store is already registered under this name and ``override`` is not ``True`` + """ + if not allow_override and name in self._stores: + raise ValueError(f"Store with the name {name!r} already exists") + self._stores[name] = store + + def get(self, name: str) -> Store: + """Get a store registered under ``name``. If no such store is registered, create a store using the default + factory with ``name`` and register the returned store under ``name``. + + Args: + name: Name of the store + + Returns: + A :class:`Store <.base.Store>` + """ + if not self._stores.get(name): + self._stores[name] = self._default_factory(name) + return self._stores[name] diff --git a/venv/lib/python3.11/site-packages/litestar/template/__init__.py b/venv/lib/python3.11/site-packages/litestar/template/__init__.py new file mode 100644 index 0000000..5989cea --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/template/__init__.py @@ -0,0 +1,4 @@ +from litestar.template.base import TemplateEngineProtocol, TemplateProtocol +from litestar.template.config import TemplateConfig + +__all__ = ("TemplateEngineProtocol", "TemplateProtocol", "TemplateConfig") diff --git a/venv/lib/python3.11/site-packages/litestar/template/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/template/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..7270a06 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/template/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/template/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/template/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f69e72d --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/template/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/template/__pycache__/config.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/template/__pycache__/config.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..536a922 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/template/__pycache__/config.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/template/base.py b/venv/lib/python3.11/site-packages/litestar/template/base.py new file mode 100644 index 0000000..3474717 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/template/base.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Mapping, Protocol, TypedDict, TypeVar, cast, runtime_checkable + +from typing_extensions import Concatenate, ParamSpec, TypeAlias + +from litestar.utils.deprecation import warn_deprecation +from litestar.utils.empty import value_or_default +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from pathlib import Path + + from litestar.connection import Request + +__all__ = ( + "TemplateCallableType", + "TemplateEngineProtocol", + "TemplateProtocol", + "csrf_token", + "url_for", + "url_for_static_asset", +) + + +def _get_request_from_context(context: Mapping[str, Any]) -> Request: + """Get the request from the template context. + + Args: + context: The template context. + + Returns: + The request object. + """ + return cast("Request", context["request"]) + + +def url_for(context: Mapping[str, Any], /, route_name: str, **path_parameters: Any) -> str: + """Wrap :func:`route_reverse <litestar.app.route_reverse>` to be used in templates. + + Args: + context: The template context. + route_name: The name of the route handler. + **path_parameters: Actual values for path parameters in the route. + + Raises: + NoRouteMatchFoundException: If ``route_name`` does not exist, path parameters are missing in **path_parameters + or have wrong type. + + Returns: + A fully formatted url path. + """ + return _get_request_from_context(context).app.route_reverse(route_name, **path_parameters) + + +def csrf_token(context: Mapping[str, Any], /) -> str: + """Set a CSRF token on the template. + + Notes: + - to use this function make sure to pass an instance of :ref:`CSRFConfig <litestar.config.csrf_config.CSRFConfig>` to + the :class:`Litestar <litestar.app.Litestar>` constructor. + + Args: + context: The template context. + + Returns: + A CSRF token if the app level ``csrf_config`` is set, otherwise an empty string. + """ + scope = _get_request_from_context(context).scope + return value_or_default(ScopeState.from_scope(scope).csrf_token, "") + + +def url_for_static_asset(context: Mapping[str, Any], /, name: str, file_path: str) -> str: + """Wrap :meth:`url_for_static_asset <litestar.app.url_for_static_asset>` to be used in templates. + + Args: + context: The template context object. + name: A static handler unique name. + file_path: a string containing path to an asset. + + Raises: + NoRouteMatchFoundException: If static files handler with ``name`` does not exist. + + Returns: + A url path to the asset. + """ + return _get_request_from_context(context).app.url_for_static_asset(name, file_path) + + +class TemplateProtocol(Protocol): + """Protocol Defining a ``Template``. + + Template is a class that has a render method which renders the template into a string. + """ + + def render(self, *args: Any, **kwargs: Any) -> str: + """Return the rendered template as a string. + + Args: + *args: Positional arguments passed to the TemplateEngine + **kwargs: A string keyed mapping of values passed to the TemplateEngine + + Returns: + The rendered template string + """ + raise NotImplementedError + + +P = ParamSpec("P") +R = TypeVar("R") +ContextType = TypeVar("ContextType") +ContextType_co = TypeVar("ContextType_co", covariant=True) +TemplateType_co = TypeVar("TemplateType_co", bound=TemplateProtocol, covariant=True) +TemplateCallableType: TypeAlias = Callable[Concatenate[ContextType, P], R] + + +@runtime_checkable +class TemplateEngineProtocol(Protocol[TemplateType_co, ContextType_co]): + """Protocol for template engines.""" + + def __init__(self, directory: Path | list[Path] | None, engine_instance: Any | None) -> None: + """Initialize the template engine with a directory. + + Args: + directory: Direct path or list of directory paths from which to serve templates, if provided the + implementation has to create the engine instance. + engine_instance: A template engine object, if provided the implementation has to use it. + """ + + def get_template(self, template_name: str) -> TemplateType_co: + """Retrieve a template by matching its name (dotted path) with files in the directory or directories provided. + + Args: + template_name: A dotted path + + Returns: + Template instance + + Raises: + TemplateNotFoundException: if no template is found. + """ + raise NotImplementedError + + def render_string(self, template_string: str, context: Mapping[str, Any]) -> str: + """Render a template from a string with the given context. + + Args: + template_string: The template string to render. + context: A dictionary of variables to pass to the template. + + Returns: + The rendered template as a string. + """ + raise NotImplementedError + + def register_template_callable( + self, key: str, template_callable: TemplateCallableType[ContextType_co, P, R] + ) -> None: + """Register a callable on the template engine. + + Args: + key: The callable key, i.e. the value to use inside the template to call the callable. + template_callable: A callable to register. + + Returns: + None + """ + + +class _TemplateContext(TypedDict): + """Dictionary representing a template context.""" + + request: Request[Any, Any, Any] + csrf_input: str + + +def __getattr__(name: str) -> Any: + if name == "TemplateContext": + warn_deprecation( + "2.3.0", + "TemplateContext", + "import", + removal_in="3.0.0", + alternative="Mapping", + ) + return _TemplateContext + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/venv/lib/python3.11/site-packages/litestar/template/config.py b/venv/lib/python3.11/site-packages/litestar/template/config.py new file mode 100644 index 0000000..d2aa87c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/template/config.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from functools import cached_property +from inspect import isclass +from typing import TYPE_CHECKING, Callable, Generic, TypeVar, cast + +from litestar.exceptions import ImproperlyConfiguredException +from litestar.template import TemplateEngineProtocol + +__all__ = ("TemplateConfig",) + +if TYPE_CHECKING: + from litestar.types import PathType + +EngineType = TypeVar("EngineType", bound=TemplateEngineProtocol) + + +@dataclass +class TemplateConfig(Generic[EngineType]): + """Configuration for Templating. + + To enable templating, pass an instance of this class to the :class:`Litestar <litestar.app.Litestar>` constructor using the + 'template_config' key. + """ + + engine: type[EngineType] | EngineType | None = field(default=None) + """A template engine adhering to the :class:`TemplateEngineProtocol <litestar.template.base.TemplateEngineProtocol>`.""" + directory: PathType | list[PathType] | None = field(default=None) + """A directory or list of directories from which to serve templates.""" + engine_callback: Callable[[EngineType], None] | None = field(default=None) + """A callback function that allows modifying the instantiated templating protocol.""" + instance: EngineType | None = field(default=None) + """An instance of the templating protocol.""" + + def __post_init__(self) -> None: + """Ensure that directory is set if engine is a class.""" + if isclass(self.engine) and not self.directory: + raise ImproperlyConfiguredException("directory is a required kwarg when passing a template engine class") + """Ensure that directory is not set if instance is.""" + if self.instance is not None and self.directory is not None: + raise ImproperlyConfiguredException("directory cannot be set if instance is") + + def to_engine(self) -> EngineType: + """Instantiate the template engine.""" + template_engine = cast( + "EngineType", + self.engine(directory=self.directory, engine_instance=None) if isclass(self.engine) else self.engine, + ) + if callable(self.engine_callback): + self.engine_callback(template_engine) + return template_engine + + @cached_property + def engine_instance(self) -> EngineType: + """Return the template engine instance.""" + return self.to_engine() if self.instance is None else self.instance diff --git a/venv/lib/python3.11/site-packages/litestar/testing/__init__.py b/venv/lib/python3.11/site-packages/litestar/testing/__init__.py new file mode 100644 index 0000000..55af446 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/__init__.py @@ -0,0 +1,16 @@ +from litestar.testing.client.async_client import AsyncTestClient +from litestar.testing.client.base import BaseTestClient +from litestar.testing.client.sync_client import TestClient +from litestar.testing.helpers import create_async_test_client, create_test_client +from litestar.testing.request_factory import RequestFactory +from litestar.testing.websocket_test_session import WebSocketTestSession + +__all__ = ( + "AsyncTestClient", + "BaseTestClient", + "create_async_test_client", + "create_test_client", + "RequestFactory", + "TestClient", + "WebSocketTestSession", +) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..77c7908 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/helpers.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/helpers.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a85995f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/helpers.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/life_span_handler.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/life_span_handler.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b3ea9ef --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/life_span_handler.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/request_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/request_factory.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..9a66826 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/request_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/transport.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/transport.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..78a1aa6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/transport.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/websocket_test_session.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/websocket_test_session.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6b00590 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/websocket_test_session.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__init__.py b/venv/lib/python3.11/site-packages/litestar/testing/client/__init__.py new file mode 100644 index 0000000..5d03a7a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__init__.py @@ -0,0 +1,36 @@ +"""Some code in this module was adapted from https://github.com/encode/starlette/blob/master/starlette/testclient.py. + +Copyright © 2018, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +from .async_client import AsyncTestClient +from .base import BaseTestClient +from .sync_client import TestClient + +__all__ = ("TestClient", "AsyncTestClient", "BaseTestClient") diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..18ad148 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/async_client.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/async_client.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1ccc805 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/async_client.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..87d5de7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/sync_client.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/sync_client.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..29f0576 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/sync_client.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/async_client.py b/venv/lib/python3.11/site-packages/litestar/testing/client/async_client.py new file mode 100644 index 0000000..cf66f12 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/async_client.py @@ -0,0 +1,534 @@ +from __future__ import annotations + +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Any, Generic, Mapping, TypeVar + +from httpx import USE_CLIENT_DEFAULT, AsyncClient, Response + +from litestar import HttpMethod +from litestar.testing.client.base import BaseTestClient +from litestar.testing.life_span_handler import LifeSpanHandler +from litestar.testing.transport import TestClientTransport +from litestar.types import AnyIOBackend, ASGIApp + +if TYPE_CHECKING: + from httpx._client import UseClientDefault + from httpx._types import ( + AuthTypes, + CookieTypes, + HeaderTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestFiles, + TimeoutTypes, + URLTypes, + ) + from typing_extensions import Self + + from litestar.middleware.session.base import BaseBackendConfig + + +T = TypeVar("T", bound=ASGIApp) + + +class AsyncTestClient(AsyncClient, BaseTestClient, Generic[T]): # type: ignore[misc] + lifespan_handler: LifeSpanHandler[Any] + exit_stack: AsyncExitStack + + def __init__( + self, + app: T, + base_url: str = "http://testserver.local", + raise_server_exceptions: bool = True, + root_path: str = "", + backend: AnyIOBackend = "asyncio", + backend_options: Mapping[str, Any] | None = None, + session_config: BaseBackendConfig | None = None, + timeout: float | None = None, + cookies: CookieTypes | None = None, + ) -> None: + """An Async client implementation providing a context manager for testing applications asynchronously. + + Args: + app: The instance of :class:`Litestar <litestar.app.Litestar>` under test. + base_url: URL scheme and domain for test request paths, e.g. ``http://testserver``. + raise_server_exceptions: Flag for the underlying test client to raise server exceptions instead of + wrapping them in an HTTP response. + root_path: Path prefix for requests. + backend: The async backend to use, options are "asyncio" or "trio". + backend_options: 'anyio' options. + session_config: Configuration for Session Middleware class to create raw session cookies for request to the + route handlers. + timeout: Request timeout + cookies: Cookies to set on the client. + """ + BaseTestClient.__init__( + self, + app=app, + base_url=base_url, + backend=backend, + backend_options=backend_options, + session_config=session_config, + cookies=cookies, + ) + AsyncClient.__init__( + self, + base_url=base_url, + headers={"user-agent": "testclient"}, + follow_redirects=True, + cookies=cookies, + transport=TestClientTransport( # type: ignore [arg-type] + client=self, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + ), + timeout=timeout, + ) + + async def __aenter__(self) -> Self: + async with AsyncExitStack() as stack: + self.blocking_portal = portal = stack.enter_context(self.portal()) + self.lifespan_handler = LifeSpanHandler(client=self) + + @stack.callback + def reset_portal() -> None: + delattr(self, "blocking_portal") + + @stack.callback + def wait_shutdown() -> None: + portal.call(self.lifespan_handler.wait_shutdown) + + self.exit_stack = stack.pop_all() + return self + + async def __aexit__(self, *args: Any) -> None: + await self.exit_stack.aclose() + + async def request( + self, + method: str, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a request. + + Args: + method: An HTTP method. + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.request( + self, + url=self.base_url.join(url), + method=method.value if isinstance(method, HttpMethod) else method, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def get( # type: ignore [override] + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a GET request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.get( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def options( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends an OPTIONS request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.options( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def head( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a HEAD request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.head( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def post( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a POST request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.post( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def put( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a PUT request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.put( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def patch( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a PATCH request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.patch( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def delete( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a DELETE request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.delete( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def get_session_data(self) -> dict[str, Any]: + """Get session data. + + Returns: + A dictionary containing session data. + + Examples: + .. code-block:: python + + from litestar import Litestar, post + from litestar.middleware.session.memory_backend import MemoryBackendConfig + + session_config = MemoryBackendConfig() + + + @post(path="/test") + def set_session_data(request: Request) -> None: + request.session["foo"] == "bar" + + + app = Litestar( + route_handlers=[set_session_data], middleware=[session_config.middleware] + ) + + async with AsyncTestClient(app=app, session_config=session_config) as client: + await client.post("/test") + assert await client.get_session_data() == {"foo": "bar"} + + """ + return await super()._get_session_data() + + async def set_session_data(self, data: dict[str, Any]) -> None: + """Set session data. + + Args: + data: Session data + + Returns: + None + + Examples: + .. code-block:: python + + from litestar import Litestar, get + from litestar.middleware.session.memory_backend import MemoryBackendConfig + + session_config = MemoryBackendConfig() + + + @get(path="/test") + def get_session_data(request: Request) -> Dict[str, Any]: + return request.session + + + app = Litestar( + route_handlers=[get_session_data], middleware=[session_config.middleware] + ) + + async with AsyncTestClient(app=app, session_config=session_config) as client: + await client.set_session_data({"foo": "bar"}) + assert await client.get("/test").json() == {"foo": "bar"} + + """ + return await super()._set_session_data(data) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/base.py b/venv/lib/python3.11/site-packages/litestar/testing/client/base.py new file mode 100644 index 0000000..3c25be1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/base.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +from contextlib import contextmanager +from http.cookiejar import CookieJar +from typing import TYPE_CHECKING, Any, Generator, Generic, Mapping, TypeVar, cast +from warnings import warn + +from anyio.from_thread import BlockingPortal, start_blocking_portal +from httpx import Cookies, Request, Response + +from litestar import Litestar +from litestar.connection import ASGIConnection +from litestar.datastructures import MutableScopeHeaders +from litestar.enums import ScopeType +from litestar.exceptions import ( + ImproperlyConfiguredException, +) +from litestar.types import AnyIOBackend, ASGIApp, HTTPResponseStartEvent +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from httpx._types import CookieTypes + + from litestar.middleware.session.base import BaseBackendConfig, BaseSessionBackend + from litestar.types.asgi_types import HTTPScope, Receive, Scope, Send + +T = TypeVar("T", bound=ASGIApp) + + +def fake_http_send_message(headers: MutableScopeHeaders) -> HTTPResponseStartEvent: + headers.setdefault("content-type", "application/text") + return HTTPResponseStartEvent(type="http.response.start", status=200, headers=headers.headers) + + +def fake_asgi_connection(app: ASGIApp, cookies: dict[str, str]) -> ASGIConnection[Any, Any, Any, Any]: + scope: HTTPScope = { + "type": ScopeType.HTTP, + "path": "/", + "raw_path": b"/", + "root_path": "", + "scheme": "http", + "query_string": b"", + "client": ("testclient", 50000), + "server": ("testserver", 80), + "headers": [], + "method": "GET", + "http_version": "1.1", + "extensions": {"http.response.template": {}}, + "app": app, # type: ignore[typeddict-item] + "state": {}, + "path_params": {}, + "route_handler": None, # type: ignore[typeddict-item] + "asgi": {"version": "3.0", "spec_version": "2.1"}, + "auth": None, + "session": None, + "user": None, + } + ScopeState.from_scope(scope).cookies = cookies + return ASGIConnection[Any, Any, Any, Any](scope=scope) + + +def _wrap_app_to_add_state(app: ASGIApp) -> ASGIApp: + """Wrap an ASGI app to add state to the scope. + + Litestar depends on `state` being present in the ASGI connection scope. Scope state is optional in the ASGI spec, + however, the Litestar app always ensures it is present so that it can be depended on internally. + + When the ASGI app that is passed to the test client is _not_ a Litestar app, we need to add + state to the scope, because httpx does not do this for us. + + This assists us in testing Litestar components that rely on state being present in the scope, without having + to create a Litestar app for every test case. + + Args: + app: The ASGI app to wrap. + + Returns: + The wrapped ASGI app. + """ + + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + scope["state"] = {} + await app(scope, receive, send) + + return wrapped + + +class BaseTestClient(Generic[T]): + __test__ = False + blocking_portal: BlockingPortal + + __slots__ = ( + "app", + "base_url", + "backend", + "backend_options", + "session_config", + "_session_backend", + "cookies", + ) + + def __init__( + self, + app: T, + base_url: str = "http://testserver.local", + backend: AnyIOBackend = "asyncio", + backend_options: Mapping[str, Any] | None = None, + session_config: BaseBackendConfig | None = None, + cookies: CookieTypes | None = None, + ) -> None: + if "." not in base_url: + warn( + f"The base_url {base_url!r} might cause issues. Try adding a domain name such as .local: " + f"'{base_url}.local'", + UserWarning, + stacklevel=1, + ) + + self._session_backend: BaseSessionBackend | None = None + if session_config: + self._session_backend = session_config._backend_class(config=session_config) + + if not isinstance(app, Litestar): + app = _wrap_app_to_add_state(app) # type: ignore[assignment] + + self.app = cast("T", app) # type: ignore[redundant-cast] # pyright needs this + + self.base_url = base_url + self.backend = backend + self.backend_options = backend_options + self.cookies = cookies + + @property + def session_backend(self) -> BaseSessionBackend[Any]: + if not self._session_backend: + raise ImproperlyConfiguredException( + "Session has not been initialized for this TestClient instance. You can" + "do so by passing a configuration object to TestClient: TestClient(app=app, session_config=...)" + ) + return self._session_backend + + @contextmanager + def portal(self) -> Generator[BlockingPortal, None, None]: + """Get a BlockingPortal. + + Returns: + A contextmanager for a BlockingPortal. + """ + if hasattr(self, "blocking_portal"): + yield self.blocking_portal + else: + with start_blocking_portal( + backend=self.backend, backend_options=dict(self.backend_options or {}) + ) as portal: + yield portal + + async def _set_session_data(self, data: dict[str, Any]) -> None: + mutable_headers = MutableScopeHeaders() + connection = fake_asgi_connection( + app=self.app, + cookies=dict(self.cookies), # type: ignore[arg-type] + ) + session_id = self.session_backend.get_session_id(connection) + connection._connection_state.session_id = session_id # pyright: ignore [reportGeneralTypeIssues] + await self.session_backend.store_in_message( + scope_session=data, message=fake_http_send_message(mutable_headers), connection=connection + ) + response = Response(200, request=Request("GET", self.base_url), headers=mutable_headers.headers) + + cookies = Cookies(CookieJar()) + cookies.extract_cookies(response) + self.cookies.update(cookies) # type: ignore[union-attr] + + async def _get_session_data(self) -> dict[str, Any]: + return await self.session_backend.load_from_connection( + connection=fake_asgi_connection( + app=self.app, + cookies=dict(self.cookies), # type: ignore[arg-type] + ), + ) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py b/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py new file mode 100644 index 0000000..d907056 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py @@ -0,0 +1,593 @@ +from __future__ import annotations + +from contextlib import ExitStack +from typing import TYPE_CHECKING, Any, Generic, Mapping, Sequence, TypeVar +from urllib.parse import urljoin + +from httpx import USE_CLIENT_DEFAULT, Client, Response + +from litestar import HttpMethod +from litestar.testing.client.base import BaseTestClient +from litestar.testing.life_span_handler import LifeSpanHandler +from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport +from litestar.types import AnyIOBackend, ASGIApp + +if TYPE_CHECKING: + from httpx._client import UseClientDefault + from httpx._types import ( + AuthTypes, + CookieTypes, + HeaderTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestFiles, + TimeoutTypes, + URLTypes, + ) + from typing_extensions import Self + + from litestar.middleware.session.base import BaseBackendConfig + from litestar.testing.websocket_test_session import WebSocketTestSession + + +T = TypeVar("T", bound=ASGIApp) + + +class TestClient(Client, BaseTestClient, Generic[T]): # type: ignore[misc] + lifespan_handler: LifeSpanHandler[Any] + exit_stack: ExitStack + + def __init__( + self, + app: T, + base_url: str = "http://testserver.local", + raise_server_exceptions: bool = True, + root_path: str = "", + backend: AnyIOBackend = "asyncio", + backend_options: Mapping[str, Any] | None = None, + session_config: BaseBackendConfig | None = None, + timeout: float | None = None, + cookies: CookieTypes | None = None, + ) -> None: + """A client implementation providing a context manager for testing applications. + + Args: + app: The instance of :class:`Litestar <litestar.app.Litestar>` under test. + base_url: URL scheme and domain for test request paths, e.g. ``http://testserver``. + raise_server_exceptions: Flag for the underlying test client to raise server exceptions instead of + wrapping them in an HTTP response. + root_path: Path prefix for requests. + backend: The async backend to use, options are "asyncio" or "trio". + backend_options: ``anyio`` options. + session_config: Configuration for Session Middleware class to create raw session cookies for request to the + route handlers. + timeout: Request timeout + cookies: Cookies to set on the client. + """ + BaseTestClient.__init__( + self, + app=app, + base_url=base_url, + backend=backend, + backend_options=backend_options, + session_config=session_config, + cookies=cookies, + ) + + Client.__init__( + self, + base_url=base_url, + headers={"user-agent": "testclient"}, + follow_redirects=True, + cookies=cookies, + transport=TestClientTransport( # type: ignore[arg-type] + client=self, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + ), + timeout=timeout, + ) + + def __enter__(self) -> Self: + with ExitStack() as stack: + self.blocking_portal = portal = stack.enter_context(self.portal()) + self.lifespan_handler = LifeSpanHandler(client=self) + + @stack.callback + def reset_portal() -> None: + delattr(self, "blocking_portal") + + @stack.callback + def wait_shutdown() -> None: + portal.call(self.lifespan_handler.wait_shutdown) + + self.exit_stack = stack.pop_all() + + return self + + def __exit__(self, *args: Any) -> None: + self.exit_stack.close() + + def request( + self, + method: str, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a request. + + Args: + method: An HTTP method. + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.request( + self, + url=self.base_url.join(url), + method=method.value if isinstance(method, HttpMethod) else method, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def get( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a GET request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.get( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def options( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends an OPTIONS request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.options( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def head( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a HEAD request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.head( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def post( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a POST request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.post( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def put( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a PUT request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.put( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def patch( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a PATCH request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.patch( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def delete( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a DELETE request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.delete( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def websocket_connect( + self, + url: str, + subprotocols: Sequence[str] | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> WebSocketTestSession: + """Sends a GET request to establish a websocket connection. + + Args: + url: Request URL. + subprotocols: Websocket subprotocols. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + A `WebSocketTestSession <litestar.testing.WebSocketTestSession>` instance. + """ + url = urljoin("ws://testserver", url) + default_headers: dict[str, str] = {} + default_headers.setdefault("connection", "upgrade") + default_headers.setdefault("sec-websocket-key", "testserver==") + default_headers.setdefault("sec-websocket-version", "13") + if subprotocols is not None: + default_headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) + try: + Client.request( + self, + "GET", + url, + headers={**dict(headers or {}), **default_headers}, # type: ignore[misc] + params=params, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + except ConnectionUpgradeExceptionError as exc: + return exc.session + + raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover + + def set_session_data(self, data: dict[str, Any]) -> None: + """Set session data. + + Args: + data: Session data + + Returns: + None + + Examples: + .. code-block:: python + + from litestar import Litestar, get + from litestar.middleware.session.memory_backend import MemoryBackendConfig + + session_config = MemoryBackendConfig() + + + @get(path="/test") + def get_session_data(request: Request) -> Dict[str, Any]: + return request.session + + + app = Litestar( + route_handlers=[get_session_data], middleware=[session_config.middleware] + ) + + with TestClient(app=app, session_config=session_config) as client: + client.set_session_data({"foo": "bar"}) + assert client.get("/test").json() == {"foo": "bar"} + + """ + with self.portal() as portal: + portal.call(self._set_session_data, data) + + def get_session_data(self) -> dict[str, Any]: + """Get session data. + + Returns: + A dictionary containing session data. + + Examples: + .. code-block:: python + + from litestar import Litestar, post + from litestar.middleware.session.memory_backend import MemoryBackendConfig + + session_config = MemoryBackendConfig() + + + @post(path="/test") + def set_session_data(request: Request) -> None: + request.session["foo"] == "bar" + + + app = Litestar( + route_handlers=[set_session_data], middleware=[session_config.middleware] + ) + + with TestClient(app=app, session_config=session_config) as client: + client.post("/test") + assert client.get_session_data() == {"foo": "bar"} + + """ + with self.portal() as portal: + return portal.call(self._get_session_data) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/helpers.py b/venv/lib/python3.11/site-packages/litestar/testing/helpers.py new file mode 100644 index 0000000..5ac59af --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/helpers.py @@ -0,0 +1,561 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, Sequence + +from litestar.app import DEFAULT_OPENAPI_CONFIG, Litestar +from litestar.controller import Controller +from litestar.events import SimpleEventEmitter +from litestar.testing.client import AsyncTestClient, TestClient +from litestar.types import Empty +from litestar.utils.predicates import is_class_and_subclass + +if TYPE_CHECKING: + from contextlib import AbstractAsyncContextManager + + from litestar import Request, Response, WebSocket + from litestar.config.allowed_hosts import AllowedHostsConfig + from litestar.config.app import ExperimentalFeatures + from litestar.config.compression import CompressionConfig + from litestar.config.cors import CORSConfig + from litestar.config.csrf import CSRFConfig + from litestar.config.response_cache import ResponseCacheConfig + from litestar.datastructures import CacheControlHeader, ETag, State + from litestar.dto import AbstractDTO + from litestar.events import BaseEventEmitterBackend, EventListener + from litestar.logging.config import BaseLoggingConfig + from litestar.middleware.session.base import BaseBackendConfig + from litestar.openapi.config import OpenAPIConfig + from litestar.openapi.spec import SecurityRequirement + from litestar.plugins import PluginProtocol + from litestar.static_files.config import StaticFilesConfig + from litestar.stores.base import Store + from litestar.stores.registry import StoreRegistry + from litestar.template.config import TemplateConfig + from litestar.types import ( + AfterExceptionHookHandler, + AfterRequestHookHandler, + AfterResponseHookHandler, + BeforeMessageSendHookHandler, + BeforeRequestHookHandler, + ControllerRouterHandler, + Dependencies, + EmptyType, + ExceptionHandlersMap, + Guard, + LifespanHook, + Middleware, + OnAppInitHandler, + ParametersMap, + ResponseCookies, + ResponseHeaders, + TypeEncodersMap, + ) + + +def create_test_client( + route_handlers: ControllerRouterHandler | Sequence[ControllerRouterHandler] | None = None, + *, + after_exception: Sequence[AfterExceptionHookHandler] | None = None, + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + allowed_hosts: Sequence[str] | AllowedHostsConfig | None = None, + backend: Literal["asyncio", "trio"] = "asyncio", + backend_options: Mapping[str, Any] | None = None, + base_url: str = "http://testserver.local", + before_request: BeforeRequestHookHandler | None = None, + before_send: Sequence[BeforeMessageSendHookHandler] | None = None, + cache_control: CacheControlHeader | None = None, + compression_config: CompressionConfig | None = None, + cors_config: CORSConfig | None = None, + csrf_config: CSRFConfig | None = None, + debug: bool = True, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + etag: ETag | None = None, + event_emitter_backend: type[BaseEventEmitterBackend] = SimpleEventEmitter, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + include_in_schema: bool | EmptyType = Empty, + listeners: Sequence[EventListener] | None = None, + logging_config: BaseLoggingConfig | EmptyType | None = Empty, + middleware: Sequence[Middleware] | None = None, + multipart_form_part_limit: int = 1000, + on_app_init: Sequence[OnAppInitHandler] | None = None, + on_shutdown: Sequence[LifespanHook] | None = None, + on_startup: Sequence[LifespanHook] | None = None, + openapi_config: OpenAPIConfig | None = DEFAULT_OPENAPI_CONFIG, + opt: Mapping[str, Any] | None = None, + parameters: ParametersMap | None = None, + plugins: Sequence[PluginProtocol] | None = None, + lifespan: list[Callable[[Litestar], AbstractAsyncContextManager] | AbstractAsyncContextManager] | None = None, + raise_server_exceptions: bool = True, + pdb_on_exception: bool | None = None, + request_class: type[Request] | None = None, + response_cache_config: ResponseCacheConfig | None = None, + response_class: type[Response] | None = None, + response_cookies: ResponseCookies | None = None, + response_headers: ResponseHeaders | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + root_path: str = "", + security: Sequence[SecurityRequirement] | None = None, + session_config: BaseBackendConfig | None = None, + signature_namespace: Mapping[str, Any] | None = None, + signature_types: Sequence[Any] | None = None, + state: State | None = None, + static_files_config: Sequence[StaticFilesConfig] | None = None, + stores: StoreRegistry | dict[str, Store] | None = None, + tags: Sequence[str] | None = None, + template_config: TemplateConfig | None = None, + timeout: float | None = None, + type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, + experimental_features: list[ExperimentalFeatures] | None = None, +) -> TestClient[Litestar]: + """Create a Litestar app instance and initializes it. + + :class:`TestClient <litestar.testing.TestClient>` with it. + + Notes: + - This function should be called as a context manager to ensure async startup and shutdown are + handled correctly. + + Examples: + .. code-block:: python + + from litestar import get + from litestar.testing import create_test_client + + + @get("/some-path") + def my_handler() -> dict[str, str]: + return {"hello": "world"} + + + def test_my_handler() -> None: + with create_test_client(my_handler) as client: + response = client.get("/some-path") + assert response.json() == {"hello": "world"} + + Args: + route_handlers: A single handler or a sequence of route handlers, which can include instances of + :class:`Router <litestar.router.Router>`, subclasses of :class:`Controller <.controller.Controller>` or + any function decorated by the route handler decorators. + backend: The async backend to use, options are "asyncio" or "trio". + backend_options: ``anyio`` options. + base_url: URL scheme and domain for test request paths, e.g. ``http://testserver``. + raise_server_exceptions: Flag for underlying the test client to raise server exceptions instead of wrapping them + in an HTTP response. + root_path: Path prefix for requests. + session_config: Configuration for Session Middleware class to create raw session cookies for request to the + route handlers. + after_exception: A sequence of :class:`exception hook handlers <.types.AfterExceptionHookHandler>`. This + hook is called after an exception occurs. In difference to exception handlers, it is not meant to + return a response - only to process the exception (e.g. log it, send it to Sentry etc.). + after_request: A sync or async function executed after the route handler function returned and the response + object has been resolved. Receives the response object. + after_response: A sync or async function called after the response has been awaited. It receives the + :class:`Request <.connection.Request>` object and should not return any values. + allowed_hosts: A sequence of allowed hosts, or an + :class:`AllowedHostsConfig <.config.allowed_hosts.AllowedHostsConfig>` instance. Enables the builtin + allowed hosts middleware. + before_request: A sync or async function called immediately before calling the route handler. Receives the + :class:`Request <.connection.Request>` instance and any non-``None`` return value is used for the + response, bypassing the route handler. + before_send: A sequence of :class:`before send hook handlers <.types.BeforeMessageSendHookHandler>`. Called + when the ASGI send function is called. + cache_control: A ``cache-control`` header of type + :class:`CacheControlHeader <litestar.datastructures.CacheControlHeader>` to add to route handlers of + this app. Can be overridden by route handlers. + compression_config: Configures compression behaviour of the application, this enabled a builtin or user + defined Compression middleware. + cors_config: If set, configures :class:`CORSMiddleware <.middleware.cors.CORSMiddleware>`. + csrf_config: If set, configures :class:`CSRFMiddleware <.middleware.csrf.CSRFMiddleware>`. + debug: If ``True``, app errors rendered as HTML with a stack trace. + dependencies: A string keyed mapping of dependency :class:`Providers <.di.Provide>`. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + etag: An ``etag`` header of type :class:`ETag <.datastructures.ETag>` to add to route handlers of this app. + Can be overridden by route handlers. + event_emitter_backend: A subclass of + :class:`BaseEventEmitterBackend <.events.emitter.BaseEventEmitterBackend>`. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + include_in_schema: A boolean flag dictating whether the route handler should be documented in the OpenAPI schema. + lifespan: A list of callables returning async context managers, wrapping the lifespan of the ASGI application + listeners: A sequence of :class:`EventListener <.events.listener.EventListener>`. + logging_config: A subclass of :class:`BaseLoggingConfig <.logging.config.BaseLoggingConfig>`. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + multipart_form_part_limit: The maximal number of allowed parts in a multipart/formdata request. This limit + is intended to protect from DoS attacks. + on_app_init: A sequence of :class:`OnAppInitHandler <.types.OnAppInitHandler>` instances. Handlers receive + an instance of :class:`AppConfig <.config.app.AppConfig>` that will have been initially populated with + the parameters passed to :class:`Litestar <litestar.app.Litestar>`, and must return an instance of same. + If more than one handler is registered they are called in the order they are provided. + on_shutdown: A sequence of :class:`LifespanHook <.types.LifespanHook>` called during application + shutdown. + on_startup: A sequence of :class:`LifespanHook <litestar.types.LifespanHook>` called during + application startup. + openapi_config: Defaults to :attr:`DEFAULT_OPENAPI_CONFIG` + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <litestar.connection.request.Request>` or + :class:`ASGI Scope <.types.Scope>`. + parameters: A mapping of :class:`Parameter <.params.Parameter>` definitions available to all application + paths. + pdb_on_exception: Drop into the PDB when an exception occurs. + plugins: Sequence of plugins. + request_class: An optional subclass of :class:`Request <.connection.Request>` to use for http connections. + response_class: A custom subclass of :class:`Response <.response.Response>` to be used as the app's default + response. + response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>`. + response_headers: A string keyed mapping of :class:`ResponseHeader <.datastructures.ResponseHeader>` + response_cache_config: Configures caching behavior of the application. + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + route_handlers: A sequence of route handlers, which can include instances of + :class:`Router <.router.Router>`, subclasses of :class:`Controller <.controller.Controller>` or any + callable decorated by the route handler decorators. + security: A sequence of dicts that will be added to the schema of all route handlers in the application. + See + :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>` for details. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modeling. + signature_types: A sequence of types for use in forward reference resolution during signature modeling. + These types will be added to the signature namespace using their ``__name__`` attribute. + state: An optional :class:`State <.datastructures.State>` for application state. + static_files_config: A sequence of :class:`StaticFilesConfig <.static_files.StaticFilesConfig>` + stores: Central registry of :class:`Store <.stores.base.Store>` that will be available throughout the + application. If this is a dictionary to it will be passed to a + :class:`StoreRegistry <.stores.registry.StoreRegistry>`. If it is a + :class:`StoreRegistry <.stores.registry.StoreRegistry>`, this instance will be used directly. + tags: A sequence of string tags that will be appended to the schema of all route handlers under the + application. + template_config: An instance of :class:`TemplateConfig <.template.TemplateConfig>` + timeout: Request timeout + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + websocket_class: An optional subclass of :class:`WebSocket <.connection.WebSocket>` to use for websocket + connections. + experimental_features: An iterable of experimental features to enable + + + Returns: + An instance of :class:`TestClient <.testing.TestClient>` with a created app instance. + """ + route_handlers = () if route_handlers is None else route_handlers + if is_class_and_subclass(route_handlers, Controller) or not isinstance(route_handlers, Sequence): + route_handlers = (route_handlers,) + + app = Litestar( + after_exception=after_exception, + after_request=after_request, + after_response=after_response, + allowed_hosts=allowed_hosts, + before_request=before_request, + before_send=before_send, + cache_control=cache_control, + compression_config=compression_config, + cors_config=cors_config, + csrf_config=csrf_config, + debug=debug, + dependencies=dependencies, + dto=dto, + etag=etag, + lifespan=lifespan, + event_emitter_backend=event_emitter_backend, + exception_handlers=exception_handlers, + guards=guards, + include_in_schema=include_in_schema, + listeners=listeners, + logging_config=logging_config, + middleware=middleware, + multipart_form_part_limit=multipart_form_part_limit, + on_app_init=on_app_init, + on_shutdown=on_shutdown, + on_startup=on_startup, + openapi_config=openapi_config, + opt=opt, + parameters=parameters, + pdb_on_exception=pdb_on_exception, + plugins=plugins, + request_class=request_class, + response_cache_config=response_cache_config, + response_class=response_class, + response_cookies=response_cookies, + response_headers=response_headers, + return_dto=return_dto, + route_handlers=route_handlers, + security=security, + signature_namespace=signature_namespace, + signature_types=signature_types, + state=state, + static_files_config=static_files_config, + stores=stores, + tags=tags, + template_config=template_config, + type_encoders=type_encoders, + websocket_class=websocket_class, + experimental_features=experimental_features, + ) + + return TestClient[Litestar]( + app=app, + backend=backend, + backend_options=backend_options, + base_url=base_url, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + session_config=session_config, + timeout=timeout, + ) + + +def create_async_test_client( + route_handlers: ControllerRouterHandler | Sequence[ControllerRouterHandler] | None = None, + *, + after_exception: Sequence[AfterExceptionHookHandler] | None = None, + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + allowed_hosts: Sequence[str] | AllowedHostsConfig | None = None, + backend: Literal["asyncio", "trio"] = "asyncio", + backend_options: Mapping[str, Any] | None = None, + base_url: str = "http://testserver.local", + before_request: BeforeRequestHookHandler | None = None, + before_send: Sequence[BeforeMessageSendHookHandler] | None = None, + cache_control: CacheControlHeader | None = None, + compression_config: CompressionConfig | None = None, + cors_config: CORSConfig | None = None, + csrf_config: CSRFConfig | None = None, + debug: bool = True, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + etag: ETag | None = None, + event_emitter_backend: type[BaseEventEmitterBackend] = SimpleEventEmitter, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + include_in_schema: bool | EmptyType = Empty, + lifespan: list[Callable[[Litestar], AbstractAsyncContextManager] | AbstractAsyncContextManager] | None = None, + listeners: Sequence[EventListener] | None = None, + logging_config: BaseLoggingConfig | EmptyType | None = Empty, + middleware: Sequence[Middleware] | None = None, + multipart_form_part_limit: int = 1000, + on_app_init: Sequence[OnAppInitHandler] | None = None, + on_shutdown: Sequence[LifespanHook] | None = None, + on_startup: Sequence[LifespanHook] | None = None, + openapi_config: OpenAPIConfig | None = DEFAULT_OPENAPI_CONFIG, + opt: Mapping[str, Any] | None = None, + parameters: ParametersMap | None = None, + pdb_on_exception: bool | None = None, + plugins: Sequence[PluginProtocol] | None = None, + raise_server_exceptions: bool = True, + request_class: type[Request] | None = None, + response_cache_config: ResponseCacheConfig | None = None, + response_class: type[Response] | None = None, + response_cookies: ResponseCookies | None = None, + response_headers: ResponseHeaders | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + root_path: str = "", + security: Sequence[SecurityRequirement] | None = None, + session_config: BaseBackendConfig | None = None, + signature_namespace: Mapping[str, Any] | None = None, + signature_types: Sequence[Any] | None = None, + state: State | None = None, + static_files_config: Sequence[StaticFilesConfig] | None = None, + stores: StoreRegistry | dict[str, Store] | None = None, + tags: Sequence[str] | None = None, + template_config: TemplateConfig | None = None, + timeout: float | None = None, + type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, + experimental_features: list[ExperimentalFeatures] | None = None, +) -> AsyncTestClient[Litestar]: + """Create a Litestar app instance and initializes it. + + :class:`AsyncTestClient <litestar.testing.AsyncTestClient>` with it. + + Notes: + - This function should be called as a context manager to ensure async startup and shutdown are + handled correctly. + + Examples: + .. code-block:: python + + from litestar import get + from litestar.testing import create_async_test_client + + + @get("/some-path") + def my_handler() -> dict[str, str]: + return {"hello": "world"} + + + async def test_my_handler() -> None: + async with create_async_test_client(my_handler) as client: + response = await client.get("/some-path") + assert response.json() == {"hello": "world"} + + Args: + route_handlers: A single handler or a sequence of route handlers, which can include instances of + :class:`Router <litestar.router.Router>`, subclasses of :class:`Controller <.controller.Controller>` or + any function decorated by the route handler decorators. + backend: The async backend to use, options are "asyncio" or "trio". + backend_options: ``anyio`` options. + base_url: URL scheme and domain for test request paths, e.g. ``http://testserver``. + raise_server_exceptions: Flag for underlying the test client to raise server exceptions instead of wrapping them + in an HTTP response. + root_path: Path prefix for requests. + session_config: Configuration for Session Middleware class to create raw session cookies for request to the + route handlers. + after_exception: A sequence of :class:`exception hook handlers <.types.AfterExceptionHookHandler>`. This + hook is called after an exception occurs. In difference to exception handlers, it is not meant to + return a response - only to process the exception (e.g. log it, send it to Sentry etc.). + after_request: A sync or async function executed after the route handler function returned and the response + object has been resolved. Receives the response object. + after_response: A sync or async function called after the response has been awaited. It receives the + :class:`Request <.connection.Request>` object and should not return any values. + allowed_hosts: A sequence of allowed hosts, or an + :class:`AllowedHostsConfig <.config.allowed_hosts.AllowedHostsConfig>` instance. Enables the builtin + allowed hosts middleware. + before_request: A sync or async function called immediately before calling the route handler. Receives the + :class:`Request <.connection.Request>` instance and any non-``None`` return value is used for the + response, bypassing the route handler. + before_send: A sequence of :class:`before send hook handlers <.types.BeforeMessageSendHookHandler>`. Called + when the ASGI send function is called. + cache_control: A ``cache-control`` header of type + :class:`CacheControlHeader <litestar.datastructures.CacheControlHeader>` to add to route handlers of + this app. Can be overridden by route handlers. + compression_config: Configures compression behaviour of the application, this enabled a builtin or user + defined Compression middleware. + cors_config: If set, configures :class:`CORSMiddleware <.middleware.cors.CORSMiddleware>`. + csrf_config: If set, configures :class:`CSRFMiddleware <.middleware.csrf.CSRFMiddleware>`. + debug: If ``True``, app errors rendered as HTML with a stack trace. + dependencies: A string keyed mapping of dependency :class:`Providers <.di.Provide>`. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + etag: An ``etag`` header of type :class:`ETag <.datastructures.ETag>` to add to route handlers of this app. + Can be overridden by route handlers. + event_emitter_backend: A subclass of + :class:`BaseEventEmitterBackend <.events.emitter.BaseEventEmitterBackend>`. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + include_in_schema: A boolean flag dictating whether the route handler should be documented in the OpenAPI schema. + lifespan: A list of callables returning async context managers, wrapping the lifespan of the ASGI application + listeners: A sequence of :class:`EventListener <.events.listener.EventListener>`. + logging_config: A subclass of :class:`BaseLoggingConfig <.logging.config.BaseLoggingConfig>`. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + multipart_form_part_limit: The maximal number of allowed parts in a multipart/formdata request. This limit + is intended to protect from DoS attacks. + on_app_init: A sequence of :class:`OnAppInitHandler <.types.OnAppInitHandler>` instances. Handlers receive + an instance of :class:`AppConfig <.config.app.AppConfig>` that will have been initially populated with + the parameters passed to :class:`Litestar <litestar.app.Litestar>`, and must return an instance of same. + If more than one handler is registered they are called in the order they are provided. + on_shutdown: A sequence of :class:`LifespanHook <.types.LifespanHook>` called during application + shutdown. + on_startup: A sequence of :class:`LifespanHook <litestar.types.LifespanHook>` called during + application startup. + openapi_config: Defaults to :attr:`DEFAULT_OPENAPI_CONFIG` + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <litestar.connection.request.Request>` or + :class:`ASGI Scope <.types.Scope>`. + parameters: A mapping of :class:`Parameter <.params.Parameter>` definitions available to all application + paths. + pdb_on_exception: Drop into the PDB when an exception occurs. + plugins: Sequence of plugins. + request_class: An optional subclass of :class:`Request <.connection.Request>` to use for http connections. + response_class: A custom subclass of :class:`Response <.response.Response>` to be used as the app's default + response. + response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>`. + response_headers: A string keyed mapping of :class:`ResponseHeader <.datastructures.ResponseHeader>` + response_cache_config: Configures caching behavior of the application. + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + route_handlers: A sequence of route handlers, which can include instances of + :class:`Router <.router.Router>`, subclasses of :class:`Controller <.controller.Controller>` or any + callable decorated by the route handler decorators. + security: A sequence of dicts that will be added to the schema of all route handlers in the application. + See + :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>` for details. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modeling. + signature_types: A sequence of types for use in forward reference resolution during signature modeling. + These types will be added to the signature namespace using their ``__name__`` attribute. + state: An optional :class:`State <.datastructures.State>` for application state. + static_files_config: A sequence of :class:`StaticFilesConfig <.static_files.StaticFilesConfig>` + stores: Central registry of :class:`Store <.stores.base.Store>` that will be available throughout the + application. If this is a dictionary to it will be passed to a + :class:`StoreRegistry <.stores.registry.StoreRegistry>`. If it is a + :class:`StoreRegistry <.stores.registry.StoreRegistry>`, this instance will be used directly. + tags: A sequence of string tags that will be appended to the schema of all route handlers under the + application. + template_config: An instance of :class:`TemplateConfig <.template.TemplateConfig>` + timeout: Request timeout + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + websocket_class: An optional subclass of :class:`WebSocket <.connection.WebSocket>` to use for websocket + connections. + experimental_features: An iterable of experimental features to enable + + Returns: + An instance of :class:`AsyncTestClient <litestar.testing.AsyncTestClient>` with a created app instance. + """ + route_handlers = () if route_handlers is None else route_handlers + if is_class_and_subclass(route_handlers, Controller) or not isinstance(route_handlers, Sequence): + route_handlers = (route_handlers,) + + app = Litestar( + after_exception=after_exception, + after_request=after_request, + after_response=after_response, + allowed_hosts=allowed_hosts, + before_request=before_request, + before_send=before_send, + cache_control=cache_control, + compression_config=compression_config, + cors_config=cors_config, + csrf_config=csrf_config, + debug=debug, + dependencies=dependencies, + dto=dto, + etag=etag, + event_emitter_backend=event_emitter_backend, + exception_handlers=exception_handlers, + guards=guards, + include_in_schema=include_in_schema, + lifespan=lifespan, + listeners=listeners, + logging_config=logging_config, + middleware=middleware, + multipart_form_part_limit=multipart_form_part_limit, + on_app_init=on_app_init, + on_shutdown=on_shutdown, + on_startup=on_startup, + openapi_config=openapi_config, + opt=opt, + parameters=parameters, + pdb_on_exception=pdb_on_exception, + plugins=plugins, + request_class=request_class, + response_cache_config=response_cache_config, + response_class=response_class, + response_cookies=response_cookies, + response_headers=response_headers, + return_dto=return_dto, + route_handlers=route_handlers, + security=security, + signature_namespace=signature_namespace, + signature_types=signature_types, + state=state, + static_files_config=static_files_config, + stores=stores, + tags=tags, + template_config=template_config, + type_encoders=type_encoders, + websocket_class=websocket_class, + experimental_features=experimental_features, + ) + + return AsyncTestClient[Litestar]( + app=app, + backend=backend, + backend_options=backend_options, + base_url=base_url, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + session_config=session_config, + timeout=timeout, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/life_span_handler.py b/venv/lib/python3.11/site-packages/litestar/testing/life_span_handler.py new file mode 100644 index 0000000..8ee7d22 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/life_span_handler.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from math import inf +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, cast + +from anyio import create_memory_object_stream +from anyio.streams.stapled import StapledObjectStream + +from litestar.testing.client.base import BaseTestClient + +if TYPE_CHECKING: + from litestar.types import ( + LifeSpanReceiveMessage, # noqa: F401 + LifeSpanSendMessage, + LifeSpanShutdownEvent, + LifeSpanStartupEvent, + ) + +T = TypeVar("T", bound=BaseTestClient) + + +class LifeSpanHandler(Generic[T]): + __slots__ = "stream_send", "stream_receive", "client", "task" + + def __init__(self, client: T) -> None: + self.client = client + self.stream_send = StapledObjectStream[Optional["LifeSpanSendMessage"]](*create_memory_object_stream(inf)) # type: ignore[arg-type] + self.stream_receive = StapledObjectStream["LifeSpanReceiveMessage"](*create_memory_object_stream(inf)) # type: ignore[arg-type] + + with self.client.portal() as portal: + self.task = portal.start_task_soon(self.lifespan) + portal.call(self.wait_startup) + + async def receive(self) -> LifeSpanSendMessage: + message = await self.stream_send.receive() + if message is None: + self.task.result() + return cast("LifeSpanSendMessage", message) + + async def wait_startup(self) -> None: + event: LifeSpanStartupEvent = {"type": "lifespan.startup"} + await self.stream_receive.send(event) + + message = await self.receive() + if message["type"] not in ( + "lifespan.startup.complete", + "lifespan.startup.failed", + ): + raise RuntimeError( + "Received unexpected ASGI message type. Expected 'lifespan.startup.complete' or " + f"'lifespan.startup.failed'. Got {message['type']!r}", + ) + if message["type"] == "lifespan.startup.failed": + await self.receive() + + async def wait_shutdown(self) -> None: + async with self.stream_send: + lifespan_shutdown_event: LifeSpanShutdownEvent = {"type": "lifespan.shutdown"} + await self.stream_receive.send(lifespan_shutdown_event) + + message = await self.receive() + if message["type"] not in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ): + raise RuntimeError( + "Received unexpected ASGI message type. Expected 'lifespan.shutdown.complete' or " + f"'lifespan.shutdown.failed'. Got {message['type']!r}", + ) + if message["type"] == "lifespan.shutdown.failed": + await self.receive() + + async def lifespan(self) -> None: + scope = {"type": "lifespan"} + try: + await self.client.app(scope, self.stream_receive.receive, self.stream_send.send) + finally: + await self.stream_send.send(None) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/request_factory.py b/venv/lib/python3.11/site-packages/litestar/testing/request_factory.py new file mode 100644 index 0000000..ccb29c6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/request_factory.py @@ -0,0 +1,565 @@ +from __future__ import annotations + +import json +from functools import partial +from typing import TYPE_CHECKING, Any, cast +from urllib.parse import urlencode + +from httpx._content import encode_json as httpx_encode_json +from httpx._content import encode_multipart_data, encode_urlencoded_data + +from litestar import delete, patch, post, put +from litestar.app import Litestar +from litestar.connection import Request +from litestar.enums import HttpMethod, ParamType, RequestEncodingType, ScopeType +from litestar.handlers.http_handlers import get +from litestar.serialization import decode_json, default_serializer, encode_json +from litestar.types import DataContainerType, HTTPScope, RouteHandlerType +from litestar.types.asgi_types import ASGIVersion +from litestar.utils import get_serializer_from_scope +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from httpx._types import FileTypes + + from litestar.datastructures.cookie import Cookie + from litestar.handlers.http_handlers import HTTPRouteHandler + +_decorator_http_method_map: dict[HttpMethod, type[HTTPRouteHandler]] = { + HttpMethod.GET: get, + HttpMethod.POST: post, + HttpMethod.DELETE: delete, + HttpMethod.PATCH: patch, + HttpMethod.PUT: put, +} + + +def _create_default_route_handler( + http_method: HttpMethod, handler_kwargs: dict[str, Any] | None, app: Litestar +) -> HTTPRouteHandler: + handler_decorator = _decorator_http_method_map[http_method] + + def _default_route_handler() -> None: ... + + handler = handler_decorator("/", sync_to_thread=False, **(handler_kwargs or {}))(_default_route_handler) + handler.owner = app + return handler + + +def _create_default_app() -> Litestar: + return Litestar(route_handlers=[]) + + +class RequestFactory: + """Factory to create :class:`Request <litestar.connection.Request>` instances.""" + + __slots__ = ( + "app", + "server", + "port", + "root_path", + "scheme", + "handler_kwargs", + "serializer", + ) + + def __init__( + self, + app: Litestar | None = None, + server: str = "test.org", + port: int = 3000, + root_path: str = "", + scheme: str = "http", + handler_kwargs: dict[str, Any] | None = None, + ) -> None: + """Initialize ``RequestFactory`` + + Args: + app: An instance of :class:`Litestar <litestar.app.Litestar>` to set as ``request.scope["app"]``. + server: The server's domain. + port: The server's port. + root_path: Root path for the server. + scheme: Scheme for the server. + handler_kwargs: Kwargs to pass to the route handler created for the request + + Examples: + .. code-block:: python + + from litestar import Litestar + from litestar.enums import RequestEncodingType + from litestar.testing import RequestFactory + + from tests import PersonFactory + + my_app = Litestar(route_handlers=[]) + my_server = "litestar.org" + + # Create a GET request + query_params = {"id": 1} + get_user_request = RequestFactory(app=my_app, server=my_server).get( + "/person", query_params=query_params + ) + + # Create a POST request + new_person = PersonFactory.build() + create_user_request = RequestFactory(app=my_app, server=my_server).post( + "/person", data=person + ) + + # Create a request with a special header + headers = {"header1": "value1"} + request_with_header = RequestFactory(app=my_app, server=my_server).get( + "/person", query_params=query_params, headers=headers + ) + + # Create a request with a media type + request_with_media_type = RequestFactory(app=my_app, server=my_server).post( + "/person", data=person, request_media_type=RequestEncodingType.MULTI_PART + ) + + """ + + self.app = app if app is not None else _create_default_app() + self.server = server + self.port = port + self.root_path = root_path + self.scheme = scheme + self.handler_kwargs = handler_kwargs + self.serializer = partial(default_serializer, type_encoders=self.app.type_encoders) + + def _create_scope( + self, + path: str, + http_method: HttpMethod, + session: dict[str, Any] | None = None, + user: Any = None, + auth: Any = None, + query_params: dict[str, str | list[str]] | None = None, + state: dict[str, Any] | None = None, + path_params: dict[str, str] | None = None, + http_version: str | None = "1.1", + route_handler: RouteHandlerType | None = None, + ) -> HTTPScope: + """Create the scope for the :class:`Request <litestar.connection.Request>`. + + Args: + path: The request's path. + http_method: The request's HTTP method. + session: A dictionary of session data. + user: A value for `request.scope["user"]`. + auth: A value for `request.scope["auth"]`. + query_params: A dictionary of values from which the request's query will be generated. + state: Arbitrary request state. + path_params: A string keyed dictionary of path parameter values. + http_version: HTTP version. Defaults to "1.1". + route_handler: A route handler instance or method. If not provided a default handler is set. + + Returns: + A dictionary that can be passed as a scope to the :class:`Request <litestar.connection.Request>` ctor. + """ + if session is None: + session = {} + + if state is None: + state = {} + + if path_params is None: + path_params = {} + + return HTTPScope( + type=ScopeType.HTTP, + method=http_method.value, + scheme=self.scheme, + server=(self.server, self.port), + root_path=self.root_path.rstrip("/"), + path=path, + headers=[], + app=self.app, + session=session, + user=user, + auth=auth, + query_string=urlencode(query_params, doseq=True).encode() if query_params else b"", + path_params=path_params, + client=(self.server, self.port), + state=state, + asgi=ASGIVersion(spec_version="3.0", version="3.0"), + http_version=http_version or "1.1", + raw_path=path.encode("ascii"), + route_handler=route_handler + or _create_default_route_handler(http_method, self.handler_kwargs, app=self.app), + extensions={}, + ) + + @classmethod + def _create_cookie_header(cls, headers: dict[str, str], cookies: list[Cookie] | str | None = None) -> None: + """Create the cookie header and add it to the ``headers`` dictionary. + + Args: + headers: A dictionary of headers, the cookie header will be added to it. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + """ + if not cookies: + return + + if isinstance(cookies, list): + cookie_header = "; ".join(cookie.to_header(header="") for cookie in cookies) + headers[ParamType.COOKIE] = cookie_header + elif isinstance(cookies, str): + headers[ParamType.COOKIE] = cookies + + def _build_headers( + self, + headers: dict[str, str] | None = None, + cookies: list[Cookie] | str | None = None, + ) -> list[tuple[bytes, bytes]]: + """Build a list of encoded headers that can be passed to the request scope. + + Args: + headers: A dictionary of headers. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + + Returns: + A list of encoded headers that can be passed to the request scope. + """ + headers = headers or {} + self._create_cookie_header(headers, cookies) + return [ + ((key.lower()).encode("latin-1", errors="ignore"), value.encode("latin-1", errors="ignore")) + for key, value in headers.items() + ] + + def _create_request_with_data( + self, + http_method: HttpMethod, + path: str, + headers: dict[str, str] | None = None, + cookies: list[Cookie] | str | None = None, + session: dict[str, Any] | None = None, + user: Any = None, + auth: Any = None, + request_media_type: RequestEncodingType = RequestEncodingType.JSON, + data: dict[str, Any] | DataContainerType | None = None, # pyright: ignore + files: dict[str, FileTypes] | list[tuple[str, FileTypes]] | None = None, + query_params: dict[str, str | list[str]] | None = None, + state: dict[str, Any] | None = None, + path_params: dict[str, str] | None = None, + http_version: str | None = "1.1", + route_handler: RouteHandlerType | None = None, + ) -> Request[Any, Any, Any]: + """Create a :class:`Request <litestar.connection.Request>` instance that has body (data) + + Args: + http_method: The request's HTTP method. + path: The request's path. + headers: A dictionary of headers. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + session: A dictionary of session data. + user: A value for `request.scope["user"]` + auth: A value for `request.scope["auth"]` + request_media_type: The 'Content-Type' header of the request. + data: A value for the request's body. Can be any supported serializable type. + files: A dictionary of files to be sent with the request. + query_params: A dictionary of values from which the request's query will be generated. + state: Arbitrary request state. + path_params: A string keyed dictionary of path parameter values. + http_version: HTTP version. Defaults to "1.1". + route_handler: A route handler instance or method. If not provided a default handler is set. + + Returns: + A :class:`Request <litestar.connection.Request>` instance + """ + scope = self._create_scope( + path=path, + http_method=http_method, + session=session, + user=user, + auth=auth, + query_params=query_params, + state=state, + path_params=path_params, + http_version=http_version, + route_handler=route_handler, + ) + + headers = headers or {} + body = b"" + if data: + data = json.loads(encode_json(data, serializer=get_serializer_from_scope(scope))) + + if request_media_type == RequestEncodingType.JSON: + encoding_headers, stream = httpx_encode_json(data) + elif request_media_type == RequestEncodingType.MULTI_PART: + encoding_headers, stream = encode_multipart_data( # type: ignore[assignment] + cast("dict[str, Any]", data), files=files or [], boundary=None + ) + else: + encoding_headers, stream = encode_urlencoded_data(decode_json(value=encode_json(data))) + headers.update(encoding_headers) + for chunk in stream: + body += chunk + ScopeState.from_scope(scope).body = body + self._create_cookie_header(headers, cookies) + scope["headers"] = self._build_headers(headers) + return Request(scope=scope) + + def get( + self, + path: str = "/", + headers: dict[str, str] | None = None, + cookies: list[Cookie] | str | None = None, + session: dict[str, Any] | None = None, + user: Any = None, + auth: Any = None, + query_params: dict[str, str | list[str]] | None = None, + state: dict[str, Any] | None = None, + path_params: dict[str, str] | None = None, + http_version: str | None = "1.1", + route_handler: RouteHandlerType | None = None, + ) -> Request[Any, Any, Any]: + """Create a GET :class:`Request <litestar.connection.Request>` instance. + + Args: + path: The request's path. + headers: A dictionary of headers. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + session: A dictionary of session data. + user: A value for `request.scope["user"]`. + auth: A value for `request.scope["auth"]`. + query_params: A dictionary of values from which the request's query will be generated. + state: Arbitrary request state. + path_params: A string keyed dictionary of path parameter values. + http_version: HTTP version. Defaults to "1.1". + route_handler: A route handler instance or method. If not provided a default handler is set. + + Returns: + A :class:`Request <litestar.connection.Request>` instance + """ + scope = self._create_scope( + path=path, + http_method=HttpMethod.GET, + session=session, + user=user, + auth=auth, + query_params=query_params, + state=state, + path_params=path_params, + http_version=http_version, + route_handler=route_handler, + ) + + scope["headers"] = self._build_headers(headers, cookies) + return Request(scope=scope) + + def post( + self, + path: str = "/", + headers: dict[str, str] | None = None, + cookies: list[Cookie] | str | None = None, + session: dict[str, Any] | None = None, + user: Any = None, + auth: Any = None, + request_media_type: RequestEncodingType = RequestEncodingType.JSON, + data: dict[str, Any] | DataContainerType | None = None, # pyright: ignore + query_params: dict[str, str | list[str]] | None = None, + state: dict[str, Any] | None = None, + path_params: dict[str, str] | None = None, + http_version: str | None = "1.1", + route_handler: RouteHandlerType | None = None, + ) -> Request[Any, Any, Any]: + """Create a POST :class:`Request <litestar.connection.Request>` instance. + + Args: + path: The request's path. + headers: A dictionary of headers. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + session: A dictionary of session data. + user: A value for `request.scope["user"]`. + auth: A value for `request.scope["auth"]`. + request_media_type: The 'Content-Type' header of the request. + data: A value for the request's body. Can be any supported serializable type. + query_params: A dictionary of values from which the request's query will be generated. + state: Arbitrary request state. + path_params: A string keyed dictionary of path parameter values. + http_version: HTTP version. Defaults to "1.1". + route_handler: A route handler instance or method. If not provided a default handler is set. + + Returns: + A :class:`Request <litestar.connection.Request>` instance + """ + return self._create_request_with_data( + auth=auth, + cookies=cookies, + data=data, + headers=headers, + http_method=HttpMethod.POST, + path=path, + query_params=query_params, + request_media_type=request_media_type, + session=session, + user=user, + state=state, + path_params=path_params, + http_version=http_version, + route_handler=route_handler, + ) + + def put( + self, + path: str = "/", + headers: dict[str, str] | None = None, + cookies: list[Cookie] | str | None = None, + session: dict[str, Any] | None = None, + user: Any = None, + auth: Any = None, + request_media_type: RequestEncodingType = RequestEncodingType.JSON, + data: dict[str, Any] | DataContainerType | None = None, # pyright: ignore + query_params: dict[str, str | list[str]] | None = None, + state: dict[str, Any] | None = None, + path_params: dict[str, str] | None = None, + http_version: str | None = "1.1", + route_handler: RouteHandlerType | None = None, + ) -> Request[Any, Any, Any]: + """Create a PUT :class:`Request <litestar.connection.Request>` instance. + + Args: + path: The request's path. + headers: A dictionary of headers. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + session: A dictionary of session data. + user: A value for `request.scope["user"]`. + auth: A value for `request.scope["auth"]`. + request_media_type: The 'Content-Type' header of the request. + data: A value for the request's body. Can be any supported serializable type. + query_params: A dictionary of values from which the request's query will be generated. + state: Arbitrary request state. + path_params: A string keyed dictionary of path parameter values. + http_version: HTTP version. Defaults to "1.1". + route_handler: A route handler instance or method. If not provided a default handler is set. + + Returns: + A :class:`Request <litestar.connection.Request>` instance + """ + return self._create_request_with_data( + auth=auth, + cookies=cookies, + data=data, + headers=headers, + http_method=HttpMethod.PUT, + path=path, + query_params=query_params, + request_media_type=request_media_type, + session=session, + user=user, + state=state, + path_params=path_params, + http_version=http_version, + route_handler=route_handler, + ) + + def patch( + self, + path: str = "/", + headers: dict[str, str] | None = None, + cookies: list[Cookie] | str | None = None, + session: dict[str, Any] | None = None, + user: Any = None, + auth: Any = None, + request_media_type: RequestEncodingType = RequestEncodingType.JSON, + data: dict[str, Any] | DataContainerType | None = None, # pyright: ignore + query_params: dict[str, str | list[str]] | None = None, + state: dict[str, Any] | None = None, + path_params: dict[str, str] | None = None, + http_version: str | None = "1.1", + route_handler: RouteHandlerType | None = None, + ) -> Request[Any, Any, Any]: + """Create a PATCH :class:`Request <litestar.connection.Request>` instance. + + Args: + path: The request's path. + headers: A dictionary of headers. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + session: A dictionary of session data. + user: A value for `request.scope["user"]`. + auth: A value for `request.scope["auth"]`. + request_media_type: The 'Content-Type' header of the request. + data: A value for the request's body. Can be any supported serializable type. + query_params: A dictionary of values from which the request's query will be generated. + state: Arbitrary request state. + path_params: A string keyed dictionary of path parameter values. + http_version: HTTP version. Defaults to "1.1". + route_handler: A route handler instance or method. If not provided a default handler is set. + + Returns: + A :class:`Request <litestar.connection.Request>` instance + """ + return self._create_request_with_data( + auth=auth, + cookies=cookies, + data=data, + headers=headers, + http_method=HttpMethod.PATCH, + path=path, + query_params=query_params, + request_media_type=request_media_type, + session=session, + user=user, + state=state, + path_params=path_params, + http_version=http_version, + route_handler=route_handler, + ) + + def delete( + self, + path: str = "/", + headers: dict[str, str] | None = None, + cookies: list[Cookie] | str | None = None, + session: dict[str, Any] | None = None, + user: Any = None, + auth: Any = None, + query_params: dict[str, str | list[str]] | None = None, + state: dict[str, Any] | None = None, + path_params: dict[str, str] | None = None, + http_version: str | None = "1.1", + route_handler: RouteHandlerType | None = None, + ) -> Request[Any, Any, Any]: + """Create a POST :class:`Request <litestar.connection.Request>` instance. + + Args: + path: The request's path. + headers: A dictionary of headers. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + session: A dictionary of session data. + user: A value for `request.scope["user"]`. + auth: A value for `request.scope["auth"]`. + query_params: A dictionary of values from which the request's query will be generated. + state: Arbitrary request state. + path_params: A string keyed dictionary of path parameter values. + http_version: HTTP version. Defaults to "1.1". + route_handler: A route handler instance or method. If not provided a default handler is set. + + Returns: + A :class:`Request <litestar.connection.Request>` instance + """ + scope = self._create_scope( + path=path, + http_method=HttpMethod.DELETE, + session=session, + user=user, + auth=auth, + query_params=query_params, + state=state, + path_params=path_params, + http_version=http_version, + route_handler=route_handler, + ) + scope["headers"] = self._build_headers(headers, cookies) + return Request(scope=scope) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/transport.py b/venv/lib/python3.11/site-packages/litestar/testing/transport.py new file mode 100644 index 0000000..ffa76a4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/transport.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from io import BytesIO +from types import GeneratorType +from typing import TYPE_CHECKING, Any, Generic, TypedDict, TypeVar, Union, cast +from urllib.parse import unquote + +from anyio import Event +from httpx import ByteStream, Response + +from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR +from litestar.testing.websocket_test_session import WebSocketTestSession + +if TYPE_CHECKING: + from httpx import Request + + from litestar.testing.client import AsyncTestClient, TestClient + from litestar.types import ( + HTTPDisconnectEvent, + HTTPRequestEvent, + Message, + Receive, + ReceiveMessage, + Send, + WebSocketScope, + ) + + +T = TypeVar("T", bound=Union["AsyncTestClient", "TestClient"]) + + +class ConnectionUpgradeExceptionError(Exception): + def __init__(self, session: WebSocketTestSession) -> None: + self.session = session + + +class SendReceiveContext(TypedDict): + request_complete: bool + response_complete: Event + raw_kwargs: dict[str, Any] + response_started: bool + template: str | None + context: Any | None + + +class TestClientTransport(Generic[T]): + def __init__( + self, + client: T, + raise_server_exceptions: bool = True, + root_path: str = "", + ) -> None: + self.client = client + self.raise_server_exceptions = raise_server_exceptions + self.root_path = root_path + + @staticmethod + def create_receive(request: Request, context: SendReceiveContext) -> Receive: + async def receive() -> ReceiveMessage: + if context["request_complete"]: + if not context["response_complete"].is_set(): + await context["response_complete"].wait() + disconnect_event: HTTPDisconnectEvent = {"type": "http.disconnect"} + return disconnect_event + + body = cast("Union[bytes, str, GeneratorType]", (request.read() or b"")) + request_event: HTTPRequestEvent = {"type": "http.request", "body": b"", "more_body": False} + if isinstance(body, GeneratorType): # pragma: no cover + try: + chunk = body.send(None) + request_event["body"] = chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + request_event["more_body"] = True + except StopIteration: + context["request_complete"] = True + else: + context["request_complete"] = True + request_event["body"] = body if isinstance(body, bytes) else body.encode("utf-8") + return request_event + + return receive + + @staticmethod + def create_send(request: Request, context: SendReceiveContext) -> Send: + async def send(message: Message) -> None: + if message["type"] == "http.response.start": + assert not context["response_started"], 'Received multiple "http.response.start" messages.' # noqa: S101 + context["raw_kwargs"]["status_code"] = message["status"] + context["raw_kwargs"]["headers"] = [ + (k.decode("utf-8"), v.decode("utf-8")) for k, v in message.get("headers", []) + ] + context["response_started"] = True + elif message["type"] == "http.response.body": + assert context["response_started"], 'Received "http.response.body" without "http.response.start".' # noqa: S101 + assert not context[ # noqa: S101 + "response_complete" + ].is_set(), 'Received "http.response.body" after response completed.' + body = message.get("body", b"") + more_body = message.get("more_body", False) + if request.method != "HEAD": + context["raw_kwargs"]["stream"].write(body) + if not more_body: + context["raw_kwargs"]["stream"].seek(0) + context["response_complete"].set() + elif message["type"] == "http.response.template": # type: ignore[comparison-overlap] # pragma: no cover + context["template"] = message["template"] # type: ignore[unreachable] + context["context"] = message["context"] + + return send + + def parse_request(self, request: Request) -> dict[str, Any]: + scheme = request.url.scheme + netloc = unquote(request.url.netloc.decode(encoding="ascii")) + path = request.url.path + raw_path = request.url.raw_path + query = request.url.query.decode(encoding="ascii") + default_port = 433 if scheme in {"https", "wss"} else 80 + + if ":" in netloc: + host, port_string = netloc.split(":", 1) + port = int(port_string) + else: + host = netloc + port = default_port + + host_header = request.headers.pop("host", host if port == default_port else f"{host}:{port}") + + headers = [(k.lower().encode(), v.encode()) for k, v in (("host", host_header), *request.headers.items())] + + return { + "type": "websocket" if scheme in {"ws", "wss"} else "http", + "path": unquote(path), + "raw_path": raw_path, + "root_path": self.root_path, + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ("testclient", 50000), + "server": (host, port), + } + + def handle_request(self, request: Request) -> Response: + scope = self.parse_request(request=request) + if scope["type"] == "websocket": + scope.update( + subprotocols=[value.strip() for value in request.headers.get("sec-websocket-protocol", "").split(",")] + ) + session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) # type: ignore[arg-type] + raise ConnectionUpgradeExceptionError(session) + + scope.update(method=request.method, http_version="1.1", extensions={"http.response.template": {}}) + + raw_kwargs: dict[str, Any] = {"stream": BytesIO()} + + try: + with self.client.portal() as portal: + response_complete = portal.call(Event) + context: SendReceiveContext = { + "response_complete": response_complete, + "request_complete": False, + "raw_kwargs": raw_kwargs, + "response_started": False, + "template": None, + "context": None, + } + portal.call( + self.client.app, + scope, + self.create_receive(request=request, context=context), + self.create_send(request=request, context=context), + ) + except BaseException as exc: # noqa: BLE001 + if self.raise_server_exceptions: + raise exc + return Response( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request + ) + else: + if not context["response_started"]: # pragma: no cover + if self.raise_server_exceptions: + assert context["response_started"], "TestClient did not receive any response." # noqa: S101 + return Response( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request + ) + + stream = ByteStream(raw_kwargs.pop("stream", BytesIO()).read()) + response = Response(**raw_kwargs, stream=stream, request=request) + setattr(response, "template", context["template"]) + setattr(response, "context", context["context"]) + return response + + async def handle_async_request(self, request: Request) -> Response: + return self.handle_request(request=request) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/websocket_test_session.py b/venv/lib/python3.11/site-packages/litestar/testing/websocket_test_session.py new file mode 100644 index 0000000..292e8a9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/websocket_test_session.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +from contextlib import ExitStack +from queue import Queue +from typing import TYPE_CHECKING, Any, Literal, cast + +from anyio import sleep + +from litestar.exceptions import WebSocketDisconnect +from litestar.serialization import decode_json, decode_msgpack, encode_json, encode_msgpack +from litestar.status_codes import WS_1000_NORMAL_CLOSURE + +if TYPE_CHECKING: + from litestar.testing.client.sync_client import TestClient + from litestar.types import ( + WebSocketConnectEvent, + WebSocketDisconnectEvent, + WebSocketReceiveMessage, + WebSocketScope, + WebSocketSendMessage, + ) + + +__all__ = ("WebSocketTestSession",) + + +class WebSocketTestSession: + exit_stack: ExitStack + + def __init__( + self, + client: TestClient[Any], + scope: WebSocketScope, + ) -> None: + self.client = client + self.scope = scope + self.accepted_subprotocol: str | None = None + self.receive_queue: Queue[WebSocketReceiveMessage] = Queue() + self.send_queue: Queue[WebSocketSendMessage | BaseException] = Queue() + self.extra_headers: list[tuple[bytes, bytes]] | None = None + + def __enter__(self) -> WebSocketTestSession: + self.exit_stack = ExitStack() + + portal = self.exit_stack.enter_context(self.client.portal()) + + try: + portal.start_task_soon(self.do_asgi_call) + event: WebSocketConnectEvent = {"type": "websocket.connect"} + self.receive_queue.put(event) + + message = self.receive(timeout=self.client.timeout.read) + self.accepted_subprotocol = cast("str | None", message.get("subprotocol", None)) + self.extra_headers = cast("list[tuple[bytes, bytes]] | None", message.get("headers", None)) + return self + except Exception: + self.exit_stack.close() + raise + + def __exit__(self, *args: Any) -> None: + try: + self.close() + finally: + self.exit_stack.close() + while not self.send_queue.empty(): + message = self.send_queue.get() + if isinstance(message, BaseException): + raise message + + async def do_asgi_call(self) -> None: + """The sub-thread in which the websocket session runs.""" + + async def receive() -> WebSocketReceiveMessage: + while self.receive_queue.empty(): + await sleep(0) + return self.receive_queue.get() + + async def send(message: WebSocketSendMessage) -> None: + if message["type"] == "websocket.accept": + headers = message.get("headers", []) + if headers: + headers_list = list(self.scope["headers"]) + headers_list.extend(headers) + self.scope["headers"] = headers_list + subprotocols = cast("str | None", message.get("subprotocols")) + if subprotocols: # pragma: no cover + self.scope["subprotocols"].append(subprotocols) + self.send_queue.put(message) + + try: + await self.client.app(self.scope, receive, send) + except BaseException as exc: + self.send_queue.put(exc) + raise + + def send(self, data: str | bytes, mode: Literal["text", "binary"] = "text", encoding: str = "utf-8") -> None: + """Sends a "receive" event. This is the inverse of the ASGI send method. + + Args: + data: Either a string or a byte string. + mode: The key to use - ``text`` or ``bytes`` + encoding: The encoding to use when encoding or decoding data. + + Returns: + None. + """ + if mode == "text": + data = data.decode(encoding) if isinstance(data, bytes) else data + text_event: WebSocketReceiveMessage = {"type": "websocket.receive", "text": data} # type: ignore[assignment] + self.receive_queue.put(text_event) + else: + data = data if isinstance(data, bytes) else data.encode(encoding) + binary_event: WebSocketReceiveMessage = {"type": "websocket.receive", "bytes": data} # type: ignore[assignment] + self.receive_queue.put(binary_event) + + def send_text(self, data: str, encoding: str = "utf-8") -> None: + """Sends the data using the ``text`` key. + + Args: + data: Data to send. + encoding: Encoding to use. + + Returns: + None + """ + self.send(data=data, encoding=encoding) + + def send_bytes(self, data: bytes, encoding: str = "utf-8") -> None: + """Sends the data using the ``bytes`` key. + + Args: + data: Data to send. + encoding: Encoding to use. + + Returns: + None + """ + self.send(data=data, mode="binary", encoding=encoding) + + def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None: + """Sends the given data as JSON. + + Args: + data: The data to send. + mode: Either ``text`` or ``binary`` + + Returns: + None + """ + self.send(encode_json(data), mode=mode) + + def send_msgpack(self, data: Any) -> None: + """Sends the given data as MessagePack. + + Args: + data: The data to send. + + Returns: + None + """ + self.send(encode_msgpack(data), mode="binary") + + def close(self, code: int = WS_1000_NORMAL_CLOSURE) -> None: + """Sends an 'websocket.disconnect' event. + + Args: + code: status code for closing the connection. + + Returns: + None. + """ + event: WebSocketDisconnectEvent = {"type": "websocket.disconnect", "code": code} + self.receive_queue.put(event) + + def receive(self, block: bool = True, timeout: float | None = None) -> WebSocketSendMessage: + """This is the base receive method. + + Args: + block: Block until a message is received + timeout: If ``block`` is ``True``, block at most ``timeout`` seconds + + Notes: + - you can use one of the other receive methods to extract the data from the message. + + Returns: + A websocket message. + """ + message = cast("WebSocketSendMessage", self.send_queue.get(block=block, timeout=timeout)) + + if isinstance(message, BaseException): + raise message + + if message["type"] == "websocket.close": + raise WebSocketDisconnect( + detail=cast("str", message.get("reason", "")), + code=message.get("code", WS_1000_NORMAL_CLOSURE), + ) + return message + + def receive_text(self, block: bool = True, timeout: float | None = None) -> str: + """Receive data in ``text`` mode and return a string + + Args: + block: Block until a message is received + timeout: If ``block`` is ``True``, block at most ``timeout`` seconds + + Returns: + A string value. + """ + message = self.receive(block=block, timeout=timeout) + return cast("str", message.get("text", "")) + + def receive_bytes(self, block: bool = True, timeout: float | None = None) -> bytes: + """Receive data in ``binary`` mode and return bytes + + Args: + block: Block until a message is received + timeout: If ``block`` is ``True``, block at most ``timeout`` seconds + + Returns: + A string value. + """ + message = self.receive(block=block, timeout=timeout) + return cast("bytes", message.get("bytes", b"")) + + def receive_json( + self, mode: Literal["text", "binary"] = "text", block: bool = True, timeout: float | None = None + ) -> Any: + """Receive data in either ``text`` or ``binary`` mode and decode it as JSON. + + Args: + mode: Either ``text`` or ``binary`` + block: Block until a message is received + timeout: If ``block`` is ``True``, block at most ``timeout`` seconds + + Returns: + An arbitrary value + """ + message = self.receive(block=block, timeout=timeout) + + if mode == "text": + return decode_json(cast("str", message.get("text", ""))) + + return decode_json(cast("bytes", message.get("bytes", b""))) + + def receive_msgpack(self, block: bool = True, timeout: float | None = None) -> Any: + message = self.receive(block=block, timeout=timeout) + return decode_msgpack(cast("bytes", message.get("bytes", b""))) diff --git a/venv/lib/python3.11/site-packages/litestar/types/__init__.py b/venv/lib/python3.11/site-packages/litestar/types/__init__.py new file mode 100644 index 0000000..90e3192 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/__init__.py @@ -0,0 +1,169 @@ +from .asgi_types import ( + ASGIApp, + ASGIVersion, + BaseScope, + HTTPDisconnectEvent, + HTTPReceiveMessage, + HTTPRequestEvent, + HTTPResponseBodyEvent, + HTTPResponseStartEvent, + HTTPScope, + HTTPSendMessage, + HTTPServerPushEvent, + LifeSpanReceive, + LifeSpanReceiveMessage, + LifeSpanScope, + LifeSpanSend, + LifeSpanSendMessage, + LifeSpanShutdownCompleteEvent, + LifeSpanShutdownEvent, + LifeSpanShutdownFailedEvent, + LifeSpanStartupCompleteEvent, + LifeSpanStartupEvent, + LifeSpanStartupFailedEvent, + Message, + Method, + Receive, + ReceiveMessage, + Scope, + ScopeSession, + Send, + WebSocketAcceptEvent, + WebSocketCloseEvent, + WebSocketConnectEvent, + WebSocketDisconnectEvent, + WebSocketReceiveEvent, + WebSocketReceiveMessage, + WebSocketResponseBodyEvent, + WebSocketResponseStartEvent, + WebSocketScope, + WebSocketSendEvent, + WebSocketSendMessage, +) +from .builtin_types import TypedDictClass +from .callable_types import ( + AfterExceptionHookHandler, + AfterRequestHookHandler, + AfterResponseHookHandler, + AnyCallable, + AnyGenerator, + AsyncAnyCallable, + BeforeMessageSendHookHandler, + BeforeRequestHookHandler, + CacheKeyBuilder, + ExceptionHandler, + GetLogger, + Guard, + LifespanHook, + OnAppInitHandler, + OperationIDCreator, + Serializer, +) +from .composite_types import ( + Dependencies, + ExceptionHandlersMap, + Middleware, + ParametersMap, + PathType, + ResponseCookies, + ResponseHeaders, + Scopes, + TypeDecodersSequence, + TypeEncodersMap, +) +from .empty import Empty, EmptyType +from .file_types import FileInfo, FileSystemProtocol +from .helper_types import AnyIOBackend, MaybePartial, OptionalSequence, SSEData, StreamType, SyncOrAsyncUnion +from .internal_types import ControllerRouterHandler, ReservedKwargs, RouteHandlerMapItem, RouteHandlerType +from .protocols import DataclassProtocol, Logger +from .serialization import DataContainerType, LitestarEncodableType + +__all__ = ( + "ASGIApp", + "ASGIVersion", + "AfterExceptionHookHandler", + "AfterRequestHookHandler", + "AfterResponseHookHandler", + "AnyCallable", + "AnyGenerator", + "AnyIOBackend", + "AsyncAnyCallable", + "BaseScope", + "BeforeMessageSendHookHandler", + "BeforeRequestHookHandler", + "CacheKeyBuilder", + "ControllerRouterHandler", + "DataContainerType", + "DataclassProtocol", + "Dependencies", + "Empty", + "EmptyType", + "ExceptionHandler", + "ExceptionHandlersMap", + "FileInfo", + "FileSystemProtocol", + "GetLogger", + "Guard", + "HTTPDisconnectEvent", + "HTTPReceiveMessage", + "HTTPRequestEvent", + "HTTPResponseBodyEvent", + "HTTPResponseStartEvent", + "HTTPScope", + "HTTPSendMessage", + "HTTPServerPushEvent", + "LifeSpanReceive", + "LifeSpanReceiveMessage", + "LifeSpanScope", + "LifeSpanSend", + "LifeSpanSendMessage", + "LifeSpanShutdownCompleteEvent", + "LifeSpanShutdownEvent", + "LifeSpanShutdownFailedEvent", + "LifeSpanStartupCompleteEvent", + "LifeSpanStartupEvent", + "LifeSpanStartupFailedEvent", + "LifespanHook", + "LitestarEncodableType", + "Logger", + "MaybePartial", + "Message", + "Method", + "Middleware", + "OnAppInitHandler", + "OperationIDCreator", + "OptionalSequence", + "ParametersMap", + "PathType", + "Receive", + "ReceiveMessage", + "ReservedKwargs", + "ResponseCookies", + "ResponseHeaders", + "RouteHandlerMapItem", + "RouteHandlerType", + "Scope", + "ScopeSession", + "Scopes", + "Send", + "Serializer", + "StreamType", + "SSEData", + "SyncOrAsyncUnion", + "TypeDecodersSequence", + "TypeEncodersMap", + "TypedDictClass", + "WebSocketAcceptEvent", + "WebSocketCloseEvent", + "WebSocketConnectEvent", + "WebSocketDisconnectEvent", + "WebSocketReceiveEvent", + "WebSocketReceiveMessage", + "WebSocketReceiveMessage", + "WebSocketResponseBodyEvent", + "WebSocketResponseStartEvent", + "WebSocketScope", + "WebSocketSendEvent", + "WebSocketSendMessage", + "WebSocketSendMessage", +) diff --git a/venv/lib/python3.11/site-packages/litestar/types/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6f64e92 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/types/__pycache__/asgi_types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/asgi_types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..cbed524 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/asgi_types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/types/__pycache__/builtin_types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/builtin_types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..42a633c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/builtin_types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/types/__pycache__/callable_types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/callable_types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..e956ed6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/callable_types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/types/__pycache__/composite_types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/composite_types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4761e1c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/composite_types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/types/__pycache__/empty.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/empty.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1ce4b48 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/empty.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/types/__pycache__/file_types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/file_types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c2a5ffd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/file_types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/types/__pycache__/helper_types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/helper_types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..e0ee6eb --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/helper_types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/types/__pycache__/internal_types.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/internal_types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8d0ac7e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/internal_types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/types/__pycache__/protocols.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/protocols.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1caf78a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/protocols.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/types/__pycache__/serialization.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/serialization.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..3bd8dc5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/__pycache__/serialization.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/types/asgi_types.py b/venv/lib/python3.11/site-packages/litestar/types/asgi_types.py new file mode 100644 index 0000000..0a6f77f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/asgi_types.py @@ -0,0 +1,343 @@ +"""Includes code adapted from https://github.com/django/asgiref/blob/main/asgiref/typing.py. + +Copyright (c) Django Software Foundation and individual contributors. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. Neither the name of Django nor the names of its contributors may be used + to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Literal, + Tuple, + TypedDict, + Union, +) + +from litestar.enums import HttpMethod + +__all__ = ( + "ASGIApp", + "ASGIVersion", + "BaseScope", + "HeaderScope", + "HTTPDisconnectEvent", + "HTTPReceiveMessage", + "HTTPRequestEvent", + "HTTPResponseBodyEvent", + "HTTPResponseStartEvent", + "HTTPScope", + "HTTPSendMessage", + "HTTPServerPushEvent", + "LifeSpanReceive", + "LifeSpanReceiveMessage", + "LifeSpanScope", + "LifeSpanSend", + "LifeSpanSendMessage", + "LifeSpanShutdownCompleteEvent", + "LifeSpanShutdownEvent", + "LifeSpanShutdownFailedEvent", + "LifeSpanStartupCompleteEvent", + "LifeSpanStartupEvent", + "LifeSpanStartupFailedEvent", + "Message", + "Method", + "RawHeaders", + "RawHeadersList", + "Receive", + "ReceiveMessage", + "Scope", + "ScopeSession", + "Send", + "WebSocketAcceptEvent", + "WebSocketCloseEvent", + "WebSocketConnectEvent", + "WebSocketDisconnectEvent", + "WebSocketMode", + "WebSocketReceiveEvent", + "WebSocketReceiveMessage", + "WebSocketResponseBodyEvent", + "WebSocketResponseStartEvent", + "WebSocketScope", + "WebSocketSendEvent", + "WebSocketSendMessage", +) + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from litestar.app import Litestar + from litestar.enums import ScopeType + from litestar.types.empty import EmptyType + + from .internal_types import RouteHandlerType + from .serialization import DataContainerType + +Method: TypeAlias = Union[Literal["GET", "POST", "DELETE", "PATCH", "PUT", "HEAD", "TRACE", "OPTIONS"], HttpMethod] +ScopeSession: TypeAlias = "EmptyType | Dict[str, Any] | DataContainerType | None" + + +class ASGIVersion(TypedDict): + """ASGI spec version.""" + + spec_version: str + version: Literal["3.0"] + + +class HeaderScope(TypedDict): + """Base class for ASGI-scopes that supports headers.""" + + headers: RawHeaders + + +class BaseScope(HeaderScope): + """Base ASGI-scope.""" + + app: Litestar + asgi: ASGIVersion + auth: Any + client: tuple[str, int] | None + extensions: dict[str, dict[object, object]] | None + http_version: str + path: str + path_params: dict[str, str] + query_string: bytes + raw_path: bytes + root_path: str + route_handler: RouteHandlerType + scheme: str + server: tuple[str, int | None] | None + session: ScopeSession + state: dict[str, Any] + user: Any + + +class HTTPScope(BaseScope): + """HTTP-ASGI-scope.""" + + method: Method + type: Literal[ScopeType.HTTP] + + +class WebSocketScope(BaseScope): + """WebSocket-ASGI-scope.""" + + subprotocols: list[str] + type: Literal[ScopeType.WEBSOCKET] + + +class LifeSpanScope(TypedDict): + """Lifespan-ASGI-scope.""" + + app: Litestar + asgi: ASGIVersion + type: Literal["lifespan"] + + +class HTTPRequestEvent(TypedDict): + """ASGI `http.request` event.""" + + type: Literal["http.request"] + body: bytes + more_body: bool + + +class HTTPResponseStartEvent(HeaderScope): + """ASGI `http.response.start` event.""" + + type: Literal["http.response.start"] + status: int + + +class HTTPResponseBodyEvent(TypedDict): + """ASGI `http.response.body` event.""" + + type: Literal["http.response.body"] + body: bytes + more_body: bool + + +class HTTPServerPushEvent(HeaderScope): + """ASGI `http.response.push` event.""" + + type: Literal["http.response.push"] + path: str + + +class HTTPDisconnectEvent(TypedDict): + """ASGI `http.disconnect` event.""" + + type: Literal["http.disconnect"] + + +class WebSocketConnectEvent(TypedDict): + """ASGI `websocket.connect` event.""" + + type: Literal["websocket.connect"] + + +class WebSocketAcceptEvent(HeaderScope): + """ASGI `websocket.accept` event.""" + + type: Literal["websocket.accept"] + subprotocol: str | None + + +class WebSocketReceiveEvent(TypedDict): + """ASGI `websocket.receive` event.""" + + type: Literal["websocket.receive"] + bytes: bytes | None + text: str | None + + +class WebSocketSendEvent(TypedDict): + """ASGI `websocket.send` event.""" + + type: Literal["websocket.send"] + bytes: bytes | None + text: str | None + + +class WebSocketResponseStartEvent(HeaderScope): + """ASGI `websocket.http.response.start` event.""" + + type: Literal["websocket.http.response.start"] + status: int + + +class WebSocketResponseBodyEvent(TypedDict): + """ASGI `websocket.http.response.body` event.""" + + type: Literal["websocket.http.response.body"] + body: bytes + more_body: bool + + +class WebSocketDisconnectEvent(TypedDict): + """ASGI `websocket.disconnect` event.""" + + type: Literal["websocket.disconnect"] + code: int + + +class WebSocketCloseEvent(TypedDict): + """ASGI `websocket.close` event.""" + + type: Literal["websocket.close"] + code: int + reason: str | None + + +class LifeSpanStartupEvent(TypedDict): + """ASGI `lifespan.startup` event.""" + + type: Literal["lifespan.startup"] + + +class LifeSpanShutdownEvent(TypedDict): + """ASGI `lifespan.shutdown` event.""" + + type: Literal["lifespan.shutdown"] + + +class LifeSpanStartupCompleteEvent(TypedDict): + """ASGI `lifespan.startup.complete` event.""" + + type: Literal["lifespan.startup.complete"] + + +class LifeSpanStartupFailedEvent(TypedDict): + """ASGI `lifespan.startup.failed` event.""" + + type: Literal["lifespan.startup.failed"] + message: str + + +class LifeSpanShutdownCompleteEvent(TypedDict): + """ASGI `lifespan.shutdown.complete` event.""" + + type: Literal["lifespan.shutdown.complete"] + + +class LifeSpanShutdownFailedEvent(TypedDict): + """ASGI `lifespan.shutdown.failed` event.""" + + type: Literal["lifespan.shutdown.failed"] + message: str + + +HTTPReceiveMessage: TypeAlias = Union[ + HTTPRequestEvent, + HTTPDisconnectEvent, +] +WebSocketReceiveMessage: TypeAlias = Union[ + WebSocketConnectEvent, + WebSocketReceiveEvent, + WebSocketDisconnectEvent, +] +LifeSpanReceiveMessage: TypeAlias = Union[ + LifeSpanStartupEvent, + LifeSpanShutdownEvent, +] +HTTPSendMessage: TypeAlias = Union[ + HTTPResponseStartEvent, + HTTPResponseBodyEvent, + HTTPServerPushEvent, + HTTPDisconnectEvent, +] +WebSocketSendMessage: TypeAlias = Union[ + WebSocketAcceptEvent, + WebSocketSendEvent, + WebSocketResponseStartEvent, + WebSocketResponseBodyEvent, + WebSocketCloseEvent, +] +LifeSpanSendMessage: TypeAlias = Union[ + LifeSpanStartupCompleteEvent, + LifeSpanStartupFailedEvent, + LifeSpanShutdownCompleteEvent, + LifeSpanShutdownFailedEvent, +] +LifeSpanReceive: TypeAlias = Callable[..., Awaitable[LifeSpanReceiveMessage]] +LifeSpanSend: TypeAlias = Callable[[LifeSpanSendMessage], Awaitable[None]] +Message: TypeAlias = Union[HTTPSendMessage, WebSocketSendMessage] +ReceiveMessage: TypeAlias = Union[HTTPReceiveMessage, WebSocketReceiveMessage] +Scope: TypeAlias = Union[HTTPScope, WebSocketScope] +Receive: TypeAlias = Callable[..., Awaitable[Union[HTTPReceiveMessage, WebSocketReceiveMessage]]] +Send: TypeAlias = Callable[[Message], Awaitable[None]] +ASGIApp: TypeAlias = Callable[[Scope, Receive, Send], Awaitable[None]] +RawHeaders: TypeAlias = Iterable[Tuple[bytes, bytes]] +RawHeadersList: TypeAlias = List[Tuple[bytes, bytes]] +WebSocketMode: TypeAlias = Literal["text", "binary"] diff --git a/venv/lib/python3.11/site-packages/litestar/types/builtin_types.py b/venv/lib/python3.11/site-packages/litestar/types/builtin_types.py new file mode 100644 index 0000000..335dedd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/builtin_types.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Type, Union + +from typing_extensions import _TypedDictMeta # type: ignore[attr-defined] + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + +__all__ = ( + "NoneType", + "UnionType", + "UnionTypes", + "TypedDictClass", +) + +NoneType: type[None] = type(None) + +try: + from types import UnionType # type: ignore[attr-defined] +except ImportError: + UnionType: TypeAlias = Union # type: ignore[no-redef] + +UnionTypes = {UnionType, Union} +TypedDictClass: TypeAlias = Type[_TypedDictMeta] diff --git a/venv/lib/python3.11/site-packages/litestar/types/callable_types.py b/venv/lib/python3.11/site-packages/litestar/types/callable_types.py new file mode 100644 index 0000000..36055d7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/callable_types.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, TypeVar + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from litestar.app import Litestar + from litestar.config.app import AppConfig + from litestar.connection.base import ASGIConnection + from litestar.connection.request import Request + from litestar.handlers.base import BaseRouteHandler + from litestar.handlers.http_handlers import HTTPRouteHandler + from litestar.response.base import Response + from litestar.types.asgi_types import ASGIApp, Message, Method, Scope + from litestar.types.helper_types import SyncOrAsyncUnion + from litestar.types.internal_types import PathParameterDefinition + from litestar.types.protocols import Logger + +ExceptionT = TypeVar("ExceptionT", bound=Exception) + +AfterExceptionHookHandler: TypeAlias = "Callable[[ExceptionT, Scope], SyncOrAsyncUnion[None]]" +AfterRequestHookHandler: TypeAlias = ( + "Callable[[ASGIApp], SyncOrAsyncUnion[ASGIApp]] | Callable[[Response], SyncOrAsyncUnion[Response]]" +) +AfterResponseHookHandler: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]" +AsyncAnyCallable: TypeAlias = Callable[..., Awaitable[Any]] +AnyCallable: TypeAlias = Callable[..., Any] +AnyGenerator: TypeAlias = "Generator[Any, Any, Any] | AsyncGenerator[Any, Any]" +BeforeMessageSendHookHandler: TypeAlias = "Callable[[Message, Scope], SyncOrAsyncUnion[None]]" +BeforeRequestHookHandler: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]" +CacheKeyBuilder: TypeAlias = "Callable[[Request], str]" +ExceptionHandler: TypeAlias = "Callable[[Request, ExceptionT], Response]" +ExceptionLoggingHandler: TypeAlias = "Callable[[Logger, Scope, list[str]], None]" +GetLogger: TypeAlias = "Callable[..., Logger]" +Guard: TypeAlias = "Callable[[ASGIConnection, BaseRouteHandler], SyncOrAsyncUnion[None]]" +LifespanHook: TypeAlias = "Callable[[Litestar], SyncOrAsyncUnion[Any]] | Callable[[], SyncOrAsyncUnion[Any]]" +OnAppInitHandler: TypeAlias = "Callable[[AppConfig], AppConfig]" +OperationIDCreator: TypeAlias = "Callable[[HTTPRouteHandler, Method, list[str | PathParameterDefinition]], str]" +Serializer: TypeAlias = Callable[[Any], Any] diff --git a/venv/lib/python3.11/site-packages/litestar/types/composite_types.py b/venv/lib/python3.11/site-packages/litestar/types/composite_types.py new file mode 100644 index 0000000..afd905f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/composite_types.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterator, + Literal, + Mapping, + MutableMapping, + Sequence, + Tuple, + Type, + Union, +) + +__all__ = ( + "Dependencies", + "ExceptionHandlersMap", + "Middleware", + "ParametersMap", + "PathType", + "ResponseCookies", + "ResponseHeaders", + "Scopes", + "TypeEncodersMap", +) + + +if TYPE_CHECKING: + from os import PathLike + from pathlib import Path + + from typing_extensions import TypeAlias + + from litestar.datastructures.cookie import Cookie + from litestar.datastructures.response_header import ResponseHeader + from litestar.di import Provide + from litestar.enums import ScopeType + from litestar.middleware.base import DefineMiddleware, MiddlewareProtocol + from litestar.params import ParameterKwarg + + from .asgi_types import ASGIApp + from .callable_types import AnyCallable, ExceptionHandler + +Dependencies: TypeAlias = "Mapping[str, Union[Provide, AnyCallable]]" +ExceptionHandlersMap: TypeAlias = "MutableMapping[Union[int, Type[Exception]], ExceptionHandler]" +Middleware: TypeAlias = "Union[Callable[..., ASGIApp], DefineMiddleware, Iterator[Tuple[ASGIApp, Dict[str, Any]]], Type[MiddlewareProtocol]]" +ParametersMap: TypeAlias = "Mapping[str, ParameterKwarg]" +PathType: TypeAlias = "Union[Path, PathLike, str]" +ResponseCookies: TypeAlias = "Union[Sequence[Cookie], Mapping[str, str]]" +ResponseHeaders: TypeAlias = "Union[Sequence[ResponseHeader], Mapping[str, str]]" +Scopes: TypeAlias = "set[Literal[ScopeType.HTTP, ScopeType.WEBSOCKET]]" +TypeDecodersSequence: TypeAlias = "Sequence[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]" +TypeEncodersMap: TypeAlias = "Mapping[Any, Callable[[Any], Any]]" diff --git a/venv/lib/python3.11/site-packages/litestar/types/empty.py b/venv/lib/python3.11/site-packages/litestar/types/empty.py new file mode 100644 index 0000000..dee9bc1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/empty.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +__all__ = ("Empty", "EmptyType") + +from enum import Enum +from typing import Final, Literal + + +class _EmptyEnum(Enum): + """A sentinel enum used as placeholder.""" + + EMPTY = 0 + + +EmptyType = Literal[_EmptyEnum.EMPTY] +Empty: Final = _EmptyEnum.EMPTY diff --git a/venv/lib/python3.11/site-packages/litestar/types/file_types.py b/venv/lib/python3.11/site-packages/litestar/types/file_types.py new file mode 100644 index 0000000..7848721 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/file_types.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from typing import ( + IO, + TYPE_CHECKING, + Any, + AnyStr, + Awaitable, + Literal, + Protocol, + TypedDict, + overload, +) + +__all__ = ("FileInfo", "FileSystemProtocol") + + +if TYPE_CHECKING: + from _typeshed import OpenBinaryMode, OpenTextMode + from anyio import AsyncFile + from typing_extensions import NotRequired + + from litestar.types.composite_types import PathType + + +class FileInfo(TypedDict): + """File information gathered from a file system.""" + + created: float + """Created time stamp, equal to 'stat_result.st_ctime'.""" + destination: NotRequired[bytes | None] + """Output of loading a symbolic link.""" + gid: int + """Group ID of owner.""" + ino: int + """inode value.""" + islink: bool + """True if the file is a symbolic link.""" + mode: int + """Protection mode.""" + mtime: float + """Modified time stamp.""" + name: str + """The path of the file.""" + nlink: int + """Number of hard links.""" + size: int + """Total size, in bytes.""" + type: Literal["file", "directory", "other"] + """The type of the file system object.""" + uid: int + """User ID of owner.""" + + +class FileSystemProtocol(Protocol): + """Base protocol used to interact with a file-system. + + This protocol is commensurable with the file systems + exported by the `fsspec <https://filesystem-spec.readthedocs.io/en/latest/>` library. + """ + + def info(self, path: PathType, **kwargs: Any) -> FileInfo | Awaitable[FileInfo]: + """Retrieve information about a given file path. + + Args: + path: A file path. + **kwargs: Any additional kwargs. + + Returns: + A dictionary of file info. + """ + ... + + @overload + def open( + self, + file: PathType, + mode: OpenBinaryMode, + buffering: int = -1, + ) -> IO[bytes] | Awaitable[AsyncFile[bytes]]: ... + + @overload + def open( + self, + file: PathType, + mode: OpenTextMode, + buffering: int = -1, + ) -> IO[str] | Awaitable[AsyncFile[str]]: ... + + def open( # pyright: ignore + self, + file: PathType, + mode: str, + buffering: int = -1, + ) -> IO[AnyStr] | Awaitable[AsyncFile[AnyStr]]: + """Return a file-like object from the filesystem. + + Notes: + - The return value must function correctly in a context ``with`` block. + + Args: + file: Path to the target file. + mode: Mode, similar to the built ``open``. + buffering: Buffer size. + """ + ... diff --git a/venv/lib/python3.11/site-packages/litestar/types/helper_types.py b/venv/lib/python3.11/site-packages/litestar/types/helper_types.py new file mode 100644 index 0000000..588ae54 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/helper_types.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from functools import partial +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Dict, + Iterable, + Iterator, + Literal, + Optional, + Sequence, + TypeVar, + Union, +) + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from litestar.response.sse import ServerSentEventMessage + + +T = TypeVar("T") + +__all__ = ("OptionalSequence", "SyncOrAsyncUnion", "AnyIOBackend", "StreamType", "MaybePartial", "SSEData") + +OptionalSequence: TypeAlias = Optional[Sequence[T]] +"""Types 'T' as union of Sequence[T] and None.""" + +SyncOrAsyncUnion: TypeAlias = Union[T, Awaitable[T]] +"""Types 'T' as a union of T and awaitable T.""" + + +AnyIOBackend: TypeAlias = Literal["asyncio", "trio"] +"""Anyio backend names.""" + +StreamType: TypeAlias = Union[Iterable[T], Iterator[T], AsyncIterable[T], AsyncIterator[T]] +"""A stream type.""" + +MaybePartial: TypeAlias = Union[T, partial] +"""A potentially partial callable.""" + +SSEData: TypeAlias = Union[int, str, bytes, Dict[str, Any], "ServerSentEventMessage"] +"""A type alias for SSE data.""" diff --git a/venv/lib/python3.11/site-packages/litestar/types/internal_types.py b/venv/lib/python3.11/site-packages/litestar/types/internal_types.py new file mode 100644 index 0000000..d473c22 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/internal_types.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Literal, NamedTuple + +from litestar.utils.deprecation import warn_deprecation + +__all__ = ( + "ControllerRouterHandler", + "PathParameterDefinition", + "PathParameterDefinition", + "ReservedKwargs", + "RouteHandlerMapItem", + "RouteHandlerType", +) + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from litestar.app import Litestar + from litestar.controller import Controller + from litestar.handlers.asgi_handlers import ASGIRouteHandler + from litestar.handlers.http_handlers import HTTPRouteHandler + from litestar.handlers.websocket_handlers import WebsocketRouteHandler + from litestar.router import Router + from litestar.template import TemplateConfig + from litestar.template.config import EngineType + from litestar.types import Method + +ReservedKwargs: TypeAlias = Literal["request", "socket", "headers", "query", "cookies", "state", "data"] +RouteHandlerType: TypeAlias = "HTTPRouteHandler | WebsocketRouteHandler | ASGIRouteHandler" +ControllerRouterHandler: TypeAlias = "type[Controller] | RouteHandlerType | Router | Callable[..., Any]" +RouteHandlerMapItem: TypeAlias = 'dict[Method | Literal["websocket", "asgi"], RouteHandlerType]' +TemplateConfigType: TypeAlias = "TemplateConfig[EngineType]" + +# deprecated +_LitestarType: TypeAlias = "Litestar" + + +class PathParameterDefinition(NamedTuple): + """Path parameter tuple.""" + + name: str + full: str + type: type + parser: Callable[[str], Any] | None + + +def __getattr__(name: str) -> Any: + if name == "LitestarType": + warn_deprecation( + "2.2.1", + "LitestarType", + "import", + removal_in="3.0.0", + alternative="Litestar", + ) + return _LitestarType + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/venv/lib/python3.11/site-packages/litestar/types/protocols.py b/venv/lib/python3.11/site-packages/litestar/types/protocols.py new file mode 100644 index 0000000..4e509ee --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/protocols.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from typing import Any, ClassVar, Collection, Iterable, Protocol, TypeVar, runtime_checkable + +__all__ = ( + "DataclassProtocol", + "InstantiableCollection", + "Logger", +) + + +class Logger(Protocol): + """Logger protocol.""" + + def debug(self, event: str, *args: Any, **kwargs: Any) -> Any: + """Output a log message at 'DEBUG' level. + + Args: + event: Log message. + *args: Any args. + **kwargs: Any kwargs. + """ + + def info(self, event: str, *args: Any, **kwargs: Any) -> Any: + """Output a log message at 'INFO' level. + + Args: + event: Log message. + *args: Any args. + **kwargs: Any kwargs. + """ + + def warning(self, event: str, *args: Any, **kwargs: Any) -> Any: + """Output a log message at 'WARNING' level. + + Args: + event: Log message. + *args: Any args. + **kwargs: Any kwargs. + """ + + def warn(self, event: str, *args: Any, **kwargs: Any) -> Any: + """Output a log message at 'WARN' level. + + Args: + event: Log message. + *args: Any args. + **kwargs: Any kwargs. + """ + + def error(self, event: str, *args: Any, **kwargs: Any) -> Any: + """Output a log message at 'ERROR' level. + + Args: + event: Log message. + *args: Any args. + **kwargs: Any kwargs. + """ + + def fatal(self, event: str, *args: Any, **kwargs: Any) -> Any: + """Output a log message at 'FATAL' level. + + Args: + event: Log message. + *args: Any args. + **kwargs: Any kwargs. + """ + + def exception(self, event: str, *args: Any, **kwargs: Any) -> Any: + """Log a message with level 'ERROR' on this logger. The arguments are interpreted as for debug(). Exception info + is added to the logging message. + + Args: + event: Log message. + *args: Any args. + **kwargs: Any kwargs. + """ + + def critical(self, event: str, *args: Any, **kwargs: Any) -> Any: + """Output a log message at 'INFO' level. + + Args: + event: Log message. + *args: Any args. + **kwargs: Any kwargs. + """ + + def setLevel(self, level: int) -> None: # noqa: N802 + """Set the log level + + Args: + level: Log level to set as an integer + + Returns: + None + """ + + +@runtime_checkable +class DataclassProtocol(Protocol): + """Protocol for instance checking dataclasses""" + + __dataclass_fields__: ClassVar[dict[str, Any]] + + +T_co = TypeVar("T_co", covariant=True) + + +@runtime_checkable +class InstantiableCollection(Collection[T_co], Protocol[T_co]): # pyright: ignore + """A protocol for instantiable collection types.""" + + def __init__(self, iterable: Iterable[T_co], /) -> None: ... diff --git a/venv/lib/python3.11/site-packages/litestar/types/serialization.py b/venv/lib/python3.11/site-packages/litestar/types/serialization.py new file mode 100644 index 0000000..0f61e10 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/types/serialization.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections import deque + from collections.abc import Collection + from datetime import date, datetime, time + from decimal import Decimal + from enum import Enum, IntEnum + from ipaddress import ( + IPv4Address, + IPv4Interface, + IPv4Network, + IPv6Address, + IPv6Interface, + IPv6Network, + ) + from pathlib import Path, PurePath + from re import Pattern + from uuid import UUID + + from msgspec import Raw, Struct + from msgspec.msgpack import Ext + from typing_extensions import TypeAlias + + from litestar.types import DataclassProtocol, TypedDictClass + + try: + from pydantic import BaseModel + except ImportError: + BaseModel = Any # type: ignore[assignment, misc] + + try: + from attrs import AttrsInstance + except ImportError: + AttrsInstance = Any # type: ignore[assignment, misc] + +__all__ = ( + "LitestarEncodableType", + "EncodableBuiltinType", + "EncodableBuiltinCollectionType", + "EncodableStdLibType", + "EncodableStdLibIPType", + "EncodableMsgSpecType", + "DataContainerType", +) + +EncodableBuiltinType: TypeAlias = "None | bool | int | float | str | bytes | bytearray" +EncodableBuiltinCollectionType: TypeAlias = "list | tuple | set | frozenset | dict | Collection" +EncodableStdLibType: TypeAlias = ( + "date | datetime | deque | time | UUID | Decimal | Enum | IntEnum | DataclassProtocol | Path | PurePath | Pattern" +) +EncodableStdLibIPType: TypeAlias = ( + "IPv4Address | IPv4Interface | IPv4Network | IPv6Address | IPv6Interface | IPv6Network" +) +EncodableMsgSpecType: TypeAlias = "Ext | Raw | Struct" +LitestarEncodableType: TypeAlias = "EncodableBuiltinType | EncodableBuiltinCollectionType | EncodableStdLibType | EncodableStdLibIPType | EncodableMsgSpecType | BaseModel | AttrsInstance" # pyright: ignore +DataContainerType: TypeAlias = "Struct | BaseModel | AttrsInstance | TypedDictClass | DataclassProtocol" # pyright: ignore diff --git a/venv/lib/python3.11/site-packages/litestar/typing.py b/venv/lib/python3.11/site-packages/litestar/typing.py new file mode 100644 index 0000000..3a27557 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/typing.py @@ -0,0 +1,636 @@ +from __future__ import annotations + +from collections import abc, deque +from copy import deepcopy +from dataclasses import dataclass, is_dataclass, replace +from inspect import Parameter, Signature +from typing import ( + Any, + AnyStr, + Callable, + Collection, + ForwardRef, + Literal, + Mapping, + Protocol, + Sequence, + TypeVar, + cast, +) + +from msgspec import UnsetType +from typing_extensions import NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict + +from litestar.exceptions import ImproperlyConfiguredException +from litestar.openapi.spec import Example +from litestar.params import BodyKwarg, DependencyKwarg, KwargDefinition, ParameterKwarg +from litestar.types import Empty +from litestar.types.builtin_types import NoneType, UnionTypes +from litestar.utils.predicates import ( + is_annotated_type, + is_any, + is_class_and_subclass, + is_generic, + is_non_string_iterable, + is_non_string_sequence, + is_union, +) +from litestar.utils.typing import ( + get_instantiable_origin, + get_safe_generic_origin, + get_type_hints_with_generics_resolved, + make_non_optional_union, + unwrap_annotation, +) + +__all__ = ("FieldDefinition",) + +T = TypeVar("T", bound=KwargDefinition) + + +class _KwargMetaExtractor(Protocol): + @staticmethod + def matches(annotation: Any, name: str | None, default: Any) -> bool: ... + + @staticmethod + def extract(annotation: Any, default: Any) -> Any: ... + + +_KWARG_META_EXTRACTORS: set[_KwargMetaExtractor] = set() + + +def _unpack_predicate(value: Any) -> dict[str, Any]: + try: + from annotated_types import Predicate + + if isinstance(value, Predicate): + if value.func == str.islower: + return {"lower_case": True} + if value.func == str.isupper: + return {"upper_case": True} + if value.func == str.isascii: + return {"pattern": "[[:ascii:]]"} + if value.func == str.isdigit: + return {"pattern": "[[:digit:]]"} + except ImportError: + pass + + return {} + + +def _parse_metadata(value: Any, is_sequence_container: bool, extra: dict[str, Any] | None) -> dict[str, Any]: + """Parse metadata from a value. + + Args: + value: A metadata value from annotation, namely anything stored under Annotated[x, metadata...] + is_sequence_container: Whether the type is a sequence container (list, tuple etc...) + extra: Extra key values to parse. + + Returns: + A dictionary of constraints, which fulfill the kwargs of a KwargDefinition class. + """ + extra = { + **cast("dict[str, Any]", extra or getattr(value, "extra", None) or {}), + **(getattr(value, "json_schema_extra", None) or {}), + } + example_list: list[Any] | None + if example := extra.pop("example", None): + example_list = [Example(value=example)] + elif examples := getattr(value, "examples", None): + example_list = [Example(value=example) for example in cast("list[str]", examples)] + else: + example_list = None + + return { + k: v + for k, v in { + "gt": getattr(value, "gt", None), + "ge": getattr(value, "ge", None), + "lt": getattr(value, "lt", None), + "le": getattr(value, "le", None), + "multiple_of": getattr(value, "multiple_of", None), + "min_length": None if is_sequence_container else getattr(value, "min_length", None), + "max_length": None if is_sequence_container else getattr(value, "max_length", None), + "description": getattr(value, "description", None), + "examples": example_list, + "title": getattr(value, "title", None), + "lower_case": getattr(value, "to_lower", None), + "upper_case": getattr(value, "to_upper", None), + "pattern": getattr(value, "regex", getattr(value, "pattern", None)), + "min_items": getattr(value, "min_items", getattr(value, "min_length", None)) + if is_sequence_container + else None, + "max_items": getattr(value, "max_items", getattr(value, "max_length", None)) + if is_sequence_container + else None, + "const": getattr(value, "const", None) is not None, + **extra, + }.items() + if v is not None + } + + +def _traverse_metadata( + metadata: Sequence[Any], is_sequence_container: bool, extra: dict[str, Any] | None +) -> dict[str, Any]: + """Recursively traverse metadata from a value. + + Args: + metadata: A list of metadata values from annotation, namely anything stored under Annotated[x, metadata...] + is_sequence_container: Whether the container is a sequence container (list, tuple etc...) + extra: Extra key values to parse. + + Returns: + A dictionary of constraints, which fulfill the kwargs of a KwargDefinition class. + """ + constraints: dict[str, Any] = {} + for value in metadata: + if isinstance(value, (list, set, frozenset, deque)): + constraints.update( + _traverse_metadata( + metadata=cast("Sequence[Any]", value), is_sequence_container=is_sequence_container, extra=extra + ) + ) + elif is_annotated_type(value) and (type_args := [v for v in get_args(value) if v is not None]): + # annotated values can be nested inside other annotated values + # this behaviour is buggy in python 3.8, hence we need to guard here. + if len(type_args) > 1: + constraints.update( + _traverse_metadata(metadata=type_args[1:], is_sequence_container=is_sequence_container, extra=extra) + ) + elif unpacked_predicate := _unpack_predicate(value): + constraints.update(unpacked_predicate) + else: + constraints.update(_parse_metadata(value=value, is_sequence_container=is_sequence_container, extra=extra)) + return constraints + + +def _create_metadata_from_type( + metadata: Sequence[Any], model: type[T], annotation: Any, extra: dict[str, Any] | None +) -> tuple[T | None, dict[str, Any]]: + is_sequence_container = is_non_string_sequence(annotation) + result = _traverse_metadata(metadata=metadata, is_sequence_container=is_sequence_container, extra=extra) + + constraints = {k: v for k, v in result.items() if k in dir(model)} + extra = {k: v for k, v in result.items() if k not in constraints} + return model(**constraints) if constraints else None, extra + + +@dataclass(frozen=True) +class FieldDefinition: + """Represents a function parameter or type annotation.""" + + __slots__ = ( + "annotation", + "args", + "default", + "extra", + "inner_types", + "instantiable_origin", + "kwarg_definition", + "metadata", + "name", + "origin", + "raw", + "safe_generic_origin", + "type_wrappers", + ) + + raw: Any + """The annotation exactly as received.""" + annotation: Any + """The annotation with any "wrapper" types removed, e.g. Annotated.""" + type_wrappers: tuple[type, ...] + """A set of all "wrapper" types, e.g. Annotated.""" + origin: Any + """The result of calling ``get_origin(annotation)`` after unwrapping Annotated, e.g. list.""" + args: tuple[Any, ...] + """The result of calling ``get_args(annotation)`` after unwrapping Annotated, e.g. (int,).""" + metadata: tuple[Any, ...] + """Any metadata associated with the annotation via ``Annotated``.""" + instantiable_origin: Any + """An equivalent type to ``origin`` that can be safely instantiated. E.g., ``Sequence`` -> ``list``.""" + safe_generic_origin: Any + """An equivalent type to ``origin`` that can be safely used as a generic type across all supported Python versions. + + This is to serve safely rebuilding a generic outer type with different args at runtime. + """ + inner_types: tuple[FieldDefinition, ...] + """The type's generic args parsed as ``FieldDefinition``, if applicable.""" + default: Any + """Default value of the field.""" + extra: dict[str, Any] + """A mapping of extra values.""" + kwarg_definition: KwargDefinition | DependencyKwarg | None + """Kwarg Parameter.""" + name: str + """Field name.""" + + def __deepcopy__(self, memo: dict[str, Any]) -> Self: + return type(self)(**{attr: deepcopy(getattr(self, attr)) for attr in self.__slots__}) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, FieldDefinition): + return False + + if self.origin: + return self.origin == other.origin and self.inner_types == other.inner_types + + return self.annotation == other.annotation # type: ignore[no-any-return] + + def __hash__(self) -> int: + return hash((self.name, self.raw, self.annotation, self.origin, self.inner_types)) + + @classmethod + def _extract_metadata( + cls, annotation: Any, name: str | None, default: Any, metadata: tuple[Any, ...], extra: dict[str, Any] | None + ) -> tuple[KwargDefinition | None, dict[str, Any]]: + model = BodyKwarg if name == "data" else ParameterKwarg + + for extractor in _KWARG_META_EXTRACTORS: + if extractor.matches(annotation=annotation, name=name, default=default): + return _create_metadata_from_type( + extractor.extract(annotation=annotation, default=default), + model=model, + annotation=annotation, + extra=extra, + ) + + if any(isinstance(arg, KwargDefinition) for arg in get_args(annotation)): + return next(arg for arg in get_args(annotation) if isinstance(arg, KwargDefinition)), extra or {} + + if metadata: + return _create_metadata_from_type(metadata=metadata, model=model, annotation=annotation, extra=extra) + + return None, {} + + @property + def has_default(self) -> bool: + """Check if the field has a default value. + + Returns: + True if the default is not Empty or Ellipsis otherwise False. + """ + return self.default is not Empty and self.default is not Ellipsis + + @property + def is_non_string_iterable(self) -> bool: + """Check if the field type is an Iterable. + + If ``self.annotation`` is an optional union, only the non-optional members of the union are evaluated. + + See: https://github.com/litestar-org/litestar/issues/1106 + """ + annotation = self.annotation + if self.is_optional: + annotation = make_non_optional_union(annotation) + return is_non_string_iterable(annotation) + + @property + def is_non_string_sequence(self) -> bool: + """Check if the field type is a non-string Sequence. + + If ``self.annotation`` is an optional union, only the non-optional members of the union are evaluated. + + See: https://github.com/litestar-org/litestar/issues/1106 + """ + annotation = self.annotation + if self.is_optional: + annotation = make_non_optional_union(annotation) + return is_non_string_sequence(annotation) + + @property + def is_any(self) -> bool: + """Check if the field type is Any.""" + return is_any(self.annotation) + + @property + def is_generic(self) -> bool: + """Check if the field type is a custom class extending Generic.""" + return is_generic(self.annotation) + + @property + def is_simple_type(self) -> bool: + """Check if the field type is a singleton value (e.g. int, str etc.).""" + return not ( + self.is_generic or self.is_optional or self.is_union or self.is_mapping or self.is_non_string_iterable + ) + + @property + def is_parameter_field(self) -> bool: + """Check if the field type is a parameter kwarg value.""" + return isinstance(self.kwarg_definition, ParameterKwarg) + + @property + def is_const(self) -> bool: + """Check if the field is defined as constant value.""" + return bool(self.kwarg_definition and getattr(self.kwarg_definition, "const", False)) + + @property + def is_required(self) -> bool: + """Check if the field should be marked as a required parameter.""" + if Required in self.type_wrappers: # type: ignore[comparison-overlap] + return True + + if NotRequired in self.type_wrappers or UnsetType in self.args: # type: ignore[comparison-overlap] + return False + + if isinstance(self.kwarg_definition, ParameterKwarg) and self.kwarg_definition.required is not None: + return self.kwarg_definition.required + + return not self.is_optional and not self.is_any and (not self.has_default or self.default is None) + + @property + def is_annotated(self) -> bool: + """Check if the field type is Annotated.""" + return bool(self.metadata) + + @property + def is_literal(self) -> bool: + """Check if the field type is Literal.""" + return self.origin is Literal + + @property + def is_forward_ref(self) -> bool: + """Whether the annotation is a forward reference or not.""" + return isinstance(self.annotation, (str, ForwardRef)) + + @property + def is_mapping(self) -> bool: + """Whether the annotation is a mapping or not.""" + return self.is_subclass_of(Mapping) + + @property + def is_tuple(self) -> bool: + """Whether the annotation is a ``tuple`` or not.""" + return self.is_subclass_of(tuple) + + @property + def is_type_var(self) -> bool: + """Whether the annotation is a TypeVar or not.""" + return isinstance(self.annotation, TypeVar) + + @property + def is_union(self) -> bool: + """Whether the annotation is a union type or not.""" + return self.origin in UnionTypes + + @property + def is_optional(self) -> bool: + """Whether the annotation is Optional or not.""" + return bool(self.is_union and NoneType in self.args) + + @property + def is_none_type(self) -> bool: + """Whether the annotation is NoneType or not.""" + return self.annotation is NoneType + + @property + def is_collection(self) -> bool: + """Whether the annotation is a collection type or not.""" + return self.is_subclass_of(Collection) + + @property + def is_non_string_collection(self) -> bool: + """Whether the annotation is a non-string collection type or not.""" + return self.is_collection and not self.is_subclass_of((str, bytes)) + + @property + def bound_types(self) -> tuple[FieldDefinition, ...] | None: + """A tuple of bound types - if the annotation is a TypeVar with bound types, otherwise None.""" + if self.is_type_var and (bound := getattr(self.annotation, "__bound__", None)): + if is_union(bound): + return tuple(FieldDefinition.from_annotation(t) for t in get_args(bound)) + return (FieldDefinition.from_annotation(bound),) + return None + + @property + def generic_types(self) -> tuple[FieldDefinition, ...] | None: + """A tuple of generic types passed into the annotation - if its generic.""" + if not (bases := getattr(self.annotation, "__orig_bases__", None)): + return None + args: list[FieldDefinition] = [] + for base_args in [getattr(base, "__args__", ()) for base in bases]: + for arg in base_args: + field_definition = FieldDefinition.from_annotation(arg) + if field_definition.generic_types: + args.extend(field_definition.generic_types) + else: + args.append(field_definition) + return tuple(args) + + @property + def is_dataclass_type(self) -> bool: + """Whether the annotation is a dataclass type or not.""" + + return is_dataclass(cast("type", self.origin or self.annotation)) + + @property + def is_typeddict_type(self) -> bool: + """Whether the type is TypedDict or not.""" + + return is_typeddict(self.origin or self.annotation) + + @property + def type_(self) -> Any: + """The type of the annotation with all the wrappers removed, including the generic types.""" + + return self.origin or self.annotation + + def is_subclass_of(self, cl: type[Any] | tuple[type[Any], ...]) -> bool: + """Whether the annotation is a subclass of the given type. + + Where ``self.annotation`` is a union type, this method will return ``True`` when all members of the union are + a subtype of ``cl``, otherwise, ``False``. + + Args: + cl: The type to check, or tuple of types. Passed as 2nd argument to ``issubclass()``. + + Returns: + Whether the annotation is a subtype of the given type(s). + """ + if self.origin: + if self.origin in UnionTypes: + return all(t.is_subclass_of(cl) for t in self.inner_types) + + return self.origin not in UnionTypes and is_class_and_subclass(self.origin, cl) + + if self.annotation is AnyStr: + return is_class_and_subclass(str, cl) or is_class_and_subclass(bytes, cl) + + return self.annotation is not Any and not self.is_type_var and is_class_and_subclass(self.annotation, cl) + + def has_inner_subclass_of(self, cl: type[Any] | tuple[type[Any], ...]) -> bool: + """Whether any generic args are a subclass of the given type. + + Args: + cl: The type to check, or tuple of types. Passed as 2nd argument to ``issubclass()``. + + Returns: + Whether any of the type's generic args are a subclass of the given type. + """ + return any(t.is_subclass_of(cl) for t in self.inner_types) + + def get_type_hints(self, *, include_extras: bool = False, resolve_generics: bool = False) -> dict[str, Any]: + """Get the type hints for the annotation. + + Args: + include_extras: Flag to indicate whether to include ``Annotated[T, ...]`` or not. + resolve_generics: Flag to indicate whether to resolve the generic types in the type hints or not. + + Returns: + The type hints. + """ + + if self.origin is not None or self.is_generic: + if resolve_generics: + return get_type_hints_with_generics_resolved(self.annotation, include_extras=include_extras) + return get_type_hints(self.origin or self.annotation, include_extras=include_extras) + + return get_type_hints(self.annotation, include_extras=include_extras) + + @classmethod + def from_annotation(cls, annotation: Any, **kwargs: Any) -> FieldDefinition: + """Initialize FieldDefinition. + + Args: + annotation: The type annotation. This should be extracted from the return of + ``get_type_hints(..., include_extras=True)`` so that forward references are resolved and recursive + ``Annotated`` types are flattened. + **kwargs: Additional keyword arguments to pass to the ``FieldDefinition`` constructor. + + Returns: + FieldDefinition + """ + + unwrapped, metadata, wrappers = unwrap_annotation(annotation if annotation is not Empty else Any) + origin = get_origin(unwrapped) + + args = () if origin is abc.Callable else get_args(unwrapped) + + if not kwargs.get("kwarg_definition"): + if isinstance(kwargs.get("default"), (KwargDefinition, DependencyKwarg)): + kwargs["kwarg_definition"] = kwargs.pop("default") + elif any(isinstance(v, (KwargDefinition, DependencyKwarg)) for v in metadata): + kwargs["kwarg_definition"] = next( # pragma: no cover + # see https://github.com/nedbat/coveragepy/issues/475 + v + for v in metadata + if isinstance(v, (KwargDefinition, DependencyKwarg)) + ) + metadata = tuple(v for v in metadata if not isinstance(v, (KwargDefinition, DependencyKwarg))) + elif (extra := kwargs.get("extra", {})) and "kwarg_definition" in extra: + kwargs["kwarg_definition"] = extra.pop("kwarg_definition") + else: + kwargs["kwarg_definition"], kwargs["extra"] = cls._extract_metadata( + annotation=annotation, + name=kwargs.get("name", ""), + default=kwargs.get("default", Empty), + metadata=metadata, + extra=kwargs.get("extra"), + ) + + kwargs.setdefault("annotation", unwrapped) + kwargs.setdefault("args", args) + kwargs.setdefault("default", Empty) + kwargs.setdefault("extra", {}) + kwargs.setdefault("inner_types", tuple(FieldDefinition.from_annotation(arg) for arg in args)) + kwargs.setdefault("instantiable_origin", get_instantiable_origin(origin, unwrapped)) + kwargs.setdefault("kwarg_definition", None) + kwargs.setdefault("metadata", metadata) + kwargs.setdefault("name", "") + kwargs.setdefault("origin", origin) + kwargs.setdefault("raw", annotation) + kwargs.setdefault("safe_generic_origin", get_safe_generic_origin(origin, unwrapped)) + kwargs.setdefault("type_wrappers", wrappers) + + instance = FieldDefinition(**kwargs) + if not instance.has_default and instance.kwarg_definition: + return replace(instance, default=instance.kwarg_definition.default) + + return instance + + @classmethod + def from_kwarg( + cls, + annotation: Any, + name: str, + default: Any = Empty, + inner_types: tuple[FieldDefinition, ...] | None = None, + kwarg_definition: KwargDefinition | DependencyKwarg | None = None, + extra: dict[str, Any] | None = None, + ) -> FieldDefinition: + """Create a new FieldDefinition instance. + + Args: + annotation: The type of the kwarg. + name: Field name. + default: A default value. + inner_types: A tuple of FieldDefinition instances representing the inner types, if any. + kwarg_definition: Kwarg Parameter. + extra: A mapping of extra values. + + Returns: + FieldDefinition instance. + """ + + return cls.from_annotation( + annotation, + name=name, + default=default, + **{ + k: v + for k, v in { + "inner_types": inner_types, + "kwarg_definition": kwarg_definition, + "extra": extra, + }.items() + if v is not None + }, + ) + + @classmethod + def from_parameter(cls, parameter: Parameter, fn_type_hints: dict[str, Any]) -> FieldDefinition: + """Initialize ParsedSignatureParameter. + + Args: + parameter: inspect.Parameter + fn_type_hints: mapping of names to types. Should be result of ``get_type_hints()``, preferably via the + :attr:``get_fn_type_hints() <.utils.signature_parsing.get_fn_type_hints>`` helper. + + Returns: + ParsedSignatureParameter. + + """ + from litestar.datastructures import ImmutableState + + try: + annotation = fn_type_hints[parameter.name] + except KeyError as e: + raise ImproperlyConfiguredException( + f"'{parameter.name}' does not have a type annotation. If it should receive any value, use 'Any'." + ) from e + + if parameter.name == "state" and not issubclass(annotation, ImmutableState): + raise ImproperlyConfiguredException( + f"The type annotation `{annotation}` is an invalid type for the 'state' reserved kwarg. " + "It must be typed to a subclass of `litestar.datastructures.ImmutableState` or " + "`litestar.datastructures.State`." + ) + + return FieldDefinition.from_kwarg( + annotation=annotation, + name=parameter.name, + default=Empty if parameter.default is Signature.empty else parameter.default, + ) + + def match_predicate_recursively(self, predicate: Callable[[FieldDefinition], bool]) -> bool: + """Recursively test the passed in predicate against the field and any of its inner fields. + + Args: + predicate: A callable that receives a field definition instance as an arg and returns a boolean. + + Returns: + A boolean. + """ + return predicate(self) or any(t.match_predicate_recursively(predicate) for t in self.inner_types) diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__init__.py b/venv/lib/python3.11/site-packages/litestar/utils/__init__.py new file mode 100644 index 0000000..3f62792 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__init__.py @@ -0,0 +1,86 @@ +from typing import Any + +from litestar.utils.deprecation import deprecated, warn_deprecation + +from .helpers import get_enum_string_value, get_name, unique_name_for_scope, url_quote +from .path import join_paths, normalize_path +from .predicates import ( + _is_sync_or_async_generator, + is_annotated_type, + is_any, + is_async_callable, + is_attrs_class, + is_class_and_subclass, + is_class_var, + is_dataclass_class, + is_dataclass_instance, + is_generic, + is_mapping, + is_non_string_iterable, + is_non_string_sequence, + is_optional_union, + is_undefined_sentinel, + is_union, +) +from .scope import ( # type: ignore[attr-defined] + _delete_litestar_scope_state, + _get_litestar_scope_state, + _set_litestar_scope_state, + get_serializer_from_scope, +) +from .sequence import find_index, unique +from .sync import AsyncIteratorWrapper, ensure_async_callable +from .typing import get_origin_or_inner_type, make_non_optional_union + +__all__ = ( + "ensure_async_callable", + "AsyncIteratorWrapper", + "deprecated", + "find_index", + "get_enum_string_value", + "get_name", + "get_origin_or_inner_type", + "get_serializer_from_scope", + "is_annotated_type", + "is_any", + "is_async_callable", + "is_attrs_class", + "is_class_and_subclass", + "is_class_var", + "is_dataclass_class", + "is_dataclass_instance", + "is_generic", + "is_mapping", + "is_non_string_iterable", + "is_non_string_sequence", + "is_optional_union", + "is_undefined_sentinel", + "is_union", + "join_paths", + "make_non_optional_union", + "normalize_path", + "unique", + "unique_name_for_scope", + "url_quote", + "warn_deprecation", +) + +_deprecated_names = { + "get_litestar_scope_state": _get_litestar_scope_state, + "set_litestar_scope_state": _set_litestar_scope_state, + "delete_litestar_scope_state": _delete_litestar_scope_state, + "is_sync_or_async_generator": _is_sync_or_async_generator, +} + + +def __getattr__(name: str) -> Any: + if name in _deprecated_names: + warn_deprecation( + deprecated_name=f"litestar.utils.{name}", + version="2.4", + kind="import", + removal_in="3.0", + info=f"'litestar.utils.{name}' is deprecated.", + ) + return globals()["_deprecated_names"][name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") # pragma: no cover diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..da9ad02 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/compat.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/compat.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..39d39d6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/compat.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/dataclass.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/dataclass.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..addf140 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/dataclass.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/deprecation.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/deprecation.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ddf3b7c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/deprecation.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/empty.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/empty.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2b358ad --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/empty.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/helpers.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/helpers.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..9db59c4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/helpers.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/module_loader.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/module_loader.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f32ca7e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/module_loader.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/path.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/path.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6602b02 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/path.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/predicates.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/predicates.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..983250b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/predicates.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/sequence.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/sequence.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b2f8c1a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/sequence.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/signature.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/signature.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..d81050f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/signature.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/sync.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/sync.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8a18dfb --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/sync.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/typing.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/typing.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..53db129 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/typing.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/version.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/version.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f656ef4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/version.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/warnings.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/warnings.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..3fa1cf1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/__pycache__/warnings.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/compat.py b/venv/lib/python3.11/site-packages/litestar/utils/compat.py new file mode 100644 index 0000000..384db76 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/compat.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar + +from litestar.types import Empty, EmptyType + +__all__ = ("async_next",) + + +if TYPE_CHECKING: + from typing import Any, AsyncGenerator + +T = TypeVar("T") +D = TypeVar("D") + +try: + async_next = anext # type: ignore[name-defined] +except NameError: + + async def async_next(gen: AsyncGenerator[T, Any], default: D | EmptyType = Empty) -> T | D: + """Backwards compatibility shim for Python<3.10.""" + try: + return await gen.__anext__() + except StopAsyncIteration as exc: + if default is not Empty: + return default + raise exc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/dataclass.py b/venv/lib/python3.11/site-packages/litestar/utils/dataclass.py new file mode 100644 index 0000000..597465d --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/dataclass.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from dataclasses import Field, fields +from typing import TYPE_CHECKING + +from litestar.types import Empty +from litestar.utils.predicates import is_dataclass_instance + +if TYPE_CHECKING: + from typing import AbstractSet, Any, Iterable + + from litestar.types.protocols import DataclassProtocol + +__all__ = ( + "extract_dataclass_fields", + "extract_dataclass_items", + "simple_asdict", +) + + +def extract_dataclass_fields( + dt: DataclassProtocol, + exclude_none: bool = False, + exclude_empty: bool = False, + include: AbstractSet[str] | None = None, + exclude: AbstractSet[str] | None = None, +) -> tuple[Field[Any], ...]: + """Extract dataclass fields. + + Args: + dt: A dataclass instance. + exclude_none: Whether to exclude None values. + exclude_empty: Whether to exclude Empty values. + include: An iterable of fields to include. + exclude: An iterable of fields to exclude. + + + Returns: + A tuple of dataclass fields. + """ + include = include or set() + exclude = exclude or set() + + if common := (include & exclude): + raise ValueError(f"Fields {common} are both included and excluded.") + + dataclass_fields: Iterable[Field[Any]] = fields(dt) + if exclude_none: + dataclass_fields = (field for field in dataclass_fields if getattr(dt, field.name) is not None) + if exclude_empty: + dataclass_fields = (field for field in dataclass_fields if getattr(dt, field.name) is not Empty) + if include: + dataclass_fields = (field for field in dataclass_fields if field.name in include) + if exclude: + dataclass_fields = (field for field in dataclass_fields if field.name not in exclude) + + return tuple(dataclass_fields) + + +def extract_dataclass_items( + dt: DataclassProtocol, + exclude_none: bool = False, + exclude_empty: bool = False, + include: AbstractSet[str] | None = None, + exclude: AbstractSet[str] | None = None, +) -> tuple[tuple[str, Any], ...]: + """Extract dataclass name, value pairs. + + Unlike the 'asdict' method exports by the stlib, this function does not pickle values. + + Args: + dt: A dataclass instance. + exclude_none: Whether to exclude None values. + exclude_empty: Whether to exclude Empty values. + include: An iterable of fields to include. + exclude: An iterable of fields to exclude. + + Returns: + A tuple of key/value pairs. + """ + dataclass_fields = extract_dataclass_fields(dt, exclude_none, exclude_empty, include, exclude) + return tuple((field.name, getattr(dt, field.name)) for field in dataclass_fields) + + +def simple_asdict( + obj: DataclassProtocol, + exclude_none: bool = False, + exclude_empty: bool = False, + convert_nested: bool = True, + exclude: set[str] | None = None, +) -> dict[str, Any]: + """Convert a dataclass to a dictionary. + + This method has important differences to the standard library version: + - it does not deepcopy values + - it does not recurse into collections + + Args: + obj: A dataclass instance. + exclude_none: Whether to exclude None values. + exclude_empty: Whether to exclude Empty values. + convert_nested: Whether to recursively convert nested dataclasses. + exclude: An iterable of fields to exclude. + + Returns: + A dictionary of key/value pairs. + """ + ret = {} + for field in extract_dataclass_fields(obj, exclude_none, exclude_empty, exclude=exclude): + value = getattr(obj, field.name) + if is_dataclass_instance(value) and convert_nested: + ret[field.name] = simple_asdict(value, exclude_none, exclude_empty) + else: + ret[field.name] = getattr(obj, field.name) + return ret diff --git a/venv/lib/python3.11/site-packages/litestar/utils/deprecation.py b/venv/lib/python3.11/site-packages/litestar/utils/deprecation.py new file mode 100644 index 0000000..b1b8725 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/deprecation.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import inspect +from functools import wraps +from typing import Callable, Literal, TypeVar +from warnings import warn + +from typing_extensions import ParamSpec + +__all__ = ("deprecated", "warn_deprecation") + + +T = TypeVar("T") +P = ParamSpec("P") +DeprecatedKind = Literal["function", "method", "classmethod", "attribute", "property", "class", "parameter", "import"] + + +def warn_deprecation( + version: str, + deprecated_name: str, + kind: DeprecatedKind, + *, + removal_in: str | None = None, + alternative: str | None = None, + info: str | None = None, + pending: bool = False, +) -> None: + """Warn about a call to a (soon to be) deprecated function. + + Args: + version: Litestar version where the deprecation will occur + deprecated_name: Name of the deprecated function + removal_in: Litestar version where the deprecated function will be removed + alternative: Name of a function that should be used instead + info: Additional information + pending: Use ``PendingDeprecationWarning`` instead of ``DeprecationWarning`` + kind: Type of the deprecated thing + """ + parts = [] + + if kind == "import": + access_type = "Import of" + elif kind in {"function", "method"}: + access_type = "Call to" + else: + access_type = "Use of" + + if pending: + parts.append(f"{access_type} {kind} awaiting deprecation {deprecated_name!r}") + else: + parts.append(f"{access_type} deprecated {kind} {deprecated_name!r}") + + parts.extend( + ( + f"Deprecated in litestar {version}", + f"This {kind} will be removed in {removal_in or 'the next major version'}", + ) + ) + if alternative: + parts.append(f"Use {alternative!r} instead") + + if info: + parts.append(info) + + text = ". ".join(parts) + warning_class = PendingDeprecationWarning if pending else DeprecationWarning + + warn(text, warning_class, stacklevel=2) + + +def deprecated( + version: str, + *, + removal_in: str | None = None, + alternative: str | None = None, + info: str | None = None, + pending: bool = False, + kind: Literal["function", "method", "classmethod", "property"] | None = None, +) -> Callable[[Callable[P, T]], Callable[P, T]]: + """Create a decorator wrapping a function, method or property with a warning call about a (pending) deprecation. + + Args: + version: Litestar version where the deprecation will occur + removal_in: Litestar version where the deprecated function will be removed + alternative: Name of a function that should be used instead + info: Additional information + pending: Use ``PendingDeprecationWarning`` instead of ``DeprecationWarning`` + kind: Type of the deprecated callable. If ``None``, will use ``inspect`` to figure + out if it's a function or method + + Returns: + A decorator wrapping the function call with a warning + """ + + def decorator(func: Callable[P, T]) -> Callable[P, T]: + @wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + warn_deprecation( + version=version, + deprecated_name=func.__name__, + info=info, + alternative=alternative, + pending=pending, + removal_in=removal_in, + kind=kind or ("method" if inspect.ismethod(func) else "function"), + ) + return func(*args, **kwargs) + + return wrapped + + return decorator diff --git a/venv/lib/python3.11/site-packages/litestar/utils/empty.py b/venv/lib/python3.11/site-packages/litestar/utils/empty.py new file mode 100644 index 0000000..cdde871 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/empty.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar + +from litestar.types.empty import Empty + +if TYPE_CHECKING: + from litestar.types.empty import EmptyType + +ValueT = TypeVar("ValueT") +DefaultT = TypeVar("DefaultT") + + +def value_or_default(value: ValueT | EmptyType, default: DefaultT) -> ValueT | DefaultT: + """Return `value` handling the case where it is empty. + + If `value` is `Empty`, `default` is returned. + + Args: + value: The value to check. + default: The default value to return if `value` is `Empty`. + + Returns: + The value or default value. + """ + return default if value is Empty else value diff --git a/venv/lib/python3.11/site-packages/litestar/utils/helpers.py b/venv/lib/python3.11/site-packages/litestar/utils/helpers.py new file mode 100644 index 0000000..c25fe35 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/helpers.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from enum import Enum +from functools import partial +from typing import TYPE_CHECKING, TypeVar, cast +from urllib.parse import quote + +from litestar.utils.typing import get_origin_or_inner_type + +if TYPE_CHECKING: + from collections.abc import Container + + from litestar.types import MaybePartial + +__all__ = ( + "get_enum_string_value", + "get_name", + "unwrap_partial", + "url_quote", + "unique_name_for_scope", +) + +T = TypeVar("T") + + +def get_name(value: object) -> str: + """Get the ``__name__`` of an object. + + Args: + value: An arbitrary object. + + Returns: + A name string. + """ + + name = getattr(value, "__name__", None) + if name is not None: + return cast("str", name) + + # On Python 3.8 and 3.9, Foo[int] does not have the __name__ attribute. + if origin := get_origin_or_inner_type(value): + return cast("str", origin.__name__) + + return type(value).__name__ + + +def get_enum_string_value(value: Enum | str) -> str: + """Return the string value of a string enum. + + See: https://github.com/litestar-org/litestar/pull/633#issuecomment-1286519267 + + Args: + value: An enum or string. + + Returns: + A string. + """ + return value.value if isinstance(value, Enum) else value # type: ignore[no-any-return] + + +def unwrap_partial(value: MaybePartial[T]) -> T: + """Unwraps a partial, returning the underlying callable. + + Args: + value: A partial function. + + Returns: + Callable + """ + from litestar.utils.sync import AsyncCallable + + return cast("T", value.func if isinstance(value, (partial, AsyncCallable)) else value) + + +def url_quote(value: str | bytes) -> str: + """Quote a URL. + + Args: + value: A URL. + + Returns: + A quoted URL. + """ + return quote(value, safe="/#%[]=:;$&()+,!?*@'~") + + +def unique_name_for_scope(base_name: str, scope: Container[str]) -> str: + """Create a name derived from ``base_name`` that's unique within ``scope``""" + i = 0 + while True: + if (unique_name := f"{base_name}_{i}") not in scope: + return unique_name + i += 1 + + +def get_exception_group() -> type[BaseException]: + """Get the exception group class with version compatibility.""" + try: + return cast("type[BaseException]", ExceptionGroup) # type:ignore[name-defined] + except NameError: + from exceptiongroup import ExceptionGroup as _ExceptionGroup # pyright: ignore + + return cast("type[BaseException]", _ExceptionGroup) diff --git a/venv/lib/python3.11/site-packages/litestar/utils/module_loader.py b/venv/lib/python3.11/site-packages/litestar/utils/module_loader.py new file mode 100644 index 0000000..09dbf9f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/module_loader.py @@ -0,0 +1,92 @@ +"""General utility functions.""" + +from __future__ import annotations + +import os.path +import sys +from importlib import import_module +from importlib.util import find_spec +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from types import ModuleType + +__all__ = ( + "import_string", + "module_to_os_path", +) + + +def module_to_os_path(dotted_path: str = "app") -> Path: + """Find Module to OS Path. + + Return a path to the base directory of the project or the module + specified by `dotted_path`. + + Args: + dotted_path: The path to the module. Defaults to "app". + + Raises: + TypeError: The module could not be found. + + Returns: + Path: The path to the module. + """ + try: + if (src := find_spec(dotted_path)) is None: # pragma: no cover + raise TypeError(f"Couldn't find the path for {dotted_path}") + except ModuleNotFoundError as e: + raise TypeError(f"Couldn't find the path for {dotted_path}") from e + + return Path(str(src.origin).rsplit(os.path.sep + "__init__.py", maxsplit=1)[0]) + + +def import_string(dotted_path: str) -> Any: + """Dotted Path Import. + + Import a dotted module path and return the attribute/class designated by the + last name in the path. Raise ImportError if the import failed. + + Args: + dotted_path: The path of the module to import. + + Raises: + ImportError: Could not import the module. + + Returns: + object: The imported object. + """ + + def _is_loaded(module: ModuleType | None) -> bool: + spec = getattr(module, "__spec__", None) + initializing = getattr(spec, "_initializing", False) + return bool(module and spec and not initializing) + + def _cached_import(module_path: str, class_name: str) -> Any: + """Import and cache a class from a module. + + Args: + module_path: dotted path to module. + class_name: Class or function name. + + Returns: + object: The imported class or function + """ + # Check whether module is loaded and fully initialized. + module = sys.modules.get(module_path) + if not _is_loaded(module): + module = import_module(module_path) + return getattr(module, class_name) + + try: + module_path, class_name = dotted_path.rsplit(".", 1) + except ValueError as e: + msg = "%s doesn't look like a module path" + raise ImportError(msg, dotted_path) from e + + try: + return _cached_import(module_path, class_name) + except AttributeError as e: + msg = "Module '%s' does not define a '%s' attribute/class" + raise ImportError(msg, module_path, class_name) from e diff --git a/venv/lib/python3.11/site-packages/litestar/utils/path.py b/venv/lib/python3.11/site-packages/litestar/utils/path.py new file mode 100644 index 0000000..76b43af --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/path.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import re +from typing import Iterable + +__all__ = ("join_paths", "normalize_path") + + +multi_slash_pattern = re.compile("//+") + + +def normalize_path(path: str) -> str: + """Normalize a given path by ensuring it starts with a slash and does not end with a slash. + + Args: + path: Path string + + Returns: + Path string + """ + path = path.strip("/") + path = f"/{path}" + return multi_slash_pattern.sub("/", path) + + +def join_paths(paths: Iterable[str]) -> str: + """Normalize and joins path fragments. + + Args: + paths: An iterable of path fragments. + + Returns: + A normalized joined path string. + """ + return normalize_path("/".join(paths)) diff --git a/venv/lib/python3.11/site-packages/litestar/utils/predicates.py b/venv/lib/python3.11/site-packages/litestar/utils/predicates.py new file mode 100644 index 0000000..11d5f79 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/predicates.py @@ -0,0 +1,339 @@ +from __future__ import annotations + +from asyncio import iscoroutinefunction +from collections import defaultdict, deque +from collections.abc import Iterable as CollectionsIterable +from dataclasses import is_dataclass +from inspect import isasyncgenfunction, isclass, isgeneratorfunction +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + ClassVar, + DefaultDict, + Deque, + Dict, + FrozenSet, + Generic, + Iterable, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + TypeVar, +) + +from typing_extensions import ( + ParamSpec, + TypeGuard, + _AnnotatedAlias, + get_args, +) + +from litestar.constants import UNDEFINED_SENTINELS +from litestar.types import Empty +from litestar.types.builtin_types import NoneType, UnionTypes +from litestar.utils.deprecation import warn_deprecation +from litestar.utils.helpers import unwrap_partial +from litestar.utils.typing import get_origin_or_inner_type + +if TYPE_CHECKING: + from litestar.types.callable_types import AnyGenerator + from litestar.types.protocols import DataclassProtocol + +try: + import attrs +except ImportError: + attrs = Empty # type: ignore[assignment] + +__all__ = ( + "is_annotated_type", + "is_any", + "is_async_callable", + "is_attrs_class", + "is_class_and_subclass", + "is_class_var", + "is_dataclass_class", + "is_dataclass_instance", + "is_generic", + "is_mapping", + "is_non_string_iterable", + "is_non_string_sequence", + "is_optional_union", + "is_undefined_sentinel", + "is_union", +) + +P = ParamSpec("P") +T = TypeVar("T") + + +def is_async_callable(value: Callable[P, T]) -> TypeGuard[Callable[P, Awaitable[T]]]: + """Extend :func:`asyncio.iscoroutinefunction` to additionally detect async :func:`functools.partial` objects and + class instances with ``async def __call__()`` defined. + + Args: + value: Any + + Returns: + Bool determining if type of ``value`` is an awaitable. + """ + value = unwrap_partial(value) + + return iscoroutinefunction(value) or ( + callable(value) and iscoroutinefunction(value.__call__) # type: ignore[operator] + ) + + +def is_dataclass_instance(obj: Any) -> TypeGuard[DataclassProtocol]: + """Check if an object is a dataclass instance. + + Args: + obj: An object to check. + + Returns: + True if the object is a dataclass instance. + """ + return hasattr(type(obj), "__dataclass_fields__") + + +def is_dataclass_class(annotation: Any) -> TypeGuard[type[DataclassProtocol]]: + """Wrap :func:`is_dataclass <dataclasses.is_dataclass>` in a :data:`typing.TypeGuard`. + + Args: + annotation: tested to determine if instance or type of :class:`dataclasses.dataclass`. + + Returns: + ``True`` if instance or type of ``dataclass``. + """ + try: + origin = get_origin_or_inner_type(annotation) + annotation = origin or annotation + + return isclass(annotation) and is_dataclass(annotation) + except TypeError: # pragma: no cover + return False + + +def is_class_and_subclass(annotation: Any, type_or_type_tuple: type[T] | tuple[type[T], ...]) -> TypeGuard[type[T]]: + """Return ``True`` if ``value`` is a ``class`` and is a subtype of ``t_type``. + + See https://github.com/litestar-org/litestar/issues/367 + + Args: + annotation: The value to check if is class and subclass of ``t_type``. + type_or_type_tuple: Type used for :func:`issubclass` check of ``value`` + + Returns: + bool + """ + origin = get_origin_or_inner_type(annotation) + if not origin and not isclass(annotation): + return False + try: + return issubclass(origin or annotation, type_or_type_tuple) + except TypeError: # pragma: no cover + return False + + +def is_generic(annotation: Any) -> bool: + """Given a type annotation determine if the annotation is a generic class. + + Args: + annotation: A type. + + Returns: + True if the annotation is a subclass of :data:`Generic <typing.Generic>` otherwise ``False``. + """ + return is_class_and_subclass(annotation, Generic) # type: ignore[arg-type] + + +def is_mapping(annotation: Any) -> TypeGuard[Mapping[Any, Any]]: + """Given a type annotation determine if the annotation is a mapping type. + + Args: + annotation: A type. + + Returns: + A typeguard determining whether the type can be cast as :class:`Mapping <typing.Mapping>`. + """ + _type = get_origin_or_inner_type(annotation) or annotation + return isclass(_type) and issubclass(_type, (dict, defaultdict, DefaultDict, Mapping)) + + +def is_non_string_iterable(annotation: Any) -> TypeGuard[Iterable[Any]]: + """Given a type annotation determine if the annotation is an iterable. + + Args: + annotation: A type. + + Returns: + A typeguard determining whether the type can be cast as :class:`Iterable <typing.Iterable>` that is not a string. + """ + origin = get_origin_or_inner_type(annotation) + if not origin and not isclass(annotation): + return False + try: + return not issubclass(origin or annotation, (str, bytes)) and ( + issubclass(origin or annotation, (Iterable, CollectionsIterable, Dict, dict, Mapping)) + or is_non_string_sequence(annotation) + ) + except TypeError: # pragma: no cover + return False + + +def is_non_string_sequence(annotation: Any) -> TypeGuard[Sequence[Any]]: + """Given a type annotation determine if the annotation is a sequence. + + Args: + annotation: A type. + + Returns: + A typeguard determining whether the type can be cast as :class`Sequence <typing.Sequence>` that is not a string. + """ + origin = get_origin_or_inner_type(annotation) + if not origin and not isclass(annotation): + return False + try: + return not issubclass(origin or annotation, (str, bytes)) and issubclass( + origin or annotation, + ( # type: ignore[arg-type] + Tuple, + List, + Set, + FrozenSet, + Deque, + Sequence, + list, + tuple, + deque, + set, + frozenset, + ), + ) + except TypeError: # pragma: no cover + return False + + +def is_any(annotation: Any) -> TypeGuard[Any]: + """Given a type annotation determine if the annotation is Any. + + Args: + annotation: A type. + + Returns: + A typeguard determining whether the type is :data:`Any <typing.Any>`. + """ + return ( + annotation is Any + or getattr(annotation, "_name", "") == "typing.Any" + or (get_origin_or_inner_type(annotation) in UnionTypes and Any in get_args(annotation)) + ) + + +def is_union(annotation: Any) -> bool: + """Given a type annotation determine if the annotation infers an optional union. + + Args: + annotation: A type. + + Returns: + A boolean determining whether the type is :data:`Union typing.Union>`. + """ + return get_origin_or_inner_type(annotation) in UnionTypes + + +def is_optional_union(annotation: Any) -> TypeGuard[Any | None]: + """Given a type annotation determine if the annotation infers an optional union. + + Args: + annotation: A type. + + Returns: + A typeguard determining whether the type is :data:`Union typing.Union>` with a + None value or :data:`Optional <typing.Optional>` which is equivalent. + """ + origin = get_origin_or_inner_type(annotation) + return origin is Optional or ( + get_origin_or_inner_type(annotation) in UnionTypes and NoneType in get_args(annotation) + ) + + +def is_attrs_class(annotation: Any) -> TypeGuard[type[attrs.AttrsInstance]]: # pyright: ignore + """Given a type annotation determine if the annotation is a class that includes an attrs attribute. + + Args: + annotation: A type. + + Returns: + A typeguard determining whether the type is an attrs class. + """ + return attrs.has(annotation) if attrs is not Empty else False # type: ignore[comparison-overlap] + + +def is_class_var(annotation: Any) -> bool: + """Check if the given annotation is a ClassVar. + + Args: + annotation: A type annotation + + Returns: + A boolean. + """ + annotation = get_origin_or_inner_type(annotation) or annotation + return annotation is ClassVar + + +def _is_sync_or_async_generator(obj: Any) -> TypeGuard[AnyGenerator]: + """Check if the given annotation is a sync or async generator. + + Args: + obj: type to be tested for sync or async generator. + + Returns: + A boolean. + """ + return isgeneratorfunction(obj) or isasyncgenfunction(obj) + + +def is_annotated_type(annotation: Any) -> bool: + """Check if the given annotation is an Annotated. + + Args: + annotation: A type annotation + + Returns: + A boolean. + """ + return isinstance(annotation, _AnnotatedAlias) and getattr(annotation, "__args__", None) is not None + + +def is_undefined_sentinel(value: Any) -> bool: + """Check if the given value is the undefined sentinel. + + Args: + value: A value to be tested for undefined sentinel. + + Returns: + A boolean. + """ + return any(v is value for v in UNDEFINED_SENTINELS) + + +_deprecated_names = {"is_sync_or_async_generator": _is_sync_or_async_generator} + + +def __getattr__(name: str) -> Any: + if name in _deprecated_names: + warn_deprecation( + deprecated_name=f"litestar.utils.scope.{name}", + version="2.4", + kind="import", + removal_in="3.0", + info=f"'litestar.utils.predicates.{name}' is deprecated.", + ) + return globals()["_deprecated_names"][name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") # pragma: no cover diff --git a/venv/lib/python3.11/site-packages/litestar/utils/scope/__init__.py b/venv/lib/python3.11/site-packages/litestar/utils/scope/__init__.py new file mode 100644 index 0000000..e5757d3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/scope/__init__.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from litestar.serialization import get_serializer +from litestar.utils.deprecation import warn_deprecation +from litestar.utils.scope.state import delete_litestar_scope_state as _delete_litestar_scope_state +from litestar.utils.scope.state import get_litestar_scope_state as _get_litestar_scope_state +from litestar.utils.scope.state import set_litestar_scope_state as _set_litestar_scope_state + +if TYPE_CHECKING: + from litestar.types import Scope, Serializer + +__all__ = ("get_serializer_from_scope",) + + +def get_serializer_from_scope(scope: Scope) -> Serializer: + """Return a serializer given a scope object. + + Args: + scope: The ASGI connection scope. + + Returns: + A serializer function + """ + route_handler = scope["route_handler"] + app = scope["app"] + + if hasattr(route_handler, "resolve_type_encoders"): + type_encoders = route_handler.resolve_type_encoders() + else: + type_encoders = app.type_encoders or {} + + if response_class := ( + route_handler.resolve_response_class() # pyright: ignore + if hasattr(route_handler, "resolve_response_class") + else app.response_class + ): + type_encoders = {**type_encoders, **(response_class.type_encoders or {})} + + return get_serializer(type_encoders) + + +_deprecated_names = { + "get_litestar_scope_state": _get_litestar_scope_state, + "set_litestar_scope_state": _set_litestar_scope_state, + "delete_litestar_scope_state": _delete_litestar_scope_state, +} + + +def __getattr__(name: str) -> Any: + if name in _deprecated_names: + warn_deprecation( + deprecated_name=f"litestar.utils.scope.{name}", + version="2.4", + kind="import", + removal_in="3.0", + info=f"'litestar.utils.scope.{name}' is deprecated. The Litestar scope state is private and should not be " + f"used. Plugin authors should maintain their own scope state namespace.", + ) + return globals()["_deprecated_names"][name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") # pragma: no cover diff --git a/venv/lib/python3.11/site-packages/litestar/utils/scope/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/scope/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ee218d5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/scope/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/scope/__pycache__/state.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/utils/scope/__pycache__/state.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4615004 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/scope/__pycache__/state.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/utils/scope/state.py b/venv/lib/python3.11/site-packages/litestar/utils/scope/state.py new file mode 100644 index 0000000..bed4394 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/scope/state.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Final + +from litestar.types import Empty, EmptyType +from litestar.utils.empty import value_or_default + +if TYPE_CHECKING: + from typing_extensions import Self + + from litestar.datastructures import URL, Accept, Headers + from litestar.types.asgi_types import Scope + +CONNECTION_STATE_KEY: Final = "_ls_connection_state" + + +@dataclass +class ScopeState: + """An object for storing connection state. + + This is an internal API, and subject to change without notice. + + All types are a union with `EmptyType` and are seeded with the `Empty` value. + """ + + __slots__ = ( + "accept", + "base_url", + "body", + "content_type", + "cookies", + "csrf_token", + "dependency_cache", + "do_cache", + "form", + "headers", + "is_cached", + "json", + "log_context", + "msgpack", + "parsed_query", + "response_compressed", + "session_id", + "url", + "_compat_ns", + ) + + def __init__(self) -> None: + self.accept = Empty + self.base_url = Empty + self.body = Empty + self.content_type = Empty + self.cookies = Empty + self.csrf_token = Empty + self.dependency_cache = Empty + self.do_cache = Empty + self.form = Empty + self.headers = Empty + self.is_cached = Empty + self.json = Empty + self.log_context: dict[str, Any] = {} + self.msgpack = Empty + self.parsed_query = Empty + self.response_compressed = Empty + self.session_id = Empty + self.url = Empty + self._compat_ns: dict[str, Any] = {} + + accept: Accept | EmptyType + base_url: URL | EmptyType + body: bytes | EmptyType + content_type: tuple[str, dict[str, str]] | EmptyType + cookies: dict[str, str] | EmptyType + csrf_token: str | EmptyType + dependency_cache: dict[str, Any] | EmptyType + do_cache: bool | EmptyType + form: dict[str, str | list[str]] | EmptyType + headers: Headers | EmptyType + is_cached: bool | EmptyType + json: Any | EmptyType + log_context: dict[str, Any] + msgpack: Any | EmptyType + parsed_query: tuple[tuple[str, str], ...] | EmptyType + response_compressed: bool | EmptyType + session_id: str | None | EmptyType + url: URL | EmptyType + _compat_ns: dict[str, Any] + + @classmethod + def from_scope(cls, scope: Scope) -> Self: + """Create a new `ConnectionState` object from a scope. + + Object is cached in the scope's state under the `SCOPE_STATE_NAMESPACE` key. + + Args: + scope: The ASGI connection scope. + + Returns: + A `ConnectionState` object. + """ + base_scope_state = scope.setdefault("state", {}) + if (state := base_scope_state.get(CONNECTION_STATE_KEY)) is None: + state = base_scope_state[CONNECTION_STATE_KEY] = cls() + return state + + +def get_litestar_scope_state(scope: Scope, key: str, default: Any = None, pop: bool = False) -> Any: + """Get an internal value from connection scope state. + + Args: + scope: The connection scope. + key: Key to get from internal namespace in scope state. + default: Default value to return. + pop: Boolean flag dictating whether the value should be deleted from the state. + + Returns: + Value mapped to ``key`` in internal connection scope namespace. + """ + scope_state = ScopeState.from_scope(scope) + try: + val = value_or_default(getattr(scope_state, key), default) + if pop: + setattr(scope_state, key, Empty) + return val + except AttributeError: + if pop: + return scope_state._compat_ns.pop(key, default) + return scope_state._compat_ns.get(key, default) + + +def set_litestar_scope_state(scope: Scope, key: str, value: Any) -> None: + """Set an internal value in connection scope state. + + Args: + scope: The connection scope. + key: Key to set under internal namespace in scope state. + value: Value for key. + """ + scope_state = ScopeState.from_scope(scope) + if hasattr(scope_state, key): + setattr(scope_state, key, value) + else: + scope_state._compat_ns[key] = value + + +def delete_litestar_scope_state(scope: Scope, key: str) -> None: + """Delete an internal value from connection scope state. + + Args: + scope: The connection scope. + key: Key to set under internal namespace in scope state. + """ + scope_state = ScopeState.from_scope(scope) + if hasattr(scope_state, key): + setattr(scope_state, key, Empty) + else: + del scope_state._compat_ns[key] diff --git a/venv/lib/python3.11/site-packages/litestar/utils/sequence.py b/venv/lib/python3.11/site-packages/litestar/utils/sequence.py new file mode 100644 index 0000000..01ef1a8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/sequence.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import Callable, Sequence, TypeVar + +__all__ = ("find_index", "unique") + + +T = TypeVar("T") + + +def find_index(target_list: Sequence[T], predicate: Callable[[T], bool]) -> int: + """Find element in list given a key and value. + + List elements can be dicts or classes + """ + return next((i for i, element in enumerate(target_list) if predicate(element)), -1) + + +def unique(value: Sequence[T]) -> list[T]: + """Return all unique values in a given sequence or iterator.""" + try: + return list(set(value)) + except TypeError: + output: list[T] = [] + for element in value: + if element not in output: + output.append(element) + return output diff --git a/venv/lib/python3.11/site-packages/litestar/utils/signature.py b/venv/lib/python3.11/site-packages/litestar/utils/signature.py new file mode 100644 index 0000000..eb58599 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/signature.py @@ -0,0 +1,267 @@ +from __future__ import annotations + +import sys +import typing +from copy import deepcopy +from dataclasses import dataclass, replace +from inspect import Signature, getmembers, isclass, ismethod +from itertools import chain +from typing import TYPE_CHECKING, Any, Union + +from typing_extensions import Annotated, Self, get_args, get_origin, get_type_hints + +from litestar import connection, datastructures, types +from litestar.exceptions import ImproperlyConfiguredException +from litestar.types import Empty +from litestar.typing import FieldDefinition +from litestar.utils.typing import unwrap_annotation + +if TYPE_CHECKING: + from typing import Sequence + + from litestar.types import AnyCallable + +if sys.version_info < (3, 11): + from typing import _get_defaults # type: ignore[attr-defined] +else: + + def _get_defaults(_: Any) -> Any: ... + + +__all__ = ( + "add_types_to_signature_namespace", + "get_fn_type_hints", + "ParsedSignature", +) + +_GLOBAL_NAMES = { + namespace: export + for namespace, export in chain( + tuple(getmembers(types)), tuple(getmembers(connection)), tuple(getmembers(datastructures)) + ) + if namespace[0].isupper() and namespace in chain(types.__all__, connection.__all__, datastructures.__all__) # pyright: ignore +} +"""A mapping of names used for handler signature forward-ref resolution. + +This allows users to include these names within an `if TYPE_CHECKING:` block in their handler module. +""" + + +def _unwrap_implicit_optional_hints(defaults: dict[str, Any], hints: dict[str, Any]) -> dict[str, Any]: + """Unwrap implicit optional hints. + + On Python<3.11, if a function parameter annotation has a ``None`` default, it is unconditionally wrapped in an + ``Optional`` type. + + If the annotation is not annotated, then any nested unions are flattened, e.g.,: + + .. code-block:: python + + def foo(a: Optional[Union[str, int]] = None): ... + + ...will become `Union[str, int, NoneType]`. + + However, if the annotation is annotated, then we end up with an optional union around the annotated type, e.g.,: + + .. code-block:: python + + def foo(a: Annotated[Optional[Union[str, int]], ...] = None): ... + + ... becomes `Union[Annotated[Union[str, int, NoneType], ...], NoneType]` + + This function makes the latter case consistent with the former by either removing the outer union if it is redundant + or flattening the union if it is not. The latter case would become `Annotated[Union[str, int, NoneType], ...]`. + + Args: + defaults: Mapping of names to default values. + hints: Mapping of names to types. + + Returns: + Mapping of names to types. + """ + + def _is_two_arg_optional(origin_: Any, args_: Any) -> bool: + """Check if a type is a two-argument optional type. + + If the type has been wrapped in `Optional` by `get_type_hints()` it will always be a union of a type and + `NoneType`. + + See: https://github.com/litestar-org/litestar/pull/2516 + """ + return origin_ is Union and len(args_) == 2 and args_[1] is type(None) + + def _is_any_optional(origin_: Any, args_: tuple[Any, ...]) -> bool: + """Detect if a type is a union with `NoneType`. + + After detecting that a type is a two-argument optional type, this function can be used to detect if the + inner type is a union with `NoneType` at all. + + We only want to perform the unwrapping of the optional union if the inner type is optional as well. + """ + return origin_ is Union and any(arg is type(None) for arg in args_) + + for name, default in defaults.items(): + if default is not None: + continue + + hint = hints[name] + origin = get_origin(hint) + args = get_args(hint) + + if _is_two_arg_optional(origin, args): + unwrapped_inner, meta, wrappers = unwrap_annotation(args[0]) + + if Annotated not in wrappers: + continue + + inner_args = get_args(unwrapped_inner) + + if not _is_any_optional(get_origin(unwrapped_inner), inner_args): + # this is where hint is like `Union[Annotated[Union[str, int], ...], NoneType]`, we add the outer union + # into the inner one, and re-wrap with Annotated + union_args = (*(inner_args or (unwrapped_inner,)), type(None)) + # calling `__class_getitem__` directly as in earlier py vers it is a syntax error to unpack into + # the getitem brackets, e.g., Annotated[T, *meta]. + hints[name] = Annotated.__class_getitem__((Union[union_args], *meta)) # type: ignore[attr-defined] + continue + + # this is where hint is like `Union[Annotated[Union[str, NoneType], ...], NoneType]`, we remove the + # redundant outer union + hints[name] = args[0] + return hints + + +def get_fn_type_hints(fn: Any, namespace: dict[str, Any] | None = None) -> dict[str, Any]: + """Resolve type hints for ``fn``. + + Args: + fn: Callable that is being inspected + namespace: Extra names for resolution of forward references. + + Returns: + Mapping of names to types. + """ + fn_to_inspect: Any = fn + + module_name = fn_to_inspect.__module__ + + if isclass(fn_to_inspect): + fn_to_inspect = fn_to_inspect.__init__ + + # detect objects that are not functions and that have a `__call__` method + if callable(fn_to_inspect) and ismethod(fn_to_inspect.__call__): + fn_to_inspect = fn_to_inspect.__call__ + + # inspect the underlying function for methods + if hasattr(fn_to_inspect, "__func__"): + fn_to_inspect = fn_to_inspect.__func__ + + # Order important. If a litestar name has been overridden in the function module, we want + # to use that instead of the litestar one. + namespace = { + **_GLOBAL_NAMES, + **vars(typing), + **vars(sys.modules[module_name]), + **(namespace or {}), + } + hints = get_type_hints(fn_to_inspect, globalns=namespace, include_extras=True) + + if sys.version_info < (3, 11): + # see https://github.com/litestar-org/litestar/pull/2516 + defaults = _get_defaults(fn_to_inspect) + hints = _unwrap_implicit_optional_hints(defaults, hints) + + return hints + + +@dataclass(frozen=True) +class ParsedSignature: + """Parsed signature. + + This object is the primary source of handler/dependency signature information. + + The only post-processing that occurs is the conversion of any forward referenced type annotations. + """ + + __slots__ = ("parameters", "return_type", "original_signature") + + parameters: dict[str, FieldDefinition] + """A mapping of parameter names to ParsedSignatureParameter instances.""" + return_type: FieldDefinition + """The return annotation of the callable.""" + original_signature: Signature + """The raw signature as returned by :func:`inspect.signature`""" + + def __deepcopy__(self, memo: dict[str, Any]) -> Self: + return type(self)( + parameters={k: deepcopy(v) for k, v in self.parameters.items()}, + return_type=deepcopy(self.return_type), + original_signature=deepcopy(self.original_signature), + ) + + @classmethod + def from_fn(cls, fn: AnyCallable, signature_namespace: dict[str, Any]) -> Self: + """Parse a function signature. + + Args: + fn: Any callable. + signature_namespace: mapping of names to types for forward reference resolution + + Returns: + ParsedSignature + """ + signature = Signature.from_callable(fn) + fn_type_hints = get_fn_type_hints(fn, namespace=signature_namespace) + + return cls.from_signature(signature, fn_type_hints) + + @classmethod + def from_signature(cls, signature: Signature, fn_type_hints: dict[str, type]) -> Self: + """Parse an :class:`inspect.Signature` instance. + + Args: + signature: An :class:`inspect.Signature` instance. + fn_type_hints: mapping of types + + Returns: + ParsedSignature + """ + + parameters = tuple( + FieldDefinition.from_parameter(parameter=parameter, fn_type_hints=fn_type_hints) + for name, parameter in signature.parameters.items() + if name not in ("self", "cls") + ) + + return_type = FieldDefinition.from_annotation(fn_type_hints.get("return", Any)) + + return cls( + parameters={p.name: p for p in parameters}, + return_type=return_type if "return" in fn_type_hints else replace(return_type, annotation=Empty), + original_signature=signature, + ) + + +def add_types_to_signature_namespace( + signature_types: Sequence[Any], signature_namespace: dict[str, Any] +) -> dict[str, Any]: + """Add types to ith signature namespace mapping. + + Types are added mapped to their `__name__` attribute. + + Args: + signature_types: A list of types to add to the signature namespace. + signature_namespace: The signature namespace to add types to. + + Raises: + ImproperlyConfiguredException: If a type is already defined in the signature namespace. + AttributeError: If a type does not have a `__name__` attribute. + + Returns: + The updated signature namespace. + """ + for typ in signature_types: + if (name := typ.__name__) in signature_namespace: + raise ImproperlyConfiguredException(f"Type '{name}' is already defined in the signature namespace") + signature_namespace[name] = typ + return signature_namespace diff --git a/venv/lib/python3.11/site-packages/litestar/utils/sync.py b/venv/lib/python3.11/site-packages/litestar/utils/sync.py new file mode 100644 index 0000000..02acabf --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/sync.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import ( + AsyncGenerator, + Awaitable, + Callable, + Generic, + Iterable, + Iterator, + TypeVar, +) + +from typing_extensions import ParamSpec + +from litestar.concurrency import sync_to_thread +from litestar.utils.predicates import is_async_callable + +__all__ = ("ensure_async_callable", "AsyncIteratorWrapper", "AsyncCallable", "is_async_callable") + + +P = ParamSpec("P") +T = TypeVar("T") + + +def ensure_async_callable(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]: + """Ensure that ``fn`` is an asynchronous callable. + If it is an asynchronous, return the original object, else wrap it in an + ``AsyncCallable`` + """ + if is_async_callable(fn): + return fn + return AsyncCallable(fn) # pyright: ignore + + +class AsyncCallable: + """Wrap a given callable to be called in a thread pool using + ``anyio.to_thread.run_sync``, keeping a reference to the original callable as + :attr:`func` + """ + + def __init__(self, fn: Callable[P, T]) -> None: # pyright: ignore + self.func = fn + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T]: # pyright: ignore + return sync_to_thread(self.func, *args, **kwargs) # pyright: ignore + + +class AsyncIteratorWrapper(Generic[T]): + """Asynchronous generator, wrapping an iterable or iterator.""" + + __slots__ = ("iterator", "generator") + + def __init__(self, iterator: Iterator[T] | Iterable[T]) -> None: + """Take a sync iterator or iterable and yields values from it asynchronously. + + Args: + iterator: A sync iterator or iterable. + """ + self.iterator = iterator if isinstance(iterator, Iterator) else iter(iterator) + self.generator = self._async_generator() + + def _call_next(self) -> T: + try: + return next(self.iterator) + except StopIteration as e: + raise ValueError from e + + async def _async_generator(self) -> AsyncGenerator[T, None]: + while True: + try: + yield await sync_to_thread(self._call_next) + except ValueError: + return + + def __aiter__(self) -> AsyncIteratorWrapper[T]: + return self + + async def __anext__(self) -> T: + return await self.generator.__anext__() diff --git a/venv/lib/python3.11/site-packages/litestar/utils/typing.py b/venv/lib/python3.11/site-packages/litestar/utils/typing.py new file mode 100644 index 0000000..9da6c2a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/typing.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +import re +from collections import abc, defaultdict, deque +from typing import ( + AbstractSet, + Any, + AsyncGenerator, + AsyncIterable, + AsyncIterator, + Awaitable, + Collection, + Container, + Coroutine, + DefaultDict, + Deque, + Dict, + FrozenSet, + Generator, + ItemsView, + Iterable, + Iterator, + KeysView, + List, + Mapping, + MappingView, + MutableMapping, + MutableSequence, + MutableSet, + Reversible, + Sequence, + Set, + Tuple, + TypeVar, + Union, + ValuesView, + cast, +) + +from typing_extensions import Annotated, NotRequired, Required, get_args, get_origin, get_type_hints + +from litestar.types.builtin_types import NoneType, UnionTypes + +__all__ = ( + "get_instantiable_origin", + "get_origin_or_inner_type", + "get_safe_generic_origin", + "instantiable_type_mapping", + "make_non_optional_union", + "safe_generic_origin_map", + "unwrap_annotation", +) + + +T = TypeVar("T") +UnionT = TypeVar("UnionT", bound="Union") + +tuple_types_regex = re.compile( + "^" + + "|".join( + [*[repr(x) for x in (List, Sequence, Iterable, Iterator, Tuple, Deque)], "tuple", "list", "collections.deque"] + ) +) + +instantiable_type_mapping = { + AbstractSet: set, + DefaultDict: defaultdict, + Deque: deque, + Dict: dict, + FrozenSet: frozenset, + List: list, + Mapping: dict, + MutableMapping: dict, + MutableSequence: list, + MutableSet: set, + Sequence: list, + Set: set, + Tuple: tuple, + abc.Mapping: dict, + abc.MutableMapping: dict, + abc.MutableSequence: list, + abc.MutableSet: set, + abc.Sequence: list, + abc.Set: set, + defaultdict: defaultdict, + deque: deque, + dict: dict, + frozenset: frozenset, + list: list, + set: set, + tuple: tuple, +} + +safe_generic_origin_map = { + set: AbstractSet, + defaultdict: DefaultDict, + deque: Deque, + dict: Dict, + frozenset: FrozenSet, + list: List, + tuple: Tuple, + abc.Mapping: Mapping, + abc.MutableMapping: MutableMapping, + abc.MutableSequence: MutableSequence, + abc.MutableSet: MutableSet, + abc.Sequence: Sequence, + abc.Set: AbstractSet, + abc.Collection: Collection, + abc.Container: Container, + abc.ItemsView: ItemsView, + abc.KeysView: KeysView, + abc.MappingView: MappingView, + abc.ValuesView: ValuesView, + abc.Iterable: Iterable, + abc.Iterator: Iterator, + abc.Generator: Generator, + abc.Reversible: Reversible, + abc.Coroutine: Coroutine, + abc.AsyncGenerator: AsyncGenerator, + abc.AsyncIterable: AsyncIterable, + abc.AsyncIterator: AsyncIterator, + abc.Awaitable: Awaitable, + **{union_t: Union for union_t in UnionTypes}, +} +"""A mapping of types to equivalent types that are safe to be used as generics across all Python versions. + +This is necessary because occasionally we want to rebuild a generic outer type with different args, and types such as +``collections.abc.Mapping``, are not valid generic types in Python 3.8. +""" + +wrapper_type_set = {Annotated, Required, NotRequired} +"""Types that always contain a wrapped type annotation as their first arg.""" + + +def normalize_type_annotation(annotation: Any) -> Any: + """Normalize a type annotation to a standard form.""" + return instantiable_type_mapping.get(annotation, annotation) + + +def make_non_optional_union(annotation: UnionT | None) -> UnionT: + """Make a :data:`Union <typing.Union>` type that excludes ``NoneType``. + + Args: + annotation: A type annotation. + + Returns: + The union with all original members, except ``NoneType``. + """ + args = tuple(tp for tp in get_args(annotation) if tp is not NoneType) + return cast("UnionT", Union[args]) # pyright: ignore + + +def unwrap_annotation(annotation: Any) -> tuple[Any, tuple[Any, ...], set[Any]]: + """Remove "wrapper" annotation types, such as ``Annotated``, ``Required``, and ``NotRequired``. + + Note: + ``annotation`` should have been retrieved from :func:`get_type_hints()` with ``include_extras=True``. This + ensures that any nested ``Annotated`` types are flattened according to the PEP 593 specification. + + Args: + annotation: A type annotation. + + Returns: + A tuple of the unwrapped annotation and any ``Annotated`` metadata, and a set of any wrapper types encountered. + """ + origin = get_origin(annotation) + wrappers = set() + metadata = [] + while origin in wrapper_type_set: + wrappers.add(origin) + annotation, *meta = get_args(annotation) + metadata.extend(meta) + origin = get_origin(annotation) + return annotation, tuple(metadata), wrappers + + +def get_origin_or_inner_type(annotation: Any) -> Any: + """Get origin or unwrap it. Returns None for non-generic types. + + Args: + annotation: A type annotation. + + Returns: + Any type. + """ + origin = get_origin(annotation) + if origin in wrapper_type_set: + inner, _, _ = unwrap_annotation(annotation) + # we need to recursively call here 'get_origin_or_inner_type' because we might be dealing + # with a generic type alias e.g. Annotated[dict[str, list[int]] + origin = get_origin_or_inner_type(inner) + return instantiable_type_mapping.get(origin, origin) + + +def get_safe_generic_origin(origin_type: Any, annotation: Any) -> Any: + """Get a type that is safe to use as a generic type across all supported Python versions. + + If a builtin collection type is annotated without generic args, e.g, ``a: dict``, then the origin type will be + ``None``. In this case, we can use the annotation to determine the correct generic type, if one exists. + + Args: + origin_type: A type - would be the return value of :func:`get_origin()`. + annotation: Type annotation associated with the origin type. Should be unwrapped from any wrapper types, such + as ``Annotated``. + + Returns: + The ``typing`` module equivalent of the given type, if it exists. Otherwise, the original type is returned. + """ + if origin_type is None: + return safe_generic_origin_map.get(annotation) + return safe_generic_origin_map.get(origin_type, origin_type) + + +def get_instantiable_origin(origin_type: Any, annotation: Any) -> Any: + """Get a type that is safe to instantiate for the given origin type. + + If a builtin collection type is annotated without generic args, e.g, ``a: dict``, then the origin type will be + ``None``. In this case, we can use the annotation to determine the correct instantiable type, if one exists. + + Args: + origin_type: A type - would be the return value of :func:`get_origin()`. + annotation: Type annotation associated with the origin type. Should be unwrapped from any wrapper types, such + as ``Annotated``. + + Returns: + A builtin type that is safe to instantiate for the given origin type. + """ + if origin_type is None: + return instantiable_type_mapping.get(annotation) + return instantiable_type_mapping.get(origin_type, origin_type) + + +def get_type_hints_with_generics_resolved( + annotation: Any, + globalns: dict[str, Any] | None = None, + localns: dict[str, Any] | None = None, + include_extras: bool = False, + type_hints: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Get the type hints for the given object after resolving the generic types as much as possible. + + Args: + annotation: A type annotation. + globalns: The global namespace. + localns: The local namespace. + include_extras: A flag indicating whether to include the ``Annotated[T, ...]`` or not. + type_hints: Already resolved type hints + """ + origin = get_origin(annotation) + + if origin is None: + # Implies the generic types have not been specified in the annotation + if type_hints is None: # pragma: no cover + type_hints = get_type_hints(annotation, globalns=globalns, localns=localns, include_extras=include_extras) + typevar_map = {p: p for p in annotation.__parameters__} + else: + if type_hints is None: # pragma: no cover + type_hints = get_type_hints(origin, globalns=globalns, localns=localns, include_extras=include_extras) + # the __parameters__ is only available on the origin itself and not the annotation + typevar_map = dict(zip(origin.__parameters__, get_args(annotation))) + + return {n: _substitute_typevars(type_, typevar_map) for n, type_ in type_hints.items()} + + +def _substitute_typevars(obj: Any, typevar_map: Mapping[Any, Any]) -> Any: + if params := getattr(obj, "__parameters__", None): + args = tuple(_substitute_typevars(typevar_map.get(p, p), typevar_map) for p in params) + return obj[args] + + if isinstance(obj, TypeVar): + # If there's a mapped type for the TypeVar already, then it should be returned instead + # of considering __constraints__ or __bound__. For a generic `Foo[T]`, if Foo[int] is given + # then int should be returned and if `Foo` is given then the __bounds__ and __constraints__ + # should be considered. + if (type_ := typevar_map.get(obj, None)) is not None and not isinstance(type_, TypeVar): + return type_ + + if obj.__bound__ is not None: + return obj.__bound__ + + if obj.__constraints__: + return Union[obj.__constraints__] # pyright: ignore + + return obj diff --git a/venv/lib/python3.11/site-packages/litestar/utils/version.py b/venv/lib/python3.11/site-packages/litestar/utils/version.py new file mode 100644 index 0000000..d7974eb --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/version.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import re +import sys +from typing import Literal, NamedTuple + +__all__ = ("Version", "get_version", "parse_version") + + +if sys.version_info >= (3, 10): + import importlib.metadata as importlib_metadata +else: + import importlib_metadata + + +_ReleaseLevel = Literal["alpha", "beta", "rc", "final"] +_PRE_RELEASE_TAGS = {"alpha", "a", "beta", "b", "rc"} +_PRE_RELEASE_TAGS_CONVERSIONS: dict[str, _ReleaseLevel] = {"a": "alpha", "b": "beta"} + +_VERSION_PARTS_RE = re.compile(r"(\d+|[a-z]+|\.)") + + +class Version(NamedTuple): + """Litestar version information""" + + major: int + minor: int + patch: int + release_level: _ReleaseLevel + serial: int + + def formatted(self, short: bool = False) -> str: + version = f"{self.major}.{self.minor}.{self.patch}" + if not short: + version += f"{self.release_level}{self.serial}" + return version + + +def parse_version(raw_version: str) -> Version: + """Parse a version string into a :class:`Version`""" + parts = [p for p in _VERSION_PARTS_RE.split(raw_version) if p and p != "."] + release_level: _ReleaseLevel = "final" + serial = "0" + + if len(parts) == 3: + major, minor, patch = parts + elif len(parts) == 5: + major, minor, patch, release_level, serial = parts # type: ignore[assignment] + if release_level not in _PRE_RELEASE_TAGS: + raise ValueError(f"Invalid release level: {release_level}") + release_level = _PRE_RELEASE_TAGS_CONVERSIONS.get(release_level, release_level) + else: + raise ValueError(f"Invalid version: {raw_version}") + + return Version( + major=int(major), minor=int(minor), patch=int(patch), release_level=release_level, serial=int(serial) + ) + + +def get_version() -> Version: + """Get the version of the installed litestar package""" + return parse_version(importlib_metadata.version("litestar")) diff --git a/venv/lib/python3.11/site-packages/litestar/utils/warnings.py b/venv/lib/python3.11/site-packages/litestar/utils/warnings.py new file mode 100644 index 0000000..e20484b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/utils/warnings.py @@ -0,0 +1,51 @@ +import os +import warnings + +from litestar.exceptions import LitestarWarning +from litestar.types import AnyCallable, AnyGenerator + + +def warn_implicit_sync_to_thread(source: AnyCallable, stacklevel: int = 2) -> None: + if os.getenv("LITESTAR_WARN_IMPLICIT_SYNC_TO_THREAD") == "0": + return + + warnings.warn( + f"Use of a synchronous callable {source} without setting sync_to_thread is " + "discouraged since synchronous callables can block the main thread if they " + "perform blocking operations. If the callable is guaranteed to be non-blocking, " + "you can set sync_to_thread=False to skip this warning, or set the environment" + "variable LITESTAR_WARN_IMPLICIT_SYNC_TO_THREAD=0 to disable warnings of this " + "type entirely.", + category=LitestarWarning, + stacklevel=stacklevel, + ) + + +def warn_sync_to_thread_with_async_callable(source: AnyCallable, stacklevel: int = 2) -> None: + if os.getenv("LITESTAR_WARN_SYNC_TO_THREAD_WITH_ASYNC") == "0": + return + + warnings.warn( + f"Use of an asynchronous callable {source} with sync_to_thread; sync_to_thread " + "has no effect on async callable. You can disable this warning by setting " + "LITESTAR_WARN_SYNC_TO_THREAD_WITH_ASYNC=0", + category=LitestarWarning, + stacklevel=stacklevel, + ) + + +def warn_sync_to_thread_with_generator(source: AnyGenerator, stacklevel: int = 2) -> None: + if os.getenv("LITESTAR_WARN_SYNC_TO_THREAD_WITH_GENERATOR") == "0": + return + + warnings.warn( + f"Use of generator {source} with sync_to_thread; sync_to_thread has no effect " + "on generators. You can disable this warning by setting " + "LITESTAR_WARN_SYNC_TO_THREAD_WITH_GENERATOR=0", + category=LitestarWarning, + stacklevel=stacklevel, + ) + + +def warn_pdb_on_exception(stacklevel: int = 2) -> None: + warnings.warn("Python Debugger on exception enabled", category=LitestarWarning, stacklevel=stacklevel) |