diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/_kwargs')
12 files changed, 1291 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__init__.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/__init__.py new file mode 100644 index 0000000..af8ad36 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__init__.py @@ -0,0 +1,3 @@ +from .kwargs_model import KwargsModel + +__all__ = ("KwargsModel",) diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b61cc1b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/cleanup.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/cleanup.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a084eb5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/cleanup.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/dependencies.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/dependencies.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b0f49a4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/dependencies.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/extractors.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/extractors.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..35a6e40 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/extractors.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/kwargs_model.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/kwargs_model.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..17029de --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/kwargs_model.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/parameter_definition.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/parameter_definition.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..10753dc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/parameter_definition.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/cleanup.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/cleanup.py new file mode 100644 index 0000000..8839d36 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/cleanup.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from inspect import Traceback, isasyncgen +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator + +from anyio import create_task_group + +from litestar.utils import ensure_async_callable +from litestar.utils.compat import async_next + +__all__ = ("DependencyCleanupGroup",) + + +if TYPE_CHECKING: + from litestar.types import AnyGenerator + + +class DependencyCleanupGroup: + """Wrapper for generator based dependencies. + + Simplify cleanup by wrapping :func:`next` / :func:`anext` calls and providing facilities to + :meth:`throw <generator.throw>` / :meth:`athrow <agen.athrow>` into all generators consecutively. An instance of + this class can be used as a contextmanager, which will automatically throw any exceptions into its generators. All + exceptions caught in this manner will be re-raised after they have been thrown in the generators. + """ + + __slots__ = ("_generators", "_closed") + + def __init__(self, generators: list[AnyGenerator] | None = None) -> None: + """Initialize ``DependencyCleanupGroup``. + + Args: + generators: An optional list of generators to be called at cleanup + """ + self._generators = generators or [] + self._closed = False + + def add(self, generator: Generator[Any, None, None] | AsyncGenerator[Any, None]) -> None: + """Add a new generator to the group. + + Args: + generator: The generator to add + + Returns: + None + """ + if self._closed: + raise RuntimeError("Cannot call cleanup on a closed DependencyCleanupGroup") + self._generators.append(generator) + + @staticmethod + def _wrap_next(generator: AnyGenerator) -> Callable[[], Awaitable[None]]: + if isasyncgen(generator): + + async def wrapped_async() -> None: + await async_next(generator, None) + + return wrapped_async + + def wrapped() -> None: + next(generator, None) # type: ignore[arg-type] + + return ensure_async_callable(wrapped) + + async def cleanup(self) -> None: + """Execute cleanup by calling :func:`next` / :func:`anext` on all generators. + + If there are multiple generators to be called, they will be executed in a :class:`anyio.TaskGroup`. + + Returns: + None + """ + if self._closed: + raise RuntimeError("Cannot call cleanup on a closed DependencyCleanupGroup") + + self._closed = True + + if not self._generators: + return + + if len(self._generators) == 1: + await self._wrap_next(self._generators[0])() + return + + async with create_task_group() as task_group: + for generator in self._generators: + task_group.start_soon(self._wrap_next(generator)) + + async def __aenter__(self) -> None: + """Support the async contextmanager protocol to allow for easier catching and throwing of exceptions into the + generators. + """ + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Traceback | None, + ) -> None: + """If an exception was raised within the contextmanager block, throw it into all generators.""" + if exc_val: + await self.throw(exc_val) + + async def throw(self, exc: BaseException) -> None: + """Throw an exception in all generators sequentially. + + Args: + exc: Exception to throw + """ + for gen in self._generators: + try: + if isasyncgen(gen): + await gen.athrow(exc) + else: + gen.throw(exc) # type: ignore[union-attr] + except (StopIteration, StopAsyncIteration): + continue diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/dependencies.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/dependencies.py new file mode 100644 index 0000000..88ffb07 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/dependencies.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from litestar.utils.compat import async_next + +__all__ = ("Dependency", "create_dependency_batches", "map_dependencies_recursively", "resolve_dependency") + + +if TYPE_CHECKING: + from litestar._kwargs.cleanup import DependencyCleanupGroup + from litestar.connection import ASGIConnection + from litestar.di import Provide + + +class Dependency: + """Dependency graph of a given combination of ``Route`` + ``RouteHandler``""" + + __slots__ = ("key", "provide", "dependencies") + + def __init__(self, key: str, provide: Provide, dependencies: list[Dependency]) -> None: + """Initialize a dependency. + + Args: + key: The dependency key + provide: Provider + dependencies: List of child nodes + """ + self.key = key + self.provide = provide + self.dependencies = dependencies + + def __eq__(self, other: Any) -> bool: + # check if memory address is identical, otherwise compare attributes + return other is self or (isinstance(other, self.__class__) and other.key == self.key) + + def __hash__(self) -> int: + return hash(self.key) + + +async def resolve_dependency( + dependency: Dependency, + connection: ASGIConnection, + kwargs: dict[str, Any], + cleanup_group: DependencyCleanupGroup, +) -> None: + """Resolve a given instance of :class:`Dependency <litestar._kwargs.Dependency>`. + + All required sub dependencies must already + be resolved into the kwargs. The result of the dependency will be stored in the kwargs. + + Args: + dependency: An instance of :class:`Dependency <litestar._kwargs.Dependency>` + connection: An instance of :class:`Request <litestar.connection.Request>` or + :class:`WebSocket <litestar.connection.WebSocket>`. + kwargs: Any kwargs to pass to the dependency, the result will be stored here as well. + cleanup_group: DependencyCleanupGroup to which generators returned by ``dependency`` will be added + """ + signature_model = dependency.provide.signature_model + dependency_kwargs = ( + signature_model.parse_values_from_connection_kwargs(connection=connection, **kwargs) + if signature_model._fields + else {} + ) + value = await dependency.provide(**dependency_kwargs) + + if dependency.provide.has_sync_generator_dependency: + cleanup_group.add(value) + value = next(value) + elif dependency.provide.has_async_generator_dependency: + cleanup_group.add(value) + value = await async_next(value) + + kwargs[dependency.key] = value + + +def create_dependency_batches(expected_dependencies: set[Dependency]) -> list[set[Dependency]]: + """Calculate batches for all dependencies, recursively. + + Args: + expected_dependencies: A set of all direct :class:`Dependencies <litestar._kwargs.Dependency>`. + + Returns: + A list of batches. + """ + dependencies_to: dict[Dependency, set[Dependency]] = {} + for dependency in expected_dependencies: + if dependency not in dependencies_to: + map_dependencies_recursively(dependency, dependencies_to) + + batches = [] + while dependencies_to: + current_batch = { + dependency + for dependency, remaining_sub_dependencies in dependencies_to.items() + if not remaining_sub_dependencies + } + + for dependency in current_batch: + del dependencies_to[dependency] + for others_dependencies in dependencies_to.values(): + others_dependencies.discard(dependency) + + batches.append(current_batch) + + return batches + + +def map_dependencies_recursively(dependency: Dependency, dependencies_to: dict[Dependency, set[Dependency]]) -> None: + """Recursively map dependencies to their sub dependencies. + + Args: + dependency: The current dependency to map. + dependencies_to: A map of dependency to its sub dependencies. + """ + dependencies_to[dependency] = set(dependency.dependencies) + for sub in dependency.dependencies: + if sub not in dependencies_to: + map_dependencies_recursively(sub, dependencies_to) diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py new file mode 100644 index 0000000..e3b347e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py @@ -0,0 +1,492 @@ +from __future__ import annotations + +from collections import defaultdict +from functools import lru_cache, partial +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Mapping, NamedTuple, cast + +from litestar._multipart import parse_multipart_form +from litestar._parsers import ( + parse_query_string, + parse_url_encoded_form_data, +) +from litestar.datastructures import Headers +from litestar.datastructures.upload_file import UploadFile +from litestar.datastructures.url import URL +from litestar.enums import ParamType, RequestEncodingType +from litestar.exceptions import ValidationException +from litestar.params import BodyKwarg +from litestar.types import Empty +from litestar.utils.predicates import is_non_string_sequence +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from litestar._kwargs import KwargsModel + from litestar._kwargs.parameter_definition import ParameterDefinition + from litestar.connection import ASGIConnection, Request + from litestar.dto import AbstractDTO + from litestar.typing import FieldDefinition + + +__all__ = ( + "body_extractor", + "cookies_extractor", + "create_connection_value_extractor", + "create_data_extractor", + "create_multipart_extractor", + "create_query_default_dict", + "create_url_encoded_data_extractor", + "headers_extractor", + "json_extractor", + "msgpack_extractor", + "parse_connection_headers", + "parse_connection_query_params", + "query_extractor", + "request_extractor", + "scope_extractor", + "socket_extractor", + "state_extractor", +) + + +class ParamMappings(NamedTuple): + alias_and_key_tuples: list[tuple[str, str]] + alias_defaults: dict[str, Any] + alias_to_param: dict[str, ParameterDefinition] + + +def _create_param_mappings(expected_params: set[ParameterDefinition]) -> ParamMappings: + alias_and_key_tuples = [] + alias_defaults = {} + alias_to_params: dict[str, ParameterDefinition] = {} + for param in expected_params: + alias = param.field_alias + if param.param_type == ParamType.HEADER: + alias = alias.lower() + + alias_and_key_tuples.append((alias, param.field_name)) + + if not (param.is_required or param.default is Ellipsis): + alias_defaults[alias] = param.default + + alias_to_params[alias] = param + + return ParamMappings( + alias_and_key_tuples=alias_and_key_tuples, + alias_defaults=alias_defaults, + alias_to_param=alias_to_params, + ) + + +def create_connection_value_extractor( + kwargs_model: KwargsModel, + connection_key: str, + expected_params: set[ParameterDefinition], + parser: Callable[[ASGIConnection, KwargsModel], Mapping[str, Any]] | None = None, +) -> Callable[[dict[str, Any], ASGIConnection], None]: + """Create a kwargs extractor function. + + Args: + kwargs_model: The KwargsModel instance. + connection_key: The attribute key to use. + expected_params: The set of expected params. + parser: An optional parser function. + + Returns: + An extractor function. + """ + + alias_and_key_tuples, alias_defaults, alias_to_params = _create_param_mappings(expected_params) + + def extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + data = parser(connection, kwargs_model) if parser else getattr(connection, connection_key, {}) + + try: + connection_mapping: dict[str, Any] = { + key: data[alias] if alias in data else alias_defaults[alias] for alias, key in alias_and_key_tuples + } + values.update(connection_mapping) + except KeyError as e: + param = alias_to_params[e.args[0]] + path = URL.from_components( + path=connection.url.path, + query=connection.url.query, + ) + raise ValidationException( + f"Missing required {param.param_type.value} parameter {param.field_alias!r} for path {path}" + ) from e + + return extractor + + +@lru_cache(1024) +def create_query_default_dict( + parsed_query: tuple[tuple[str, str], ...], sequence_query_parameter_names: tuple[str, ...] +) -> defaultdict[str, list[str] | str]: + """Transform a list of tuples into a default dict. Ensures non-list values are not wrapped in a list. + + Args: + parsed_query: The parsed query list of tuples. + sequence_query_parameter_names: A set of query parameters that should be wrapped in list. + + Returns: + A default dict + """ + output: defaultdict[str, list[str] | str] = defaultdict(list) + + for k, v in parsed_query: + if k in sequence_query_parameter_names: + output[k].append(v) # type: ignore[union-attr] + else: + output[k] = v + + return output + + +def parse_connection_query_params(connection: ASGIConnection, kwargs_model: KwargsModel) -> dict[str, Any]: + """Parse query params and cache the result in scope. + + Args: + connection: The ASGI connection instance. + kwargs_model: The KwargsModel instance. + + Returns: + A dictionary of parsed values. + """ + parsed_query = ( + connection._parsed_query + if connection._parsed_query is not Empty + else parse_query_string(connection.scope.get("query_string", b"")) + ) + ScopeState.from_scope(connection.scope).parsed_query = parsed_query + return create_query_default_dict( + parsed_query=parsed_query, + sequence_query_parameter_names=kwargs_model.sequence_query_parameter_names, + ) + + +def parse_connection_headers(connection: ASGIConnection, _: KwargsModel) -> Headers: + """Parse header parameters and cache the result in scope. + + Args: + connection: The ASGI connection instance. + _: The KwargsModel instance. + + Returns: + A Headers instance + """ + return Headers.from_scope(connection.scope) + + +def state_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the app state from the connection and insert it to the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["state"] = connection.app.state._state + + +def headers_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the headers from the connection and insert them to the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + # TODO: This should be removed in 3.0 and instead Headers should be injected + # directly. We are only keeping this one around to not break things + values["headers"] = dict(connection.headers.items()) + + +def cookies_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the cookies from the connection and insert them to the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["cookies"] = connection.cookies + + +def query_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the query params from the connection and insert them to the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["query"] = connection.query_params + + +def scope_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the scope from the connection and insert it into the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["scope"] = connection.scope + + +def request_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Set the connection instance as the 'request' value in the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["request"] = connection + + +def socket_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Set the connection instance as the 'socket' value in the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["socket"] = connection + + +def body_extractor( + values: dict[str, Any], + connection: Request[Any, Any, Any], +) -> None: + """Extract the body from the request instance. + + Notes: + - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + The Body value. + """ + values["body"] = connection.body() + + +async def json_extractor(connection: Request[Any, Any, Any]) -> Any: + """Extract the data from request and insert it into the kwargs injected to the handler. + + Notes: + - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage. + + Args: + connection: The ASGI connection instance. + + Returns: + The JSON value. + """ + if not await connection.body(): + return Empty + return await connection.json() + + +async def msgpack_extractor(connection: Request[Any, Any, Any]) -> Any: + """Extract the data from request and insert it into the kwargs injected to the handler. + + Notes: + - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage. + + Args: + connection: The ASGI connection instance. + + Returns: + The MessagePack value. + """ + if not await connection.body(): + return Empty + return await connection.msgpack() + + +async def _extract_multipart( + connection: Request[Any, Any, Any], + body_kwarg_multipart_form_part_limit: int | None, + field_definition: FieldDefinition, + is_data_optional: bool, + data_dto: type[AbstractDTO] | None, +) -> Any: + multipart_form_part_limit = ( + body_kwarg_multipart_form_part_limit + if body_kwarg_multipart_form_part_limit is not None + else connection.app.multipart_form_part_limit + ) + connection.scope["_form"] = form_values = ( # type: ignore[typeddict-unknown-key] + connection.scope["_form"] # type: ignore[typeddict-item] + if "_form" in connection.scope + else parse_multipart_form( + body=await connection.body(), + boundary=connection.content_type[-1].get("boundary", "").encode(), + multipart_form_part_limit=multipart_form_part_limit, + type_decoders=connection.route_handler.resolve_type_decoders(), + ) + ) + + if field_definition.is_non_string_sequence: + values = list(form_values.values()) + if field_definition.has_inner_subclass_of(UploadFile) and isinstance(values[0], list): + return values[0] + + return values + + if field_definition.is_simple_type and field_definition.annotation is UploadFile and form_values: + return next(v for v in form_values.values() if isinstance(v, UploadFile)) + + if not form_values and is_data_optional: + return None + + if data_dto: + return data_dto(connection).decode_builtins(form_values) + + for name, tp in field_definition.get_type_hints().items(): + value = form_values.get(name) + if value is not None and is_non_string_sequence(tp) and not isinstance(value, list): + form_values[name] = [value] + + return form_values + + +def create_multipart_extractor( + field_definition: FieldDefinition, is_data_optional: bool, data_dto: type[AbstractDTO] | None +) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]: + """Create a multipart form-data extractor. + + Args: + field_definition: A FieldDefinition instance. + is_data_optional: Boolean dictating whether the field is optional. + data_dto: A data DTO type, if configured for handler. + + Returns: + An extractor function. + """ + body_kwarg_multipart_form_part_limit: int | None = None + if field_definition.kwarg_definition and isinstance(field_definition.kwarg_definition, BodyKwarg): + body_kwarg_multipart_form_part_limit = field_definition.kwarg_definition.multipart_form_part_limit + + extract_multipart = partial( + _extract_multipart, + body_kwarg_multipart_form_part_limit=body_kwarg_multipart_form_part_limit, + is_data_optional=is_data_optional, + data_dto=data_dto, + field_definition=field_definition, + ) + + return cast("Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", extract_multipart) + + +def create_url_encoded_data_extractor( + is_data_optional: bool, data_dto: type[AbstractDTO] | None +) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]: + """Create extractor for url encoded form-data. + + Args: + is_data_optional: Boolean dictating whether the field is optional. + data_dto: A data DTO type, if configured for handler. + + Returns: + An extractor function. + """ + + async def extract_url_encoded_extractor( + connection: Request[Any, Any, Any], + ) -> Any: + connection.scope["_form"] = form_values = ( # type: ignore[typeddict-unknown-key] + connection.scope["_form"] # type: ignore[typeddict-item] + if "_form" in connection.scope + else parse_url_encoded_form_data(await connection.body()) + ) + + if not form_values and is_data_optional: + return None + + return data_dto(connection).decode_builtins(form_values) if data_dto else form_values + + return cast( + "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", extract_url_encoded_extractor + ) + + +def create_data_extractor(kwargs_model: KwargsModel) -> Callable[[dict[str, Any], ASGIConnection], None]: + """Create an extractor for a request's body. + + Args: + kwargs_model: The KwargsModel instance. + + Returns: + An extractor for the request's body. + """ + + if kwargs_model.expected_form_data: + media_type, field_definition = kwargs_model.expected_form_data + + if media_type == RequestEncodingType.MULTI_PART: + data_extractor = create_multipart_extractor( + field_definition=field_definition, + is_data_optional=kwargs_model.is_data_optional, + data_dto=kwargs_model.expected_data_dto, + ) + else: + data_extractor = create_url_encoded_data_extractor( + is_data_optional=kwargs_model.is_data_optional, + data_dto=kwargs_model.expected_data_dto, + ) + elif kwargs_model.expected_msgpack_data: + data_extractor = cast( + "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", msgpack_extractor + ) + elif kwargs_model.expected_data_dto: + data_extractor = create_dto_extractor(data_dto=kwargs_model.expected_data_dto) + else: + data_extractor = cast( + "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", json_extractor + ) + + def extractor( + values: dict[str, Any], + connection: ASGIConnection[Any, Any, Any, Any], + ) -> None: + values["data"] = data_extractor(connection) + + return extractor + + +def create_dto_extractor( + data_dto: type[AbstractDTO], +) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]: + """Create a DTO data extractor. + + + Returns: + An extractor function. + """ + + async def dto_extractor(connection: Request[Any, Any, Any]) -> Any: + if not (body := await connection.body()): + return Empty + return data_dto(connection).decode_bytes(body) + + return dto_extractor # type:ignore[return-value] diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py new file mode 100644 index 0000000..01ed2e5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py @@ -0,0 +1,479 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable + +from anyio import create_task_group + +from litestar._kwargs.cleanup import DependencyCleanupGroup +from litestar._kwargs.dependencies import ( + Dependency, + create_dependency_batches, + resolve_dependency, +) +from litestar._kwargs.extractors import ( + body_extractor, + cookies_extractor, + create_connection_value_extractor, + create_data_extractor, + headers_extractor, + parse_connection_headers, + parse_connection_query_params, + query_extractor, + request_extractor, + scope_extractor, + socket_extractor, + state_extractor, +) +from litestar._kwargs.parameter_definition import ( + ParameterDefinition, + create_parameter_definition, + merge_parameter_sets, +) +from litestar.constants import RESERVED_KWARGS +from litestar.enums import ParamType, RequestEncodingType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.params import BodyKwarg, ParameterKwarg +from litestar.typing import FieldDefinition +from litestar.utils.helpers import get_exception_group + +__all__ = ("KwargsModel",) + + +if TYPE_CHECKING: + from litestar._signature import SignatureModel + from litestar.connection import ASGIConnection + from litestar.di import Provide + from litestar.dto import AbstractDTO + from litestar.utils.signature import ParsedSignature + +_ExceptionGroup = get_exception_group() + + +class KwargsModel: + """Model required kwargs for a given RouteHandler and its dependencies. + + This is done once and is memoized during application bootstrap, ensuring minimal runtime overhead. + """ + + __slots__ = ( + "dependency_batches", + "expected_cookie_params", + "expected_data_dto", + "expected_form_data", + "expected_header_params", + "expected_msgpack_data", + "expected_path_params", + "expected_query_params", + "expected_reserved_kwargs", + "extractors", + "has_kwargs", + "is_data_optional", + "sequence_query_parameter_names", + ) + + def __init__( + self, + *, + expected_cookie_params: set[ParameterDefinition], + expected_data_dto: type[AbstractDTO] | None, + expected_dependencies: set[Dependency], + expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None, + expected_header_params: set[ParameterDefinition], + expected_msgpack_data: FieldDefinition | None, + expected_path_params: set[ParameterDefinition], + expected_query_params: set[ParameterDefinition], + expected_reserved_kwargs: set[str], + is_data_optional: bool, + sequence_query_parameter_names: set[str], + ) -> None: + """Initialize ``KwargsModel``. + + Args: + expected_cookie_params: Any expected cookie parameter kwargs + expected_dependencies: Any expected dependency kwargs + expected_form_data: Any expected form data kwargs + expected_header_params: Any expected header parameter kwargs + expected_msgpack_data: Any expected MessagePack data kwargs + expected_path_params: Any expected path parameter kwargs + expected_query_params: Any expected query parameter kwargs + expected_reserved_kwargs: Any expected reserved kwargs, e.g. 'state' + expected_data_dto: A data DTO, if defined + is_data_optional: Treat data as optional + sequence_query_parameter_names: Any query parameters that are sequences + """ + self.expected_cookie_params = expected_cookie_params + self.expected_form_data = expected_form_data + self.expected_header_params = expected_header_params + self.expected_msgpack_data = expected_msgpack_data + self.expected_path_params = expected_path_params + self.expected_query_params = expected_query_params + self.expected_reserved_kwargs = expected_reserved_kwargs + self.expected_data_dto = expected_data_dto + self.sequence_query_parameter_names = tuple(sequence_query_parameter_names) + + self.has_kwargs = ( + expected_cookie_params + or expected_dependencies + or expected_form_data + or expected_msgpack_data + or expected_header_params + or expected_path_params + or expected_query_params + or expected_reserved_kwargs + or expected_data_dto + ) + + self.is_data_optional = is_data_optional + self.extractors = self._create_extractors() + self.dependency_batches = create_dependency_batches(expected_dependencies) + + def _create_extractors(self) -> list[Callable[[dict[str, Any], ASGIConnection], None]]: + reserved_kwargs_extractors: dict[str, Callable[[dict[str, Any], ASGIConnection], None]] = { + "data": create_data_extractor(self), + "state": state_extractor, + "scope": scope_extractor, + "request": request_extractor, + "socket": socket_extractor, + "headers": headers_extractor, + "cookies": cookies_extractor, + "query": query_extractor, + "body": body_extractor, # type: ignore[dict-item] + } + + extractors: list[Callable[[dict[str, Any], ASGIConnection], None]] = [ + reserved_kwargs_extractors[reserved_kwarg] for reserved_kwarg in self.expected_reserved_kwargs + ] + + if self.expected_header_params: + extractors.append( + create_connection_value_extractor( + connection_key="headers", + expected_params=self.expected_header_params, + kwargs_model=self, + parser=parse_connection_headers, + ), + ) + + if self.expected_path_params: + extractors.append( + create_connection_value_extractor( + connection_key="path_params", + expected_params=self.expected_path_params, + kwargs_model=self, + ), + ) + + if self.expected_cookie_params: + extractors.append( + create_connection_value_extractor( + connection_key="cookies", + expected_params=self.expected_cookie_params, + kwargs_model=self, + ), + ) + + if self.expected_query_params: + extractors.append( + create_connection_value_extractor( + connection_key="query_params", + expected_params=self.expected_query_params, + kwargs_model=self, + parser=parse_connection_query_params, + ), + ) + return extractors + + @classmethod + def _get_param_definitions( + cls, + path_parameters: set[str], + layered_parameters: dict[str, FieldDefinition], + dependencies: dict[str, Provide], + field_definitions: dict[str, FieldDefinition], + ) -> tuple[set[ParameterDefinition], set[Dependency]]: + """Get parameter_definitions for the construction of KwargsModel instance. + + Args: + path_parameters: Any expected path parameters. + layered_parameters: A string keyed dictionary of layered parameters. + dependencies: A string keyed dictionary mapping dependency providers. + field_definitions: The SignatureModel fields. + + Returns: + A Tuple of sets + """ + expected_dependencies = { + cls._create_dependency_graph(key=key, dependencies=dependencies) + for key in dependencies + if key in field_definitions + } + ignored_keys = {*RESERVED_KWARGS, *(dependency.key for dependency in expected_dependencies)} + + param_definitions = { + *( + create_parameter_definition( + field_definition=field_definition, + field_name=field_name, + path_parameters=path_parameters, + ) + for field_name, field_definition in layered_parameters.items() + if field_name not in ignored_keys and field_name not in field_definitions + ), + *( + create_parameter_definition( + field_definition=field_definition, + field_name=field_name, + path_parameters=path_parameters, + ) + for field_name, field_definition in field_definitions.items() + if field_name not in ignored_keys and field_name not in layered_parameters + ), + } + + for field_name, field_definition in ( + (k, v) for k, v in field_definitions.items() if k not in ignored_keys and k in layered_parameters + ): + layered_parameter = layered_parameters[field_name] + field = field_definition if field_definition.is_parameter_field else layered_parameter + default = field_definition.default if field_definition.has_default else layered_parameter.default + + param_definitions.add( + create_parameter_definition( + field_definition=FieldDefinition.from_kwarg( + name=field.name, + default=default, + inner_types=field.inner_types, + annotation=field.annotation, + kwarg_definition=field.kwarg_definition, + extra=field.extra, + ), + field_name=field_name, + path_parameters=path_parameters, + ) + ) + + return param_definitions, expected_dependencies + + @classmethod + def create_for_signature_model( + cls, + signature_model: type[SignatureModel], + parsed_signature: ParsedSignature, + dependencies: dict[str, Provide], + path_parameters: set[str], + layered_parameters: dict[str, FieldDefinition], + ) -> KwargsModel: + """Pre-determine what parameters are required for a given combination of route + route handler. It is executed + during the application bootstrap process. + + Args: + signature_model: A :class:`SignatureModel <litestar._signature.SignatureModel>` subclass. + parsed_signature: A :class:`ParsedSignature <litestar._signature.ParsedSignature>` instance. + dependencies: A string keyed dictionary mapping dependency providers. + path_parameters: Any expected path parameters. + layered_parameters: A string keyed dictionary of layered parameters. + + Returns: + An instance of KwargsModel + """ + + field_definitions = signature_model._fields + + cls._validate_raw_kwargs( + path_parameters=path_parameters, + dependencies=dependencies, + field_definitions=field_definitions, + layered_parameters=layered_parameters, + ) + + param_definitions, expected_dependencies = cls._get_param_definitions( + path_parameters=path_parameters, + layered_parameters=layered_parameters, + dependencies=dependencies, + field_definitions=field_definitions, + ) + + expected_reserved_kwargs = {field_name for field_name in field_definitions if field_name in RESERVED_KWARGS} + expected_path_parameters = {p for p in param_definitions if p.param_type == ParamType.PATH} + expected_header_parameters = {p for p in param_definitions if p.param_type == ParamType.HEADER} + expected_cookie_parameters = {p for p in param_definitions if p.param_type == ParamType.COOKIE} + expected_query_parameters = {p for p in param_definitions if p.param_type == ParamType.QUERY} + sequence_query_parameter_names = {p.field_alias for p in expected_query_parameters if p.is_sequence} + + expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None = None + expected_msgpack_data: FieldDefinition | None = None + expected_data_dto: type[AbstractDTO] | None = None + data_field_definition = field_definitions.get("data") + + media_type: RequestEncodingType | str | None = None + if data_field_definition: + if isinstance(data_field_definition.kwarg_definition, BodyKwarg): + media_type = data_field_definition.kwarg_definition.media_type + + if media_type in (RequestEncodingType.MULTI_PART, RequestEncodingType.URL_ENCODED): + expected_form_data = (media_type, data_field_definition) + expected_data_dto = signature_model._data_dto + elif signature_model._data_dto: + expected_data_dto = signature_model._data_dto + elif media_type == RequestEncodingType.MESSAGEPACK: + expected_msgpack_data = data_field_definition + + for dependency in expected_dependencies: + dependency_kwargs_model = cls.create_for_signature_model( + signature_model=dependency.provide.signature_model, + parsed_signature=parsed_signature, + dependencies=dependencies, + path_parameters=path_parameters, + layered_parameters=layered_parameters, + ) + expected_path_parameters = merge_parameter_sets( + expected_path_parameters, dependency_kwargs_model.expected_path_params + ) + expected_query_parameters = merge_parameter_sets( + expected_query_parameters, dependency_kwargs_model.expected_query_params + ) + expected_cookie_parameters = merge_parameter_sets( + expected_cookie_parameters, dependency_kwargs_model.expected_cookie_params + ) + expected_header_parameters = merge_parameter_sets( + expected_header_parameters, dependency_kwargs_model.expected_header_params + ) + + if "data" in expected_reserved_kwargs and "data" in dependency_kwargs_model.expected_reserved_kwargs: + cls._validate_dependency_data( + expected_form_data=expected_form_data, + dependency_kwargs_model=dependency_kwargs_model, + ) + + expected_reserved_kwargs.update(dependency_kwargs_model.expected_reserved_kwargs) + sequence_query_parameter_names.update(dependency_kwargs_model.sequence_query_parameter_names) + + return KwargsModel( + expected_cookie_params=expected_cookie_parameters, + expected_dependencies=expected_dependencies, + expected_data_dto=expected_data_dto, + expected_form_data=expected_form_data, + expected_header_params=expected_header_parameters, + expected_msgpack_data=expected_msgpack_data, + expected_path_params=expected_path_parameters, + expected_query_params=expected_query_parameters, + expected_reserved_kwargs=expected_reserved_kwargs, + is_data_optional=field_definitions["data"].is_optional if "data" in expected_reserved_kwargs else False, + sequence_query_parameter_names=sequence_query_parameter_names, + ) + + def to_kwargs(self, connection: ASGIConnection) -> dict[str, Any]: + """Return a dictionary of kwargs. Async values, i.e. CoRoutines, are not resolved to ensure this function is + sync. + + Args: + connection: An instance of :class:`Request <litestar.connection.Request>` or + :class:`WebSocket <litestar.connection.WebSocket>`. + + Returns: + A string keyed dictionary of kwargs expected by the handler function and its dependencies. + """ + output: dict[str, Any] = {} + + for extractor in self.extractors: + extractor(output, connection) + + return output + + async def resolve_dependencies(self, connection: ASGIConnection, kwargs: dict[str, Any]) -> DependencyCleanupGroup: + """Resolve all dependencies into the kwargs, recursively. + + Args: + connection: An instance of :class:`Request <litestar.connection.Request>` or + :class:`WebSocket <litestar.connection.WebSocket>`. + kwargs: Kwargs to pass to dependencies. + """ + cleanup_group = DependencyCleanupGroup() + for batch in self.dependency_batches: + if len(batch) == 1: + await resolve_dependency(next(iter(batch)), connection, kwargs, cleanup_group) + else: + try: + async with create_task_group() as task_group: + for dependency in batch: + task_group.start_soon(resolve_dependency, dependency, connection, kwargs, cleanup_group) + except _ExceptionGroup as excgroup: + raise excgroup.exceptions[0] from excgroup # type: ignore[attr-defined] + + return cleanup_group + + @classmethod + def _create_dependency_graph(cls, key: str, dependencies: dict[str, Provide]) -> Dependency: + """Create a graph like structure of dependencies, with each dependency including its own dependencies as a + list. + """ + provide = dependencies[key] + sub_dependency_keys = [k for k in provide.signature_model._fields if k in dependencies] + return Dependency( + key=key, + provide=provide, + dependencies=[cls._create_dependency_graph(key=k, dependencies=dependencies) for k in sub_dependency_keys], + ) + + @classmethod + def _validate_dependency_data( + cls, + expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None, + dependency_kwargs_model: KwargsModel, + ) -> None: + """Validate that the 'data' kwarg is compatible across dependencies.""" + if bool(expected_form_data) != bool(dependency_kwargs_model.expected_form_data): + raise ImproperlyConfiguredException( + "Dependencies have incompatible 'data' kwarg types: one expects JSON and the other expects form-data" + ) + if expected_form_data and dependency_kwargs_model.expected_form_data: + local_media_type = expected_form_data[0] + dependency_media_type = dependency_kwargs_model.expected_form_data[0] + if local_media_type != dependency_media_type: + raise ImproperlyConfiguredException( + "Dependencies have incompatible form-data encoding: one expects url-encoded and the other expects multi-part" + ) + + @classmethod + def _validate_raw_kwargs( + cls, + path_parameters: set[str], + dependencies: dict[str, Provide], + field_definitions: dict[str, FieldDefinition], + layered_parameters: dict[str, FieldDefinition], + ) -> None: + """Validate that there are no ambiguous kwargs, that is, kwargs declared using the same key in different + places. + """ + dependency_keys = set(dependencies.keys()) + + parameter_names = { + *( + k + for k, f in field_definitions.items() + if isinstance(f.kwarg_definition, ParameterKwarg) + and (f.kwarg_definition.header or f.kwarg_definition.query or f.kwarg_definition.cookie) + ), + *list(layered_parameters.keys()), + } + + intersection = ( + path_parameters.intersection(dependency_keys) + or path_parameters.intersection(parameter_names) + or dependency_keys.intersection(parameter_names) + ) + if intersection: + raise ImproperlyConfiguredException( + f"Kwarg resolution ambiguity detected for the following keys: {', '.join(intersection)}. " + f"Make sure to use distinct keys for your dependencies, path parameters, and aliased parameters." + ) + + if used_reserved_kwargs := { + *parameter_names, + *path_parameters, + *dependency_keys, + }.intersection(RESERVED_KWARGS): + raise ImproperlyConfiguredException( + f"Reserved kwargs ({', '.join(RESERVED_KWARGS)}) cannot be used for dependencies and parameter arguments. " + f"The following kwargs have been used: {', '.join(used_reserved_kwargs)}" + ) diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/parameter_definition.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/parameter_definition.py new file mode 100644 index 0000000..02b09fc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/parameter_definition.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, NamedTuple + +from litestar.enums import ParamType +from litestar.params import ParameterKwarg + +if TYPE_CHECKING: + from litestar.typing import FieldDefinition + +__all__ = ("ParameterDefinition", "create_parameter_definition", "merge_parameter_sets") + + +class ParameterDefinition(NamedTuple): + """Tuple defining a kwarg representing a request parameter.""" + + default: Any + field_alias: str + field_name: str + is_required: bool + is_sequence: bool + param_type: ParamType + + +def create_parameter_definition( + field_definition: FieldDefinition, + field_name: str, + path_parameters: set[str], +) -> ParameterDefinition: + """Create a ParameterDefinition for the given FieldDefinition. + + Args: + field_definition: FieldDefinition instance. + field_name: The field's name. + path_parameters: A set of path parameter names. + + Returns: + A ParameterDefinition tuple. + """ + default = field_definition.default if field_definition.has_default else None + kwarg_definition = ( + field_definition.kwarg_definition if isinstance(field_definition.kwarg_definition, ParameterKwarg) else None + ) + + field_alias = kwarg_definition.query if kwarg_definition and kwarg_definition.query else field_name + param_type = ParamType.QUERY + + if field_name in path_parameters: + field_alias = field_name + param_type = ParamType.PATH + elif kwarg_definition and kwarg_definition.header: + field_alias = kwarg_definition.header + param_type = ParamType.HEADER + elif kwarg_definition and kwarg_definition.cookie: + field_alias = kwarg_definition.cookie + param_type = ParamType.COOKIE + + return ParameterDefinition( + param_type=param_type, + field_name=field_name, + field_alias=field_alias, + default=default, + is_required=field_definition.is_required + and default is None + and not field_definition.is_optional + and not field_definition.is_any, + is_sequence=field_definition.is_non_string_sequence, + ) + + +def merge_parameter_sets(first: set[ParameterDefinition], second: set[ParameterDefinition]) -> set[ParameterDefinition]: + """Given two sets of parameter definitions, coming from different dependencies for example, merge them into a single + set. + """ + result: set[ParameterDefinition] = first.intersection(second) + difference = first.symmetric_difference(second) + for param in difference: + # add the param if it's either required or no-other param in difference is the same but required + if param.is_required or not any(p.field_alias == param.field_alias and p.is_required for p in difference): + result.add(param) + return result |