summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/handlers/base.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/handlers/base.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/handlers/base.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/handlers/base.py577
1 files changed, 577 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/base.py b/venv/lib/python3.11/site-packages/litestar/handlers/base.py
new file mode 100644
index 0000000..9dbb70e
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/handlers/base.py
@@ -0,0 +1,577 @@
+from __future__ import annotations
+
+from copy import copy
+from functools import partial
+from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence, cast
+
+from litestar._signature import SignatureModel
+from litestar.config.app import ExperimentalFeatures
+from litestar.di import Provide
+from litestar.dto import DTOData
+from litestar.exceptions import ImproperlyConfiguredException
+from litestar.plugins import DIPlugin, PluginRegistry
+from litestar.serialization import default_deserializer, default_serializer
+from litestar.types import (
+ Dependencies,
+ Empty,
+ ExceptionHandlersMap,
+ Guard,
+ Middleware,
+ TypeDecodersSequence,
+ TypeEncodersMap,
+)
+from litestar.typing import FieldDefinition
+from litestar.utils import ensure_async_callable, get_name, normalize_path
+from litestar.utils.helpers import unwrap_partial
+from litestar.utils.signature import ParsedSignature, add_types_to_signature_namespace
+
+if TYPE_CHECKING:
+ from typing_extensions import Self
+
+ from litestar.app import Litestar
+ from litestar.connection import ASGIConnection
+ from litestar.controller import Controller
+ from litestar.dto import AbstractDTO
+ from litestar.dto._backend import DTOBackend
+ from litestar.params import ParameterKwarg
+ from litestar.router import Router
+ from litestar.types import AnyCallable, AsyncAnyCallable, ExceptionHandler
+ from litestar.types.empty import EmptyType
+
+__all__ = ("BaseRouteHandler",)
+
+
+class BaseRouteHandler:
+ """Base route handler.
+
+ Serves as a subclass for all route handlers
+ """
+
+ __slots__ = (
+ "_fn",
+ "_parsed_data_field",
+ "_parsed_fn_signature",
+ "_parsed_return_field",
+ "_resolved_data_dto",
+ "_resolved_dependencies",
+ "_resolved_guards",
+ "_resolved_layered_parameters",
+ "_resolved_return_dto",
+ "_resolved_signature_namespace",
+ "_resolved_type_decoders",
+ "_resolved_type_encoders",
+ "_signature_model",
+ "dependencies",
+ "dto",
+ "exception_handlers",
+ "guards",
+ "middleware",
+ "name",
+ "opt",
+ "owner",
+ "paths",
+ "return_dto",
+ "signature_namespace",
+ "type_decoders",
+ "type_encoders",
+ )
+
+ def __init__(
+ self,
+ path: str | Sequence[str] | None = None,
+ *,
+ dependencies: Dependencies | None = None,
+ dto: type[AbstractDTO] | None | EmptyType = Empty,
+ exception_handlers: ExceptionHandlersMap | None = None,
+ guards: Sequence[Guard] | None = None,
+ middleware: Sequence[Middleware] | None = None,
+ name: str | None = None,
+ opt: Mapping[str, Any] | None = None,
+ return_dto: type[AbstractDTO] | None | EmptyType = Empty,
+ signature_namespace: Mapping[str, Any] | None = None,
+ signature_types: Sequence[Any] | None = None,
+ type_decoders: TypeDecodersSequence | None = None,
+ type_encoders: TypeEncodersMap | None = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize ``HTTPRouteHandler``.
+
+ Args:
+ path: A path fragment for the route handler function or a sequence of path fragments. If not given defaults
+ to ``/``
+ dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances.
+ dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and
+ validation of request data.
+ exception_handlers: A mapping of status codes and/or exception types to handler functions.
+ guards: A sequence of :class:`Guard <.types.Guard>` callables.
+ middleware: A sequence of :class:`Middleware <.types.Middleware>`.
+ name: A string identifying the route handler.
+ opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or
+ wherever you have access to :class:`Request <.connection.Request>` or
+ :class:`ASGI Scope <.types.Scope>`.
+ return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing
+ outbound response data.
+ signature_namespace: A mapping of names to types for use in forward reference resolution during signature
+ modelling.
+ signature_types: A sequence of types for use in forward reference resolution during signature modeling.
+ These types will be added to the signature namespace using their ``__name__`` attribute.
+ type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization.
+ type_encoders: A mapping of types to callables that transform them into types supported for serialization.
+ **kwargs: Any additional kwarg - will be set in the opt dictionary.
+ """
+ self._parsed_fn_signature: ParsedSignature | EmptyType = Empty
+ self._parsed_return_field: FieldDefinition | EmptyType = Empty
+ self._parsed_data_field: FieldDefinition | None | EmptyType = Empty
+ self._resolved_data_dto: type[AbstractDTO] | None | EmptyType = Empty
+ self._resolved_dependencies: dict[str, Provide] | EmptyType = Empty
+ self._resolved_guards: list[Guard] | EmptyType = Empty
+ self._resolved_layered_parameters: dict[str, FieldDefinition] | EmptyType = Empty
+ self._resolved_return_dto: type[AbstractDTO] | None | EmptyType = Empty
+ self._resolved_signature_namespace: dict[str, Any] | EmptyType = Empty
+ self._resolved_type_decoders: TypeDecodersSequence | EmptyType = Empty
+ self._resolved_type_encoders: TypeEncodersMap | EmptyType = Empty
+ self._signature_model: type[SignatureModel] | EmptyType = Empty
+
+ self.dependencies = dependencies
+ self.dto = dto
+ self.exception_handlers = exception_handlers
+ self.guards = guards
+ self.middleware = middleware
+ self.name = name
+ self.opt = dict(opt or {})
+ self.opt.update(**kwargs)
+ self.owner: Controller | Router | None = None
+ self.return_dto = return_dto
+ self.signature_namespace = add_types_to_signature_namespace(
+ signature_types or [], dict(signature_namespace or {})
+ )
+ self.type_decoders = type_decoders
+ self.type_encoders = type_encoders
+
+ self.paths = (
+ {normalize_path(p) for p in path} if path and isinstance(path, list) else {normalize_path(path or "/")} # type: ignore[arg-type]
+ )
+
+ def __call__(self, fn: AsyncAnyCallable) -> Self:
+ """Replace a function with itself."""
+ self._fn = fn
+ return self
+
+ @property
+ def handler_id(self) -> str:
+ """A unique identifier used for generation of DTOs."""
+ return f"{self!s}::{sum(id(layer) for layer in self.ownership_layers)}"
+
+ @property
+ def default_deserializer(self) -> Callable[[Any, Any], Any]:
+ """Get a default deserializer for the route handler.
+
+ Returns:
+ A default deserializer for the route handler.
+
+ """
+ return partial(default_deserializer, type_decoders=self.resolve_type_decoders())
+
+ @property
+ def default_serializer(self) -> Callable[[Any], Any]:
+ """Get a default serializer for the route handler.
+
+ Returns:
+ A default serializer for the route handler.
+
+ """
+ return partial(default_serializer, type_encoders=self.resolve_type_encoders())
+
+ @property
+ def signature_model(self) -> type[SignatureModel]:
+ """Get the signature model for the route handler.
+
+ Returns:
+ A signature model for the route handler.
+
+ """
+ if self._signature_model is Empty:
+ self._signature_model = SignatureModel.create(
+ dependency_name_set=self.dependency_name_set,
+ fn=cast("AnyCallable", self.fn),
+ parsed_signature=self.parsed_fn_signature,
+ data_dto=self.resolve_data_dto(),
+ type_decoders=self.resolve_type_decoders(),
+ )
+ return self._signature_model
+
+ @property
+ def fn(self) -> AsyncAnyCallable:
+ """Get the handler function.
+
+ Raises:
+ ImproperlyConfiguredException: if handler fn is not set.
+
+ Returns:
+ Handler function
+ """
+ if not hasattr(self, "_fn"):
+ raise ImproperlyConfiguredException("No callable has been registered for this handler")
+ return self._fn
+
+ @property
+ def parsed_fn_signature(self) -> ParsedSignature:
+ """Return the parsed signature of the handler function.
+
+ This method is memoized so the computation occurs only once.
+
+ Returns:
+ A ParsedSignature instance
+ """
+ if self._parsed_fn_signature is Empty:
+ self._parsed_fn_signature = ParsedSignature.from_fn(
+ unwrap_partial(self.fn), self.resolve_signature_namespace()
+ )
+
+ return self._parsed_fn_signature
+
+ @property
+ def parsed_return_field(self) -> FieldDefinition:
+ if self._parsed_return_field is Empty:
+ self._parsed_return_field = self.parsed_fn_signature.return_type
+ return self._parsed_return_field
+
+ @property
+ def parsed_data_field(self) -> FieldDefinition | None:
+ if self._parsed_data_field is Empty:
+ self._parsed_data_field = self.parsed_fn_signature.parameters.get("data")
+ return self._parsed_data_field
+
+ @property
+ def handler_name(self) -> str:
+ """Get the name of the handler function.
+
+ Raises:
+ ImproperlyConfiguredException: if handler fn is not set.
+
+ Returns:
+ Name of the handler function
+ """
+ return get_name(unwrap_partial(self.fn))
+
+ @property
+ def dependency_name_set(self) -> set[str]:
+ """Set of all dependency names provided in the handler's ownership layers."""
+ layered_dependencies = (layer.dependencies or {} for layer in self.ownership_layers)
+ return {name for layer in layered_dependencies for name in layer} # pyright: ignore
+
+ @property
+ def ownership_layers(self) -> list[Self | Controller | Router]:
+ """Return the handler layers from the app down to the route handler.
+
+ ``app -> ... -> route handler``
+ """
+ layers = []
+
+ cur: Any = self
+ while cur:
+ layers.append(cur)
+ cur = cur.owner
+
+ return list(reversed(layers))
+
+ @property
+ def app(self) -> Litestar:
+ return cast("Litestar", self.ownership_layers[0])
+
+ def resolve_type_encoders(self) -> TypeEncodersMap:
+ """Return a merged type_encoders mapping.
+
+ This method is memoized so the computation occurs only once.
+
+ Returns:
+ A dict of type encoders
+ """
+ if self._resolved_type_encoders is Empty:
+ self._resolved_type_encoders = {}
+
+ for layer in self.ownership_layers:
+ if type_encoders := getattr(layer, "type_encoders", None):
+ self._resolved_type_encoders.update(type_encoders)
+ return cast("TypeEncodersMap", self._resolved_type_encoders)
+
+ def resolve_type_decoders(self) -> TypeDecodersSequence:
+ """Return a merged type_encoders mapping.
+
+ This method is memoized so the computation occurs only once.
+
+ Returns:
+ A dict of type encoders
+ """
+ if self._resolved_type_decoders is Empty:
+ self._resolved_type_decoders = []
+
+ for layer in self.ownership_layers:
+ if type_decoders := getattr(layer, "type_decoders", None):
+ self._resolved_type_decoders.extend(list(type_decoders))
+ return cast("TypeDecodersSequence", self._resolved_type_decoders)
+
+ def resolve_layered_parameters(self) -> dict[str, FieldDefinition]:
+ """Return all parameters declared above the handler."""
+ if self._resolved_layered_parameters is Empty:
+ parameter_kwargs: dict[str, ParameterKwarg] = {}
+
+ for layer in self.ownership_layers:
+ parameter_kwargs.update(getattr(layer, "parameters", {}) or {})
+
+ self._resolved_layered_parameters = {
+ key: FieldDefinition.from_kwarg(name=key, annotation=parameter.annotation, kwarg_definition=parameter)
+ for key, parameter in parameter_kwargs.items()
+ }
+
+ return self._resolved_layered_parameters
+
+ def resolve_guards(self) -> list[Guard]:
+ """Return all guards in the handlers scope, starting from highest to current layer."""
+ if self._resolved_guards is Empty:
+ self._resolved_guards = []
+
+ for layer in self.ownership_layers:
+ self._resolved_guards.extend(layer.guards or []) # pyright: ignore
+
+ self._resolved_guards = cast(
+ "list[Guard]", [ensure_async_callable(guard) for guard in self._resolved_guards]
+ )
+
+ return self._resolved_guards
+
+ def _get_plugin_registry(self) -> PluginRegistry | None:
+ from litestar.app import Litestar
+
+ root_owner = self.ownership_layers[0]
+ if isinstance(root_owner, Litestar):
+ return root_owner.plugins
+ return None
+
+ def resolve_dependencies(self) -> dict[str, Provide]:
+ """Return all dependencies correlating to handler function's kwargs that exist in the handler's scope."""
+ plugin_registry = self._get_plugin_registry()
+ if self._resolved_dependencies is Empty:
+ self._resolved_dependencies = {}
+ for layer in self.ownership_layers:
+ for key, provider in (layer.dependencies or {}).items():
+ self._resolved_dependencies[key] = self._resolve_dependency(
+ key=key, provider=provider, plugin_registry=plugin_registry
+ )
+
+ return self._resolved_dependencies
+
+ def _resolve_dependency(
+ self, key: str, provider: Provide | AnyCallable, plugin_registry: PluginRegistry | None
+ ) -> Provide:
+ if not isinstance(provider, Provide):
+ provider = Provide(provider)
+
+ if self._resolved_dependencies is not Empty: # pragma: no cover
+ self._validate_dependency_is_unique(dependencies=self._resolved_dependencies, key=key, provider=provider)
+
+ if not getattr(provider, "parsed_fn_signature", None):
+ dependency = unwrap_partial(provider.dependency)
+ plugin: DIPlugin | None = None
+ if plugin_registry:
+ plugin = next(
+ (p for p in plugin_registry.di if isinstance(p, DIPlugin) and p.has_typed_init(dependency)),
+ None,
+ )
+ if plugin:
+ signature, init_type_hints = plugin.get_typed_init(dependency)
+ provider.parsed_fn_signature = ParsedSignature.from_signature(signature, init_type_hints)
+ else:
+ provider.parsed_fn_signature = ParsedSignature.from_fn(dependency, self.resolve_signature_namespace())
+
+ if not getattr(provider, "signature_model", None):
+ provider.signature_model = SignatureModel.create(
+ dependency_name_set=self.dependency_name_set,
+ fn=provider.dependency,
+ parsed_signature=provider.parsed_fn_signature,
+ data_dto=self.resolve_data_dto(),
+ type_decoders=self.resolve_type_decoders(),
+ )
+ return provider
+
+ def resolve_middleware(self) -> list[Middleware]:
+ """Build the middleware stack for the RouteHandler and return it.
+
+ The middlewares are added from top to bottom (``app -> router -> controller -> route handler``) and then
+ reversed.
+ """
+ resolved_middleware: list[Middleware] = []
+ for layer in self.ownership_layers:
+ resolved_middleware.extend(layer.middleware or []) # pyright: ignore
+ return list(reversed(resolved_middleware))
+
+ def resolve_exception_handlers(self) -> ExceptionHandlersMap:
+ """Resolve the exception_handlers by starting from the route handler and moving up.
+
+ This method is memoized so the computation occurs only once.
+ """
+ resolved_exception_handlers: dict[int | type[Exception], ExceptionHandler] = {}
+ for layer in self.ownership_layers:
+ resolved_exception_handlers.update(layer.exception_handlers or {}) # pyright: ignore
+ return resolved_exception_handlers
+
+ def resolve_opts(self) -> None:
+ """Build the route handler opt dictionary by going from top to bottom.
+
+ When merging keys from multiple layers, if the same key is defined by multiple layers, the value from the
+ layer closest to the response handler will take precedence.
+ """
+
+ opt: dict[str, Any] = {}
+ for layer in self.ownership_layers:
+ opt.update(layer.opt or {}) # pyright: ignore
+
+ self.opt = opt
+
+ def resolve_signature_namespace(self) -> dict[str, Any]:
+ """Build the route handler signature namespace dictionary by going from top to bottom.
+
+ When merging keys from multiple layers, if the same key is defined by multiple layers, the value from the
+ layer closest to the response handler will take precedence.
+ """
+ if self._resolved_layered_parameters is Empty:
+ ns: dict[str, Any] = {}
+ for layer in self.ownership_layers:
+ ns.update(layer.signature_namespace)
+
+ self._resolved_signature_namespace = ns
+ return cast("dict[str, Any]", self._resolved_signature_namespace)
+
+ def _get_dto_backend_cls(self) -> type[DTOBackend] | None:
+ if ExperimentalFeatures.DTO_CODEGEN in self.app.experimental_features:
+ from litestar.dto._codegen_backend import DTOCodegenBackend
+
+ return DTOCodegenBackend
+ return None
+
+ def resolve_data_dto(self) -> type[AbstractDTO] | None:
+ """Resolve the data_dto by starting from the route handler and moving up.
+ If a handler is found it is returned, otherwise None is set.
+ This method is memoized so the computation occurs only once.
+
+ Returns:
+ An optional :class:`DTO type <.dto.base_dto.AbstractDTO>`
+ """
+ if self._resolved_data_dto is Empty:
+ if data_dtos := cast(
+ "list[type[AbstractDTO] | None]",
+ [layer.dto for layer in self.ownership_layers if layer.dto is not Empty],
+ ):
+ data_dto: type[AbstractDTO] | None = data_dtos[-1]
+ elif self.parsed_data_field and (
+ plugins_for_data_type := [
+ plugin
+ for plugin in self.app.plugins.serialization
+ if self.parsed_data_field.match_predicate_recursively(plugin.supports_type)
+ ]
+ ):
+ data_dto = plugins_for_data_type[0].create_dto_for_type(self.parsed_data_field)
+ else:
+ data_dto = None
+
+ if self.parsed_data_field and data_dto:
+ data_dto.create_for_field_definition(
+ field_definition=self.parsed_data_field,
+ handler_id=self.handler_id,
+ backend_cls=self._get_dto_backend_cls(),
+ )
+
+ self._resolved_data_dto = data_dto
+
+ return self._resolved_data_dto
+
+ def resolve_return_dto(self) -> type[AbstractDTO] | None:
+ """Resolve the return_dto by starting from the route handler and moving up.
+ If a handler is found it is returned, otherwise None is set.
+ This method is memoized so the computation occurs only once.
+
+ Returns:
+ An optional :class:`DTO type <.dto.base_dto.AbstractDTO>`
+ """
+ if self._resolved_return_dto is Empty:
+ if return_dtos := cast(
+ "list[type[AbstractDTO] | None]",
+ [layer.return_dto for layer in self.ownership_layers if layer.return_dto is not Empty],
+ ):
+ return_dto: type[AbstractDTO] | None = return_dtos[-1]
+ elif plugins_for_return_type := [
+ plugin
+ for plugin in self.app.plugins.serialization
+ if self.parsed_return_field.match_predicate_recursively(plugin.supports_type)
+ ]:
+ return_dto = plugins_for_return_type[0].create_dto_for_type(self.parsed_return_field)
+ else:
+ return_dto = self.resolve_data_dto()
+
+ if return_dto and return_dto.is_supported_model_type_field(self.parsed_return_field):
+ return_dto.create_for_field_definition(
+ field_definition=self.parsed_return_field,
+ handler_id=self.handler_id,
+ backend_cls=self._get_dto_backend_cls(),
+ )
+ self._resolved_return_dto = return_dto
+ else:
+ self._resolved_return_dto = None
+
+ return self._resolved_return_dto
+
+ async def authorize_connection(self, connection: ASGIConnection) -> None:
+ """Ensure the connection is authorized by running all the route guards in scope."""
+ for guard in self.resolve_guards():
+ await guard(connection, copy(self)) # type: ignore[misc]
+
+ @staticmethod
+ def _validate_dependency_is_unique(dependencies: dict[str, Provide], key: str, provider: Provide) -> None:
+ """Validate that a given provider has not been already defined under a different key."""
+ for dependency_key, value in dependencies.items():
+ if provider == value:
+ raise ImproperlyConfiguredException(
+ f"Provider for key {key} is already defined under the different key {dependency_key}. "
+ f"If you wish to override a provider, it must have the same key."
+ )
+
+ def on_registration(self, app: Litestar) -> None:
+ """Called once per handler when the app object is instantiated.
+
+ Args:
+ app: The :class:`Litestar<.app.Litestar>` app object.
+
+ Returns:
+ None
+ """
+ self._validate_handler_function()
+ self.resolve_dependencies()
+ self.resolve_guards()
+ self.resolve_middleware()
+ self.resolve_opts()
+ self.resolve_data_dto()
+ self.resolve_return_dto()
+
+ def _validate_handler_function(self) -> None:
+ """Validate the route handler function once set by inspecting its return annotations."""
+ if (
+ self.parsed_data_field is not None
+ and self.parsed_data_field.is_subclass_of(DTOData)
+ and not self.resolve_data_dto()
+ ):
+ raise ImproperlyConfiguredException(
+ f"Handler function {self.handler_name} has a data parameter that is a subclass of DTOData but no "
+ "DTO has been registered for it."
+ )
+
+ def __str__(self) -> str:
+ """Return a unique identifier for the route handler.
+
+ Returns:
+ A string
+ """
+ target: type[AsyncAnyCallable] | AsyncAnyCallable # pyright: ignore
+ target = unwrap_partial(self.fn)
+ if not hasattr(target, "__qualname__"):
+ target = type(target)
+ return f"{target.__module__}.{target.__qualname__}"