From 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 Mon Sep 17 00:00:00 2001 From: cyfraeviolae Date: Wed, 3 Apr 2024 03:10:44 -0400 Subject: venv --- .../site-packages/litestar/_kwargs/kwargs_model.py | 479 +++++++++++++++++++++ 1 file changed, 479 insertions(+) create mode 100644 venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py (limited to 'venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py') diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py new file mode 100644 index 0000000..01ed2e5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py @@ -0,0 +1,479 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable + +from anyio import create_task_group + +from litestar._kwargs.cleanup import DependencyCleanupGroup +from litestar._kwargs.dependencies import ( + Dependency, + create_dependency_batches, + resolve_dependency, +) +from litestar._kwargs.extractors import ( + body_extractor, + cookies_extractor, + create_connection_value_extractor, + create_data_extractor, + headers_extractor, + parse_connection_headers, + parse_connection_query_params, + query_extractor, + request_extractor, + scope_extractor, + socket_extractor, + state_extractor, +) +from litestar._kwargs.parameter_definition import ( + ParameterDefinition, + create_parameter_definition, + merge_parameter_sets, +) +from litestar.constants import RESERVED_KWARGS +from litestar.enums import ParamType, RequestEncodingType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.params import BodyKwarg, ParameterKwarg +from litestar.typing import FieldDefinition +from litestar.utils.helpers import get_exception_group + +__all__ = ("KwargsModel",) + + +if TYPE_CHECKING: + from litestar._signature import SignatureModel + from litestar.connection import ASGIConnection + from litestar.di import Provide + from litestar.dto import AbstractDTO + from litestar.utils.signature import ParsedSignature + +_ExceptionGroup = get_exception_group() + + +class KwargsModel: + """Model required kwargs for a given RouteHandler and its dependencies. + + This is done once and is memoized during application bootstrap, ensuring minimal runtime overhead. + """ + + __slots__ = ( + "dependency_batches", + "expected_cookie_params", + "expected_data_dto", + "expected_form_data", + "expected_header_params", + "expected_msgpack_data", + "expected_path_params", + "expected_query_params", + "expected_reserved_kwargs", + "extractors", + "has_kwargs", + "is_data_optional", + "sequence_query_parameter_names", + ) + + def __init__( + self, + *, + expected_cookie_params: set[ParameterDefinition], + expected_data_dto: type[AbstractDTO] | None, + expected_dependencies: set[Dependency], + expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None, + expected_header_params: set[ParameterDefinition], + expected_msgpack_data: FieldDefinition | None, + expected_path_params: set[ParameterDefinition], + expected_query_params: set[ParameterDefinition], + expected_reserved_kwargs: set[str], + is_data_optional: bool, + sequence_query_parameter_names: set[str], + ) -> None: + """Initialize ``KwargsModel``. + + Args: + expected_cookie_params: Any expected cookie parameter kwargs + expected_dependencies: Any expected dependency kwargs + expected_form_data: Any expected form data kwargs + expected_header_params: Any expected header parameter kwargs + expected_msgpack_data: Any expected MessagePack data kwargs + expected_path_params: Any expected path parameter kwargs + expected_query_params: Any expected query parameter kwargs + expected_reserved_kwargs: Any expected reserved kwargs, e.g. 'state' + expected_data_dto: A data DTO, if defined + is_data_optional: Treat data as optional + sequence_query_parameter_names: Any query parameters that are sequences + """ + self.expected_cookie_params = expected_cookie_params + self.expected_form_data = expected_form_data + self.expected_header_params = expected_header_params + self.expected_msgpack_data = expected_msgpack_data + self.expected_path_params = expected_path_params + self.expected_query_params = expected_query_params + self.expected_reserved_kwargs = expected_reserved_kwargs + self.expected_data_dto = expected_data_dto + self.sequence_query_parameter_names = tuple(sequence_query_parameter_names) + + self.has_kwargs = ( + expected_cookie_params + or expected_dependencies + or expected_form_data + or expected_msgpack_data + or expected_header_params + or expected_path_params + or expected_query_params + or expected_reserved_kwargs + or expected_data_dto + ) + + self.is_data_optional = is_data_optional + self.extractors = self._create_extractors() + self.dependency_batches = create_dependency_batches(expected_dependencies) + + def _create_extractors(self) -> list[Callable[[dict[str, Any], ASGIConnection], None]]: + reserved_kwargs_extractors: dict[str, Callable[[dict[str, Any], ASGIConnection], None]] = { + "data": create_data_extractor(self), + "state": state_extractor, + "scope": scope_extractor, + "request": request_extractor, + "socket": socket_extractor, + "headers": headers_extractor, + "cookies": cookies_extractor, + "query": query_extractor, + "body": body_extractor, # type: ignore[dict-item] + } + + extractors: list[Callable[[dict[str, Any], ASGIConnection], None]] = [ + reserved_kwargs_extractors[reserved_kwarg] for reserved_kwarg in self.expected_reserved_kwargs + ] + + if self.expected_header_params: + extractors.append( + create_connection_value_extractor( + connection_key="headers", + expected_params=self.expected_header_params, + kwargs_model=self, + parser=parse_connection_headers, + ), + ) + + if self.expected_path_params: + extractors.append( + create_connection_value_extractor( + connection_key="path_params", + expected_params=self.expected_path_params, + kwargs_model=self, + ), + ) + + if self.expected_cookie_params: + extractors.append( + create_connection_value_extractor( + connection_key="cookies", + expected_params=self.expected_cookie_params, + kwargs_model=self, + ), + ) + + if self.expected_query_params: + extractors.append( + create_connection_value_extractor( + connection_key="query_params", + expected_params=self.expected_query_params, + kwargs_model=self, + parser=parse_connection_query_params, + ), + ) + return extractors + + @classmethod + def _get_param_definitions( + cls, + path_parameters: set[str], + layered_parameters: dict[str, FieldDefinition], + dependencies: dict[str, Provide], + field_definitions: dict[str, FieldDefinition], + ) -> tuple[set[ParameterDefinition], set[Dependency]]: + """Get parameter_definitions for the construction of KwargsModel instance. + + Args: + path_parameters: Any expected path parameters. + layered_parameters: A string keyed dictionary of layered parameters. + dependencies: A string keyed dictionary mapping dependency providers. + field_definitions: The SignatureModel fields. + + Returns: + A Tuple of sets + """ + expected_dependencies = { + cls._create_dependency_graph(key=key, dependencies=dependencies) + for key in dependencies + if key in field_definitions + } + ignored_keys = {*RESERVED_KWARGS, *(dependency.key for dependency in expected_dependencies)} + + param_definitions = { + *( + create_parameter_definition( + field_definition=field_definition, + field_name=field_name, + path_parameters=path_parameters, + ) + for field_name, field_definition in layered_parameters.items() + if field_name not in ignored_keys and field_name not in field_definitions + ), + *( + create_parameter_definition( + field_definition=field_definition, + field_name=field_name, + path_parameters=path_parameters, + ) + for field_name, field_definition in field_definitions.items() + if field_name not in ignored_keys and field_name not in layered_parameters + ), + } + + for field_name, field_definition in ( + (k, v) for k, v in field_definitions.items() if k not in ignored_keys and k in layered_parameters + ): + layered_parameter = layered_parameters[field_name] + field = field_definition if field_definition.is_parameter_field else layered_parameter + default = field_definition.default if field_definition.has_default else layered_parameter.default + + param_definitions.add( + create_parameter_definition( + field_definition=FieldDefinition.from_kwarg( + name=field.name, + default=default, + inner_types=field.inner_types, + annotation=field.annotation, + kwarg_definition=field.kwarg_definition, + extra=field.extra, + ), + field_name=field_name, + path_parameters=path_parameters, + ) + ) + + return param_definitions, expected_dependencies + + @classmethod + def create_for_signature_model( + cls, + signature_model: type[SignatureModel], + parsed_signature: ParsedSignature, + dependencies: dict[str, Provide], + path_parameters: set[str], + layered_parameters: dict[str, FieldDefinition], + ) -> KwargsModel: + """Pre-determine what parameters are required for a given combination of route + route handler. It is executed + during the application bootstrap process. + + Args: + signature_model: A :class:`SignatureModel ` subclass. + parsed_signature: A :class:`ParsedSignature ` instance. + dependencies: A string keyed dictionary mapping dependency providers. + path_parameters: Any expected path parameters. + layered_parameters: A string keyed dictionary of layered parameters. + + Returns: + An instance of KwargsModel + """ + + field_definitions = signature_model._fields + + cls._validate_raw_kwargs( + path_parameters=path_parameters, + dependencies=dependencies, + field_definitions=field_definitions, + layered_parameters=layered_parameters, + ) + + param_definitions, expected_dependencies = cls._get_param_definitions( + path_parameters=path_parameters, + layered_parameters=layered_parameters, + dependencies=dependencies, + field_definitions=field_definitions, + ) + + expected_reserved_kwargs = {field_name for field_name in field_definitions if field_name in RESERVED_KWARGS} + expected_path_parameters = {p for p in param_definitions if p.param_type == ParamType.PATH} + expected_header_parameters = {p for p in param_definitions if p.param_type == ParamType.HEADER} + expected_cookie_parameters = {p for p in param_definitions if p.param_type == ParamType.COOKIE} + expected_query_parameters = {p for p in param_definitions if p.param_type == ParamType.QUERY} + sequence_query_parameter_names = {p.field_alias for p in expected_query_parameters if p.is_sequence} + + expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None = None + expected_msgpack_data: FieldDefinition | None = None + expected_data_dto: type[AbstractDTO] | None = None + data_field_definition = field_definitions.get("data") + + media_type: RequestEncodingType | str | None = None + if data_field_definition: + if isinstance(data_field_definition.kwarg_definition, BodyKwarg): + media_type = data_field_definition.kwarg_definition.media_type + + if media_type in (RequestEncodingType.MULTI_PART, RequestEncodingType.URL_ENCODED): + expected_form_data = (media_type, data_field_definition) + expected_data_dto = signature_model._data_dto + elif signature_model._data_dto: + expected_data_dto = signature_model._data_dto + elif media_type == RequestEncodingType.MESSAGEPACK: + expected_msgpack_data = data_field_definition + + for dependency in expected_dependencies: + dependency_kwargs_model = cls.create_for_signature_model( + signature_model=dependency.provide.signature_model, + parsed_signature=parsed_signature, + dependencies=dependencies, + path_parameters=path_parameters, + layered_parameters=layered_parameters, + ) + expected_path_parameters = merge_parameter_sets( + expected_path_parameters, dependency_kwargs_model.expected_path_params + ) + expected_query_parameters = merge_parameter_sets( + expected_query_parameters, dependency_kwargs_model.expected_query_params + ) + expected_cookie_parameters = merge_parameter_sets( + expected_cookie_parameters, dependency_kwargs_model.expected_cookie_params + ) + expected_header_parameters = merge_parameter_sets( + expected_header_parameters, dependency_kwargs_model.expected_header_params + ) + + if "data" in expected_reserved_kwargs and "data" in dependency_kwargs_model.expected_reserved_kwargs: + cls._validate_dependency_data( + expected_form_data=expected_form_data, + dependency_kwargs_model=dependency_kwargs_model, + ) + + expected_reserved_kwargs.update(dependency_kwargs_model.expected_reserved_kwargs) + sequence_query_parameter_names.update(dependency_kwargs_model.sequence_query_parameter_names) + + return KwargsModel( + expected_cookie_params=expected_cookie_parameters, + expected_dependencies=expected_dependencies, + expected_data_dto=expected_data_dto, + expected_form_data=expected_form_data, + expected_header_params=expected_header_parameters, + expected_msgpack_data=expected_msgpack_data, + expected_path_params=expected_path_parameters, + expected_query_params=expected_query_parameters, + expected_reserved_kwargs=expected_reserved_kwargs, + is_data_optional=field_definitions["data"].is_optional if "data" in expected_reserved_kwargs else False, + sequence_query_parameter_names=sequence_query_parameter_names, + ) + + def to_kwargs(self, connection: ASGIConnection) -> dict[str, Any]: + """Return a dictionary of kwargs. Async values, i.e. CoRoutines, are not resolved to ensure this function is + sync. + + Args: + connection: An instance of :class:`Request ` or + :class:`WebSocket `. + + Returns: + A string keyed dictionary of kwargs expected by the handler function and its dependencies. + """ + output: dict[str, Any] = {} + + for extractor in self.extractors: + extractor(output, connection) + + return output + + async def resolve_dependencies(self, connection: ASGIConnection, kwargs: dict[str, Any]) -> DependencyCleanupGroup: + """Resolve all dependencies into the kwargs, recursively. + + Args: + connection: An instance of :class:`Request ` or + :class:`WebSocket `. + kwargs: Kwargs to pass to dependencies. + """ + cleanup_group = DependencyCleanupGroup() + for batch in self.dependency_batches: + if len(batch) == 1: + await resolve_dependency(next(iter(batch)), connection, kwargs, cleanup_group) + else: + try: + async with create_task_group() as task_group: + for dependency in batch: + task_group.start_soon(resolve_dependency, dependency, connection, kwargs, cleanup_group) + except _ExceptionGroup as excgroup: + raise excgroup.exceptions[0] from excgroup # type: ignore[attr-defined] + + return cleanup_group + + @classmethod + def _create_dependency_graph(cls, key: str, dependencies: dict[str, Provide]) -> Dependency: + """Create a graph like structure of dependencies, with each dependency including its own dependencies as a + list. + """ + provide = dependencies[key] + sub_dependency_keys = [k for k in provide.signature_model._fields if k in dependencies] + return Dependency( + key=key, + provide=provide, + dependencies=[cls._create_dependency_graph(key=k, dependencies=dependencies) for k in sub_dependency_keys], + ) + + @classmethod + def _validate_dependency_data( + cls, + expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None, + dependency_kwargs_model: KwargsModel, + ) -> None: + """Validate that the 'data' kwarg is compatible across dependencies.""" + if bool(expected_form_data) != bool(dependency_kwargs_model.expected_form_data): + raise ImproperlyConfiguredException( + "Dependencies have incompatible 'data' kwarg types: one expects JSON and the other expects form-data" + ) + if expected_form_data and dependency_kwargs_model.expected_form_data: + local_media_type = expected_form_data[0] + dependency_media_type = dependency_kwargs_model.expected_form_data[0] + if local_media_type != dependency_media_type: + raise ImproperlyConfiguredException( + "Dependencies have incompatible form-data encoding: one expects url-encoded and the other expects multi-part" + ) + + @classmethod + def _validate_raw_kwargs( + cls, + path_parameters: set[str], + dependencies: dict[str, Provide], + field_definitions: dict[str, FieldDefinition], + layered_parameters: dict[str, FieldDefinition], + ) -> None: + """Validate that there are no ambiguous kwargs, that is, kwargs declared using the same key in different + places. + """ + dependency_keys = set(dependencies.keys()) + + parameter_names = { + *( + k + for k, f in field_definitions.items() + if isinstance(f.kwarg_definition, ParameterKwarg) + and (f.kwarg_definition.header or f.kwarg_definition.query or f.kwarg_definition.cookie) + ), + *list(layered_parameters.keys()), + } + + intersection = ( + path_parameters.intersection(dependency_keys) + or path_parameters.intersection(parameter_names) + or dependency_keys.intersection(parameter_names) + ) + if intersection: + raise ImproperlyConfiguredException( + f"Kwarg resolution ambiguity detected for the following keys: {', '.join(intersection)}. " + f"Make sure to use distinct keys for your dependencies, path parameters, and aliased parameters." + ) + + if used_reserved_kwargs := { + *parameter_names, + *path_parameters, + *dependency_keys, + }.intersection(RESERVED_KWARGS): + raise ImproperlyConfiguredException( + f"Reserved kwargs ({', '.join(RESERVED_KWARGS)}) cannot be used for dependencies and parameter arguments. " + f"The following kwargs have been used: {', '.join(used_reserved_kwargs)}" + ) -- cgit v1.2.3