diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/handlers/base.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/handlers/base.py | 577 |
1 files changed, 577 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/base.py b/venv/lib/python3.11/site-packages/litestar/handlers/base.py new file mode 100644 index 0000000..9dbb70e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/base.py @@ -0,0 +1,577 @@ +from __future__ import annotations + +from copy import copy +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence, cast + +from litestar._signature import SignatureModel +from litestar.config.app import ExperimentalFeatures +from litestar.di import Provide +from litestar.dto import DTOData +from litestar.exceptions import ImproperlyConfiguredException +from litestar.plugins import DIPlugin, PluginRegistry +from litestar.serialization import default_deserializer, default_serializer +from litestar.types import ( + Dependencies, + Empty, + ExceptionHandlersMap, + Guard, + Middleware, + TypeDecodersSequence, + TypeEncodersMap, +) +from litestar.typing import FieldDefinition +from litestar.utils import ensure_async_callable, get_name, normalize_path +from litestar.utils.helpers import unwrap_partial +from litestar.utils.signature import ParsedSignature, add_types_to_signature_namespace + +if TYPE_CHECKING: + from typing_extensions import Self + + from litestar.app import Litestar + from litestar.connection import ASGIConnection + from litestar.controller import Controller + from litestar.dto import AbstractDTO + from litestar.dto._backend import DTOBackend + from litestar.params import ParameterKwarg + from litestar.router import Router + from litestar.types import AnyCallable, AsyncAnyCallable, ExceptionHandler + from litestar.types.empty import EmptyType + +__all__ = ("BaseRouteHandler",) + + +class BaseRouteHandler: + """Base route handler. + + Serves as a subclass for all route handlers + """ + + __slots__ = ( + "_fn", + "_parsed_data_field", + "_parsed_fn_signature", + "_parsed_return_field", + "_resolved_data_dto", + "_resolved_dependencies", + "_resolved_guards", + "_resolved_layered_parameters", + "_resolved_return_dto", + "_resolved_signature_namespace", + "_resolved_type_decoders", + "_resolved_type_encoders", + "_signature_model", + "dependencies", + "dto", + "exception_handlers", + "guards", + "middleware", + "name", + "opt", + "owner", + "paths", + "return_dto", + "signature_namespace", + "type_decoders", + "type_encoders", + ) + + def __init__( + self, + path: str | Sequence[str] | None = None, + *, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + middleware: Sequence[Middleware] | None = None, + name: str | None = None, + opt: Mapping[str, Any] | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + signature_namespace: Mapping[str, Any] | None = None, + signature_types: Sequence[Any] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + **kwargs: Any, + ) -> None: + """Initialize ``HTTPRouteHandler``. + + Args: + path: A path fragment for the route handler function or a sequence of path fragments. If not given defaults + to ``/`` + dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + name: A string identifying the route handler. + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or + :class:`ASGI Scope <.types.Scope>`. + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature + modelling. + signature_types: A sequence of types for use in forward reference resolution during signature modeling. + These types will be added to the signature namespace using their ``__name__`` attribute. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + **kwargs: Any additional kwarg - will be set in the opt dictionary. + """ + self._parsed_fn_signature: ParsedSignature | EmptyType = Empty + self._parsed_return_field: FieldDefinition | EmptyType = Empty + self._parsed_data_field: FieldDefinition | None | EmptyType = Empty + self._resolved_data_dto: type[AbstractDTO] | None | EmptyType = Empty + self._resolved_dependencies: dict[str, Provide] | EmptyType = Empty + self._resolved_guards: list[Guard] | EmptyType = Empty + self._resolved_layered_parameters: dict[str, FieldDefinition] | EmptyType = Empty + self._resolved_return_dto: type[AbstractDTO] | None | EmptyType = Empty + self._resolved_signature_namespace: dict[str, Any] | EmptyType = Empty + self._resolved_type_decoders: TypeDecodersSequence | EmptyType = Empty + self._resolved_type_encoders: TypeEncodersMap | EmptyType = Empty + self._signature_model: type[SignatureModel] | EmptyType = Empty + + self.dependencies = dependencies + self.dto = dto + self.exception_handlers = exception_handlers + self.guards = guards + self.middleware = middleware + self.name = name + self.opt = dict(opt or {}) + self.opt.update(**kwargs) + self.owner: Controller | Router | None = None + self.return_dto = return_dto + self.signature_namespace = add_types_to_signature_namespace( + signature_types or [], dict(signature_namespace or {}) + ) + self.type_decoders = type_decoders + self.type_encoders = type_encoders + + self.paths = ( + {normalize_path(p) for p in path} if path and isinstance(path, list) else {normalize_path(path or "/")} # type: ignore[arg-type] + ) + + def __call__(self, fn: AsyncAnyCallable) -> Self: + """Replace a function with itself.""" + self._fn = fn + return self + + @property + def handler_id(self) -> str: + """A unique identifier used for generation of DTOs.""" + return f"{self!s}::{sum(id(layer) for layer in self.ownership_layers)}" + + @property + def default_deserializer(self) -> Callable[[Any, Any], Any]: + """Get a default deserializer for the route handler. + + Returns: + A default deserializer for the route handler. + + """ + return partial(default_deserializer, type_decoders=self.resolve_type_decoders()) + + @property + def default_serializer(self) -> Callable[[Any], Any]: + """Get a default serializer for the route handler. + + Returns: + A default serializer for the route handler. + + """ + return partial(default_serializer, type_encoders=self.resolve_type_encoders()) + + @property + def signature_model(self) -> type[SignatureModel]: + """Get the signature model for the route handler. + + Returns: + A signature model for the route handler. + + """ + if self._signature_model is Empty: + self._signature_model = SignatureModel.create( + dependency_name_set=self.dependency_name_set, + fn=cast("AnyCallable", self.fn), + parsed_signature=self.parsed_fn_signature, + data_dto=self.resolve_data_dto(), + type_decoders=self.resolve_type_decoders(), + ) + return self._signature_model + + @property + def fn(self) -> AsyncAnyCallable: + """Get the handler function. + + Raises: + ImproperlyConfiguredException: if handler fn is not set. + + Returns: + Handler function + """ + if not hasattr(self, "_fn"): + raise ImproperlyConfiguredException("No callable has been registered for this handler") + return self._fn + + @property + def parsed_fn_signature(self) -> ParsedSignature: + """Return the parsed signature of the handler function. + + This method is memoized so the computation occurs only once. + + Returns: + A ParsedSignature instance + """ + if self._parsed_fn_signature is Empty: + self._parsed_fn_signature = ParsedSignature.from_fn( + unwrap_partial(self.fn), self.resolve_signature_namespace() + ) + + return self._parsed_fn_signature + + @property + def parsed_return_field(self) -> FieldDefinition: + if self._parsed_return_field is Empty: + self._parsed_return_field = self.parsed_fn_signature.return_type + return self._parsed_return_field + + @property + def parsed_data_field(self) -> FieldDefinition | None: + if self._parsed_data_field is Empty: + self._parsed_data_field = self.parsed_fn_signature.parameters.get("data") + return self._parsed_data_field + + @property + def handler_name(self) -> str: + """Get the name of the handler function. + + Raises: + ImproperlyConfiguredException: if handler fn is not set. + + Returns: + Name of the handler function + """ + return get_name(unwrap_partial(self.fn)) + + @property + def dependency_name_set(self) -> set[str]: + """Set of all dependency names provided in the handler's ownership layers.""" + layered_dependencies = (layer.dependencies or {} for layer in self.ownership_layers) + return {name for layer in layered_dependencies for name in layer} # pyright: ignore + + @property + def ownership_layers(self) -> list[Self | Controller | Router]: + """Return the handler layers from the app down to the route handler. + + ``app -> ... -> route handler`` + """ + layers = [] + + cur: Any = self + while cur: + layers.append(cur) + cur = cur.owner + + return list(reversed(layers)) + + @property + def app(self) -> Litestar: + return cast("Litestar", self.ownership_layers[0]) + + def resolve_type_encoders(self) -> TypeEncodersMap: + """Return a merged type_encoders mapping. + + This method is memoized so the computation occurs only once. + + Returns: + A dict of type encoders + """ + if self._resolved_type_encoders is Empty: + self._resolved_type_encoders = {} + + for layer in self.ownership_layers: + if type_encoders := getattr(layer, "type_encoders", None): + self._resolved_type_encoders.update(type_encoders) + return cast("TypeEncodersMap", self._resolved_type_encoders) + + def resolve_type_decoders(self) -> TypeDecodersSequence: + """Return a merged type_encoders mapping. + + This method is memoized so the computation occurs only once. + + Returns: + A dict of type encoders + """ + if self._resolved_type_decoders is Empty: + self._resolved_type_decoders = [] + + for layer in self.ownership_layers: + if type_decoders := getattr(layer, "type_decoders", None): + self._resolved_type_decoders.extend(list(type_decoders)) + return cast("TypeDecodersSequence", self._resolved_type_decoders) + + def resolve_layered_parameters(self) -> dict[str, FieldDefinition]: + """Return all parameters declared above the handler.""" + if self._resolved_layered_parameters is Empty: + parameter_kwargs: dict[str, ParameterKwarg] = {} + + for layer in self.ownership_layers: + parameter_kwargs.update(getattr(layer, "parameters", {}) or {}) + + self._resolved_layered_parameters = { + key: FieldDefinition.from_kwarg(name=key, annotation=parameter.annotation, kwarg_definition=parameter) + for key, parameter in parameter_kwargs.items() + } + + return self._resolved_layered_parameters + + def resolve_guards(self) -> list[Guard]: + """Return all guards in the handlers scope, starting from highest to current layer.""" + if self._resolved_guards is Empty: + self._resolved_guards = [] + + for layer in self.ownership_layers: + self._resolved_guards.extend(layer.guards or []) # pyright: ignore + + self._resolved_guards = cast( + "list[Guard]", [ensure_async_callable(guard) for guard in self._resolved_guards] + ) + + return self._resolved_guards + + def _get_plugin_registry(self) -> PluginRegistry | None: + from litestar.app import Litestar + + root_owner = self.ownership_layers[0] + if isinstance(root_owner, Litestar): + return root_owner.plugins + return None + + def resolve_dependencies(self) -> dict[str, Provide]: + """Return all dependencies correlating to handler function's kwargs that exist in the handler's scope.""" + plugin_registry = self._get_plugin_registry() + if self._resolved_dependencies is Empty: + self._resolved_dependencies = {} + for layer in self.ownership_layers: + for key, provider in (layer.dependencies or {}).items(): + self._resolved_dependencies[key] = self._resolve_dependency( + key=key, provider=provider, plugin_registry=plugin_registry + ) + + return self._resolved_dependencies + + def _resolve_dependency( + self, key: str, provider: Provide | AnyCallable, plugin_registry: PluginRegistry | None + ) -> Provide: + if not isinstance(provider, Provide): + provider = Provide(provider) + + if self._resolved_dependencies is not Empty: # pragma: no cover + self._validate_dependency_is_unique(dependencies=self._resolved_dependencies, key=key, provider=provider) + + if not getattr(provider, "parsed_fn_signature", None): + dependency = unwrap_partial(provider.dependency) + plugin: DIPlugin | None = None + if plugin_registry: + plugin = next( + (p for p in plugin_registry.di if isinstance(p, DIPlugin) and p.has_typed_init(dependency)), + None, + ) + if plugin: + signature, init_type_hints = plugin.get_typed_init(dependency) + provider.parsed_fn_signature = ParsedSignature.from_signature(signature, init_type_hints) + else: + provider.parsed_fn_signature = ParsedSignature.from_fn(dependency, self.resolve_signature_namespace()) + + if not getattr(provider, "signature_model", None): + provider.signature_model = SignatureModel.create( + dependency_name_set=self.dependency_name_set, + fn=provider.dependency, + parsed_signature=provider.parsed_fn_signature, + data_dto=self.resolve_data_dto(), + type_decoders=self.resolve_type_decoders(), + ) + return provider + + def resolve_middleware(self) -> list[Middleware]: + """Build the middleware stack for the RouteHandler and return it. + + The middlewares are added from top to bottom (``app -> router -> controller -> route handler``) and then + reversed. + """ + resolved_middleware: list[Middleware] = [] + for layer in self.ownership_layers: + resolved_middleware.extend(layer.middleware or []) # pyright: ignore + return list(reversed(resolved_middleware)) + + def resolve_exception_handlers(self) -> ExceptionHandlersMap: + """Resolve the exception_handlers by starting from the route handler and moving up. + + This method is memoized so the computation occurs only once. + """ + resolved_exception_handlers: dict[int | type[Exception], ExceptionHandler] = {} + for layer in self.ownership_layers: + resolved_exception_handlers.update(layer.exception_handlers or {}) # pyright: ignore + return resolved_exception_handlers + + def resolve_opts(self) -> None: + """Build the route handler opt dictionary by going from top to bottom. + + When merging keys from multiple layers, if the same key is defined by multiple layers, the value from the + layer closest to the response handler will take precedence. + """ + + opt: dict[str, Any] = {} + for layer in self.ownership_layers: + opt.update(layer.opt or {}) # pyright: ignore + + self.opt = opt + + def resolve_signature_namespace(self) -> dict[str, Any]: + """Build the route handler signature namespace dictionary by going from top to bottom. + + When merging keys from multiple layers, if the same key is defined by multiple layers, the value from the + layer closest to the response handler will take precedence. + """ + if self._resolved_layered_parameters is Empty: + ns: dict[str, Any] = {} + for layer in self.ownership_layers: + ns.update(layer.signature_namespace) + + self._resolved_signature_namespace = ns + return cast("dict[str, Any]", self._resolved_signature_namespace) + + def _get_dto_backend_cls(self) -> type[DTOBackend] | None: + if ExperimentalFeatures.DTO_CODEGEN in self.app.experimental_features: + from litestar.dto._codegen_backend import DTOCodegenBackend + + return DTOCodegenBackend + return None + + def resolve_data_dto(self) -> type[AbstractDTO] | None: + """Resolve the data_dto by starting from the route handler and moving up. + If a handler is found it is returned, otherwise None is set. + This method is memoized so the computation occurs only once. + + Returns: + An optional :class:`DTO type <.dto.base_dto.AbstractDTO>` + """ + if self._resolved_data_dto is Empty: + if data_dtos := cast( + "list[type[AbstractDTO] | None]", + [layer.dto for layer in self.ownership_layers if layer.dto is not Empty], + ): + data_dto: type[AbstractDTO] | None = data_dtos[-1] + elif self.parsed_data_field and ( + plugins_for_data_type := [ + plugin + for plugin in self.app.plugins.serialization + if self.parsed_data_field.match_predicate_recursively(plugin.supports_type) + ] + ): + data_dto = plugins_for_data_type[0].create_dto_for_type(self.parsed_data_field) + else: + data_dto = None + + if self.parsed_data_field and data_dto: + data_dto.create_for_field_definition( + field_definition=self.parsed_data_field, + handler_id=self.handler_id, + backend_cls=self._get_dto_backend_cls(), + ) + + self._resolved_data_dto = data_dto + + return self._resolved_data_dto + + def resolve_return_dto(self) -> type[AbstractDTO] | None: + """Resolve the return_dto by starting from the route handler and moving up. + If a handler is found it is returned, otherwise None is set. + This method is memoized so the computation occurs only once. + + Returns: + An optional :class:`DTO type <.dto.base_dto.AbstractDTO>` + """ + if self._resolved_return_dto is Empty: + if return_dtos := cast( + "list[type[AbstractDTO] | None]", + [layer.return_dto for layer in self.ownership_layers if layer.return_dto is not Empty], + ): + return_dto: type[AbstractDTO] | None = return_dtos[-1] + elif plugins_for_return_type := [ + plugin + for plugin in self.app.plugins.serialization + if self.parsed_return_field.match_predicate_recursively(plugin.supports_type) + ]: + return_dto = plugins_for_return_type[0].create_dto_for_type(self.parsed_return_field) + else: + return_dto = self.resolve_data_dto() + + if return_dto and return_dto.is_supported_model_type_field(self.parsed_return_field): + return_dto.create_for_field_definition( + field_definition=self.parsed_return_field, + handler_id=self.handler_id, + backend_cls=self._get_dto_backend_cls(), + ) + self._resolved_return_dto = return_dto + else: + self._resolved_return_dto = None + + return self._resolved_return_dto + + async def authorize_connection(self, connection: ASGIConnection) -> None: + """Ensure the connection is authorized by running all the route guards in scope.""" + for guard in self.resolve_guards(): + await guard(connection, copy(self)) # type: ignore[misc] + + @staticmethod + def _validate_dependency_is_unique(dependencies: dict[str, Provide], key: str, provider: Provide) -> None: + """Validate that a given provider has not been already defined under a different key.""" + for dependency_key, value in dependencies.items(): + if provider == value: + raise ImproperlyConfiguredException( + f"Provider for key {key} is already defined under the different key {dependency_key}. " + f"If you wish to override a provider, it must have the same key." + ) + + def on_registration(self, app: Litestar) -> None: + """Called once per handler when the app object is instantiated. + + Args: + app: The :class:`Litestar<.app.Litestar>` app object. + + Returns: + None + """ + self._validate_handler_function() + self.resolve_dependencies() + self.resolve_guards() + self.resolve_middleware() + self.resolve_opts() + self.resolve_data_dto() + self.resolve_return_dto() + + def _validate_handler_function(self) -> None: + """Validate the route handler function once set by inspecting its return annotations.""" + if ( + self.parsed_data_field is not None + and self.parsed_data_field.is_subclass_of(DTOData) + and not self.resolve_data_dto() + ): + raise ImproperlyConfiguredException( + f"Handler function {self.handler_name} has a data parameter that is a subclass of DTOData but no " + "DTO has been registered for it." + ) + + def __str__(self) -> str: + """Return a unique identifier for the route handler. + + Returns: + A string + """ + target: type[AsyncAnyCallable] | AsyncAnyCallable # pyright: ignore + target = unwrap_partial(self.fn) + if not hasattr(target, "__qualname__"): + target = type(target) + return f"{target.__module__}.{target.__qualname__}" |