diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/_kwargs')
12 files changed, 1291 insertions, 0 deletions
| diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__init__.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/__init__.py new file mode 100644 index 0000000..af8ad36 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__init__.py @@ -0,0 +1,3 @@ +from .kwargs_model import KwargsModel + +__all__ = ("KwargsModel",) diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/__init__.cpython-311.pycBinary files differ new file mode 100644 index 0000000..b61cc1b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/cleanup.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/cleanup.cpython-311.pycBinary files differ new file mode 100644 index 0000000..a084eb5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/cleanup.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/dependencies.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/dependencies.cpython-311.pycBinary files differ new file mode 100644 index 0000000..b0f49a4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/dependencies.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/extractors.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/extractors.cpython-311.pycBinary files differ new file mode 100644 index 0000000..35a6e40 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/extractors.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/kwargs_model.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/kwargs_model.cpython-311.pycBinary files differ new file mode 100644 index 0000000..17029de --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/kwargs_model.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/parameter_definition.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/parameter_definition.cpython-311.pycBinary files differ new file mode 100644 index 0000000..10753dc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/parameter_definition.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/cleanup.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/cleanup.py new file mode 100644 index 0000000..8839d36 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/cleanup.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from inspect import Traceback, isasyncgen +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator + +from anyio import create_task_group + +from litestar.utils import ensure_async_callable +from litestar.utils.compat import async_next + +__all__ = ("DependencyCleanupGroup",) + + +if TYPE_CHECKING: +    from litestar.types import AnyGenerator + + +class DependencyCleanupGroup: +    """Wrapper for generator based dependencies. + +    Simplify cleanup by wrapping :func:`next` / :func:`anext` calls and providing facilities to +    :meth:`throw <generator.throw>` / :meth:`athrow <agen.athrow>` into all generators consecutively. An instance of +    this class can be used as a contextmanager, which will automatically throw any exceptions into its generators. All +    exceptions caught in this manner will be re-raised after they have been thrown in the generators. +    """ + +    __slots__ = ("_generators", "_closed") + +    def __init__(self, generators: list[AnyGenerator] | None = None) -> None: +        """Initialize ``DependencyCleanupGroup``. + +        Args: +            generators: An optional list of generators to be called at cleanup +        """ +        self._generators = generators or [] +        self._closed = False + +    def add(self, generator: Generator[Any, None, None] | AsyncGenerator[Any, None]) -> None: +        """Add a new generator to the group. + +        Args: +            generator: The generator to add + +        Returns: +            None +        """ +        if self._closed: +            raise RuntimeError("Cannot call cleanup on a closed DependencyCleanupGroup") +        self._generators.append(generator) + +    @staticmethod +    def _wrap_next(generator: AnyGenerator) -> Callable[[], Awaitable[None]]: +        if isasyncgen(generator): + +            async def wrapped_async() -> None: +                await async_next(generator, None) + +            return wrapped_async + +        def wrapped() -> None: +            next(generator, None)  # type: ignore[arg-type] + +        return ensure_async_callable(wrapped) + +    async def cleanup(self) -> None: +        """Execute cleanup by calling :func:`next` / :func:`anext` on all generators. + +        If there are multiple generators to be called, they will be executed in a :class:`anyio.TaskGroup`. + +        Returns: +            None +        """ +        if self._closed: +            raise RuntimeError("Cannot call cleanup on a closed DependencyCleanupGroup") + +        self._closed = True + +        if not self._generators: +            return + +        if len(self._generators) == 1: +            await self._wrap_next(self._generators[0])() +            return + +        async with create_task_group() as task_group: +            for generator in self._generators: +                task_group.start_soon(self._wrap_next(generator)) + +    async def __aenter__(self) -> None: +        """Support the async contextmanager protocol to allow for easier catching and throwing of exceptions into the +        generators. +        """ + +    async def __aexit__( +        self, +        exc_type: type[BaseException] | None, +        exc_val: BaseException | None, +        exc_tb: Traceback | None, +    ) -> None: +        """If an exception was raised within the contextmanager block, throw it into all generators.""" +        if exc_val: +            await self.throw(exc_val) + +    async def throw(self, exc: BaseException) -> None: +        """Throw an exception in all generators sequentially. + +        Args: +            exc: Exception to throw +        """ +        for gen in self._generators: +            try: +                if isasyncgen(gen): +                    await gen.athrow(exc) +                else: +                    gen.throw(exc)  # type: ignore[union-attr] +            except (StopIteration, StopAsyncIteration): +                continue diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/dependencies.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/dependencies.py new file mode 100644 index 0000000..88ffb07 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/dependencies.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from litestar.utils.compat import async_next + +__all__ = ("Dependency", "create_dependency_batches", "map_dependencies_recursively", "resolve_dependency") + + +if TYPE_CHECKING: +    from litestar._kwargs.cleanup import DependencyCleanupGroup +    from litestar.connection import ASGIConnection +    from litestar.di import Provide + + +class Dependency: +    """Dependency graph of a given combination of ``Route`` + ``RouteHandler``""" + +    __slots__ = ("key", "provide", "dependencies") + +    def __init__(self, key: str, provide: Provide, dependencies: list[Dependency]) -> None: +        """Initialize a dependency. + +        Args: +            key: The dependency key +            provide: Provider +            dependencies: List of child nodes +        """ +        self.key = key +        self.provide = provide +        self.dependencies = dependencies + +    def __eq__(self, other: Any) -> bool: +        # check if memory address is identical, otherwise compare attributes +        return other is self or (isinstance(other, self.__class__) and other.key == self.key) + +    def __hash__(self) -> int: +        return hash(self.key) + + +async def resolve_dependency( +    dependency: Dependency, +    connection: ASGIConnection, +    kwargs: dict[str, Any], +    cleanup_group: DependencyCleanupGroup, +) -> None: +    """Resolve a given instance of :class:`Dependency <litestar._kwargs.Dependency>`. + +    All required sub dependencies must already +    be resolved into the kwargs. The result of the dependency will be stored in the kwargs. + +    Args: +        dependency: An instance of :class:`Dependency <litestar._kwargs.Dependency>` +        connection: An instance of :class:`Request <litestar.connection.Request>` or +            :class:`WebSocket <litestar.connection.WebSocket>`. +        kwargs: Any kwargs to pass to the dependency, the result will be stored here as well. +        cleanup_group: DependencyCleanupGroup to which generators returned by ``dependency`` will be added +    """ +    signature_model = dependency.provide.signature_model +    dependency_kwargs = ( +        signature_model.parse_values_from_connection_kwargs(connection=connection, **kwargs) +        if signature_model._fields +        else {} +    ) +    value = await dependency.provide(**dependency_kwargs) + +    if dependency.provide.has_sync_generator_dependency: +        cleanup_group.add(value) +        value = next(value) +    elif dependency.provide.has_async_generator_dependency: +        cleanup_group.add(value) +        value = await async_next(value) + +    kwargs[dependency.key] = value + + +def create_dependency_batches(expected_dependencies: set[Dependency]) -> list[set[Dependency]]: +    """Calculate batches for all dependencies, recursively. + +    Args: +        expected_dependencies: A set of all direct :class:`Dependencies <litestar._kwargs.Dependency>`. + +    Returns: +        A list of batches. +    """ +    dependencies_to: dict[Dependency, set[Dependency]] = {} +    for dependency in expected_dependencies: +        if dependency not in dependencies_to: +            map_dependencies_recursively(dependency, dependencies_to) + +    batches = [] +    while dependencies_to: +        current_batch = { +            dependency +            for dependency, remaining_sub_dependencies in dependencies_to.items() +            if not remaining_sub_dependencies +        } + +        for dependency in current_batch: +            del dependencies_to[dependency] +            for others_dependencies in dependencies_to.values(): +                others_dependencies.discard(dependency) + +        batches.append(current_batch) + +    return batches + + +def map_dependencies_recursively(dependency: Dependency, dependencies_to: dict[Dependency, set[Dependency]]) -> None: +    """Recursively map dependencies to their sub dependencies. + +    Args: +        dependency: The current dependency to map. +        dependencies_to: A map of dependency to its sub dependencies. +    """ +    dependencies_to[dependency] = set(dependency.dependencies) +    for sub in dependency.dependencies: +        if sub not in dependencies_to: +            map_dependencies_recursively(sub, dependencies_to) diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py new file mode 100644 index 0000000..e3b347e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py @@ -0,0 +1,492 @@ +from __future__ import annotations + +from collections import defaultdict +from functools import lru_cache, partial +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Mapping, NamedTuple, cast + +from litestar._multipart import parse_multipart_form +from litestar._parsers import ( +    parse_query_string, +    parse_url_encoded_form_data, +) +from litestar.datastructures import Headers +from litestar.datastructures.upload_file import UploadFile +from litestar.datastructures.url import URL +from litestar.enums import ParamType, RequestEncodingType +from litestar.exceptions import ValidationException +from litestar.params import BodyKwarg +from litestar.types import Empty +from litestar.utils.predicates import is_non_string_sequence +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: +    from litestar._kwargs import KwargsModel +    from litestar._kwargs.parameter_definition import ParameterDefinition +    from litestar.connection import ASGIConnection, Request +    from litestar.dto import AbstractDTO +    from litestar.typing import FieldDefinition + + +__all__ = ( +    "body_extractor", +    "cookies_extractor", +    "create_connection_value_extractor", +    "create_data_extractor", +    "create_multipart_extractor", +    "create_query_default_dict", +    "create_url_encoded_data_extractor", +    "headers_extractor", +    "json_extractor", +    "msgpack_extractor", +    "parse_connection_headers", +    "parse_connection_query_params", +    "query_extractor", +    "request_extractor", +    "scope_extractor", +    "socket_extractor", +    "state_extractor", +) + + +class ParamMappings(NamedTuple): +    alias_and_key_tuples: list[tuple[str, str]] +    alias_defaults: dict[str, Any] +    alias_to_param: dict[str, ParameterDefinition] + + +def _create_param_mappings(expected_params: set[ParameterDefinition]) -> ParamMappings: +    alias_and_key_tuples = [] +    alias_defaults = {} +    alias_to_params: dict[str, ParameterDefinition] = {} +    for param in expected_params: +        alias = param.field_alias +        if param.param_type == ParamType.HEADER: +            alias = alias.lower() + +        alias_and_key_tuples.append((alias, param.field_name)) + +        if not (param.is_required or param.default is Ellipsis): +            alias_defaults[alias] = param.default + +        alias_to_params[alias] = param + +    return ParamMappings( +        alias_and_key_tuples=alias_and_key_tuples, +        alias_defaults=alias_defaults, +        alias_to_param=alias_to_params, +    ) + + +def create_connection_value_extractor( +    kwargs_model: KwargsModel, +    connection_key: str, +    expected_params: set[ParameterDefinition], +    parser: Callable[[ASGIConnection, KwargsModel], Mapping[str, Any]] | None = None, +) -> Callable[[dict[str, Any], ASGIConnection], None]: +    """Create a kwargs extractor function. + +    Args: +        kwargs_model: The KwargsModel instance. +        connection_key: The attribute key to use. +        expected_params: The set of expected params. +        parser: An optional parser function. + +    Returns: +        An extractor function. +    """ + +    alias_and_key_tuples, alias_defaults, alias_to_params = _create_param_mappings(expected_params) + +    def extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +        data = parser(connection, kwargs_model) if parser else getattr(connection, connection_key, {}) + +        try: +            connection_mapping: dict[str, Any] = { +                key: data[alias] if alias in data else alias_defaults[alias] for alias, key in alias_and_key_tuples +            } +            values.update(connection_mapping) +        except KeyError as e: +            param = alias_to_params[e.args[0]] +            path = URL.from_components( +                path=connection.url.path, +                query=connection.url.query, +            ) +            raise ValidationException( +                f"Missing required {param.param_type.value} parameter {param.field_alias!r} for path {path}" +            ) from e + +    return extractor + + +@lru_cache(1024) +def create_query_default_dict( +    parsed_query: tuple[tuple[str, str], ...], sequence_query_parameter_names: tuple[str, ...] +) -> defaultdict[str, list[str] | str]: +    """Transform a list of tuples into a default dict. Ensures non-list values are not wrapped in a list. + +    Args: +        parsed_query: The parsed query list of tuples. +        sequence_query_parameter_names: A set of query parameters that should be wrapped in list. + +    Returns: +        A default dict +    """ +    output: defaultdict[str, list[str] | str] = defaultdict(list) + +    for k, v in parsed_query: +        if k in sequence_query_parameter_names: +            output[k].append(v)  # type: ignore[union-attr] +        else: +            output[k] = v + +    return output + + +def parse_connection_query_params(connection: ASGIConnection, kwargs_model: KwargsModel) -> dict[str, Any]: +    """Parse query params and cache the result in scope. + +    Args: +        connection: The ASGI connection instance. +        kwargs_model: The KwargsModel instance. + +    Returns: +        A dictionary of parsed values. +    """ +    parsed_query = ( +        connection._parsed_query +        if connection._parsed_query is not Empty +        else parse_query_string(connection.scope.get("query_string", b"")) +    ) +    ScopeState.from_scope(connection.scope).parsed_query = parsed_query +    return create_query_default_dict( +        parsed_query=parsed_query, +        sequence_query_parameter_names=kwargs_model.sequence_query_parameter_names, +    ) + + +def parse_connection_headers(connection: ASGIConnection, _: KwargsModel) -> Headers: +    """Parse header parameters and cache the result in scope. + +    Args: +        connection: The ASGI connection instance. +        _: The KwargsModel instance. + +    Returns: +        A Headers instance +    """ +    return Headers.from_scope(connection.scope) + + +def state_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +    """Extract the app state from the connection and insert it to the kwargs injected to the handler. + +    Args: +        connection: The ASGI connection instance. +        values: The kwargs that are extracted from the connection and will be injected into the handler. + +    Returns: +        None +    """ +    values["state"] = connection.app.state._state + + +def headers_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +    """Extract the headers from the connection and insert them to the kwargs injected to the handler. + +    Args: +        connection: The ASGI connection instance. +        values: The kwargs that are extracted from the connection and will be injected into the handler. + +    Returns: +        None +    """ +    # TODO: This should be removed in 3.0 and instead Headers should be injected +    # directly. We are only keeping this one around to not break things +    values["headers"] = dict(connection.headers.items()) + + +def cookies_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +    """Extract the cookies from the connection and insert them to the kwargs injected to the handler. + +    Args: +        connection: The ASGI connection instance. +        values: The kwargs that are extracted from the connection and will be injected into the handler. + +    Returns: +        None +    """ +    values["cookies"] = connection.cookies + + +def query_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +    """Extract the query params from the connection and insert them to the kwargs injected to the handler. + +    Args: +        connection: The ASGI connection instance. +        values: The kwargs that are extracted from the connection and will be injected into the handler. + +    Returns: +        None +    """ +    values["query"] = connection.query_params + + +def scope_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +    """Extract the scope from the connection and insert it into the kwargs injected to the handler. + +    Args: +        connection: The ASGI connection instance. +        values: The kwargs that are extracted from the connection and will be injected into the handler. + +    Returns: +        None +    """ +    values["scope"] = connection.scope + + +def request_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +    """Set the connection instance as the 'request' value in the kwargs injected to the handler. + +    Args: +        connection: The ASGI connection instance. +        values: The kwargs that are extracted from the connection and will be injected into the handler. + +    Returns: +        None +    """ +    values["request"] = connection + + +def socket_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +    """Set the connection instance as the 'socket' value in the kwargs injected to the handler. + +    Args: +        connection: The ASGI connection instance. +        values: The kwargs that are extracted from the connection and will be injected into the handler. + +    Returns: +        None +    """ +    values["socket"] = connection + + +def body_extractor( +    values: dict[str, Any], +    connection: Request[Any, Any, Any], +) -> None: +    """Extract the body from the request instance. + +    Notes: +        - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage. + +    Args: +        connection: The ASGI connection instance. +        values: The kwargs that are extracted from the connection and will be injected into the handler. + +    Returns: +        The Body value. +    """ +    values["body"] = connection.body() + + +async def json_extractor(connection: Request[Any, Any, Any]) -> Any: +    """Extract the data from request and insert it into the kwargs injected to the handler. + +    Notes: +        - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage. + +    Args: +        connection: The ASGI connection instance. + +    Returns: +        The JSON value. +    """ +    if not await connection.body(): +        return Empty +    return await connection.json() + + +async def msgpack_extractor(connection: Request[Any, Any, Any]) -> Any: +    """Extract the data from request and insert it into the kwargs injected to the handler. + +    Notes: +        - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage. + +    Args: +        connection: The ASGI connection instance. + +    Returns: +        The MessagePack value. +    """ +    if not await connection.body(): +        return Empty +    return await connection.msgpack() + + +async def _extract_multipart( +    connection: Request[Any, Any, Any], +    body_kwarg_multipart_form_part_limit: int | None, +    field_definition: FieldDefinition, +    is_data_optional: bool, +    data_dto: type[AbstractDTO] | None, +) -> Any: +    multipart_form_part_limit = ( +        body_kwarg_multipart_form_part_limit +        if body_kwarg_multipart_form_part_limit is not None +        else connection.app.multipart_form_part_limit +    ) +    connection.scope["_form"] = form_values = (  # type: ignore[typeddict-unknown-key] +        connection.scope["_form"]  # type: ignore[typeddict-item] +        if "_form" in connection.scope +        else parse_multipart_form( +            body=await connection.body(), +            boundary=connection.content_type[-1].get("boundary", "").encode(), +            multipart_form_part_limit=multipart_form_part_limit, +            type_decoders=connection.route_handler.resolve_type_decoders(), +        ) +    ) + +    if field_definition.is_non_string_sequence: +        values = list(form_values.values()) +        if field_definition.has_inner_subclass_of(UploadFile) and isinstance(values[0], list): +            return values[0] + +        return values + +    if field_definition.is_simple_type and field_definition.annotation is UploadFile and form_values: +        return next(v for v in form_values.values() if isinstance(v, UploadFile)) + +    if not form_values and is_data_optional: +        return None + +    if data_dto: +        return data_dto(connection).decode_builtins(form_values) + +    for name, tp in field_definition.get_type_hints().items(): +        value = form_values.get(name) +        if value is not None and is_non_string_sequence(tp) and not isinstance(value, list): +            form_values[name] = [value] + +    return form_values + + +def create_multipart_extractor( +    field_definition: FieldDefinition, is_data_optional: bool, data_dto: type[AbstractDTO] | None +) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]: +    """Create a multipart form-data extractor. + +    Args: +        field_definition: A FieldDefinition instance. +        is_data_optional: Boolean dictating whether the field is optional. +        data_dto: A data DTO type, if configured for handler. + +    Returns: +        An extractor function. +    """ +    body_kwarg_multipart_form_part_limit: int | None = None +    if field_definition.kwarg_definition and isinstance(field_definition.kwarg_definition, BodyKwarg): +        body_kwarg_multipart_form_part_limit = field_definition.kwarg_definition.multipart_form_part_limit + +    extract_multipart = partial( +        _extract_multipart, +        body_kwarg_multipart_form_part_limit=body_kwarg_multipart_form_part_limit, +        is_data_optional=is_data_optional, +        data_dto=data_dto, +        field_definition=field_definition, +    ) + +    return cast("Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", extract_multipart) + + +def create_url_encoded_data_extractor( +    is_data_optional: bool, data_dto: type[AbstractDTO] | None +) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]: +    """Create extractor for url encoded form-data. + +    Args: +        is_data_optional: Boolean dictating whether the field is optional. +        data_dto: A data DTO type, if configured for handler. + +    Returns: +        An extractor function. +    """ + +    async def extract_url_encoded_extractor( +        connection: Request[Any, Any, Any], +    ) -> Any: +        connection.scope["_form"] = form_values = (  # type: ignore[typeddict-unknown-key] +            connection.scope["_form"]  # type: ignore[typeddict-item] +            if "_form" in connection.scope +            else parse_url_encoded_form_data(await connection.body()) +        ) + +        if not form_values and is_data_optional: +            return None + +        return data_dto(connection).decode_builtins(form_values) if data_dto else form_values + +    return cast( +        "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", extract_url_encoded_extractor +    ) + + +def create_data_extractor(kwargs_model: KwargsModel) -> Callable[[dict[str, Any], ASGIConnection], None]: +    """Create an extractor for a request's body. + +    Args: +        kwargs_model: The KwargsModel instance. + +    Returns: +        An extractor for the request's body. +    """ + +    if kwargs_model.expected_form_data: +        media_type, field_definition = kwargs_model.expected_form_data + +        if media_type == RequestEncodingType.MULTI_PART: +            data_extractor = create_multipart_extractor( +                field_definition=field_definition, +                is_data_optional=kwargs_model.is_data_optional, +                data_dto=kwargs_model.expected_data_dto, +            ) +        else: +            data_extractor = create_url_encoded_data_extractor( +                is_data_optional=kwargs_model.is_data_optional, +                data_dto=kwargs_model.expected_data_dto, +            ) +    elif kwargs_model.expected_msgpack_data: +        data_extractor = cast( +            "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", msgpack_extractor +        ) +    elif kwargs_model.expected_data_dto: +        data_extractor = create_dto_extractor(data_dto=kwargs_model.expected_data_dto) +    else: +        data_extractor = cast( +            "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", json_extractor +        ) + +    def extractor( +        values: dict[str, Any], +        connection: ASGIConnection[Any, Any, Any, Any], +    ) -> None: +        values["data"] = data_extractor(connection) + +    return extractor + + +def create_dto_extractor( +    data_dto: type[AbstractDTO], +) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]: +    """Create a DTO data extractor. + + +    Returns: +        An extractor function. +    """ + +    async def dto_extractor(connection: Request[Any, Any, Any]) -> Any: +        if not (body := await connection.body()): +            return Empty +        return data_dto(connection).decode_bytes(body) + +    return dto_extractor  # type:ignore[return-value] diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py new file mode 100644 index 0000000..01ed2e5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py @@ -0,0 +1,479 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable + +from anyio import create_task_group + +from litestar._kwargs.cleanup import DependencyCleanupGroup +from litestar._kwargs.dependencies import ( +    Dependency, +    create_dependency_batches, +    resolve_dependency, +) +from litestar._kwargs.extractors import ( +    body_extractor, +    cookies_extractor, +    create_connection_value_extractor, +    create_data_extractor, +    headers_extractor, +    parse_connection_headers, +    parse_connection_query_params, +    query_extractor, +    request_extractor, +    scope_extractor, +    socket_extractor, +    state_extractor, +) +from litestar._kwargs.parameter_definition import ( +    ParameterDefinition, +    create_parameter_definition, +    merge_parameter_sets, +) +from litestar.constants import RESERVED_KWARGS +from litestar.enums import ParamType, RequestEncodingType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.params import BodyKwarg, ParameterKwarg +from litestar.typing import FieldDefinition +from litestar.utils.helpers import get_exception_group + +__all__ = ("KwargsModel",) + + +if TYPE_CHECKING: +    from litestar._signature import SignatureModel +    from litestar.connection import ASGIConnection +    from litestar.di import Provide +    from litestar.dto import AbstractDTO +    from litestar.utils.signature import ParsedSignature + +_ExceptionGroup = get_exception_group() + + +class KwargsModel: +    """Model required kwargs for a given RouteHandler and its dependencies. + +    This is done once and is memoized during application bootstrap, ensuring minimal runtime overhead. +    """ + +    __slots__ = ( +        "dependency_batches", +        "expected_cookie_params", +        "expected_data_dto", +        "expected_form_data", +        "expected_header_params", +        "expected_msgpack_data", +        "expected_path_params", +        "expected_query_params", +        "expected_reserved_kwargs", +        "extractors", +        "has_kwargs", +        "is_data_optional", +        "sequence_query_parameter_names", +    ) + +    def __init__( +        self, +        *, +        expected_cookie_params: set[ParameterDefinition], +        expected_data_dto: type[AbstractDTO] | None, +        expected_dependencies: set[Dependency], +        expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None, +        expected_header_params: set[ParameterDefinition], +        expected_msgpack_data: FieldDefinition | None, +        expected_path_params: set[ParameterDefinition], +        expected_query_params: set[ParameterDefinition], +        expected_reserved_kwargs: set[str], +        is_data_optional: bool, +        sequence_query_parameter_names: set[str], +    ) -> None: +        """Initialize ``KwargsModel``. + +        Args: +            expected_cookie_params: Any expected cookie parameter kwargs +            expected_dependencies: Any expected dependency kwargs +            expected_form_data: Any expected form data kwargs +            expected_header_params: Any expected header parameter kwargs +            expected_msgpack_data: Any expected MessagePack data kwargs +            expected_path_params: Any expected path parameter kwargs +            expected_query_params: Any expected query parameter kwargs +            expected_reserved_kwargs: Any expected reserved kwargs, e.g. 'state' +            expected_data_dto: A data DTO, if defined +            is_data_optional: Treat data as optional +            sequence_query_parameter_names: Any query parameters that are sequences +        """ +        self.expected_cookie_params = expected_cookie_params +        self.expected_form_data = expected_form_data +        self.expected_header_params = expected_header_params +        self.expected_msgpack_data = expected_msgpack_data +        self.expected_path_params = expected_path_params +        self.expected_query_params = expected_query_params +        self.expected_reserved_kwargs = expected_reserved_kwargs +        self.expected_data_dto = expected_data_dto +        self.sequence_query_parameter_names = tuple(sequence_query_parameter_names) + +        self.has_kwargs = ( +            expected_cookie_params +            or expected_dependencies +            or expected_form_data +            or expected_msgpack_data +            or expected_header_params +            or expected_path_params +            or expected_query_params +            or expected_reserved_kwargs +            or expected_data_dto +        ) + +        self.is_data_optional = is_data_optional +        self.extractors = self._create_extractors() +        self.dependency_batches = create_dependency_batches(expected_dependencies) + +    def _create_extractors(self) -> list[Callable[[dict[str, Any], ASGIConnection], None]]: +        reserved_kwargs_extractors: dict[str, Callable[[dict[str, Any], ASGIConnection], None]] = { +            "data": create_data_extractor(self), +            "state": state_extractor, +            "scope": scope_extractor, +            "request": request_extractor, +            "socket": socket_extractor, +            "headers": headers_extractor, +            "cookies": cookies_extractor, +            "query": query_extractor, +            "body": body_extractor,  # type: ignore[dict-item] +        } + +        extractors: list[Callable[[dict[str, Any], ASGIConnection], None]] = [ +            reserved_kwargs_extractors[reserved_kwarg] for reserved_kwarg in self.expected_reserved_kwargs +        ] + +        if self.expected_header_params: +            extractors.append( +                create_connection_value_extractor( +                    connection_key="headers", +                    expected_params=self.expected_header_params, +                    kwargs_model=self, +                    parser=parse_connection_headers, +                ), +            ) + +        if self.expected_path_params: +            extractors.append( +                create_connection_value_extractor( +                    connection_key="path_params", +                    expected_params=self.expected_path_params, +                    kwargs_model=self, +                ), +            ) + +        if self.expected_cookie_params: +            extractors.append( +                create_connection_value_extractor( +                    connection_key="cookies", +                    expected_params=self.expected_cookie_params, +                    kwargs_model=self, +                ), +            ) + +        if self.expected_query_params: +            extractors.append( +                create_connection_value_extractor( +                    connection_key="query_params", +                    expected_params=self.expected_query_params, +                    kwargs_model=self, +                    parser=parse_connection_query_params, +                ), +            ) +        return extractors + +    @classmethod +    def _get_param_definitions( +        cls, +        path_parameters: set[str], +        layered_parameters: dict[str, FieldDefinition], +        dependencies: dict[str, Provide], +        field_definitions: dict[str, FieldDefinition], +    ) -> tuple[set[ParameterDefinition], set[Dependency]]: +        """Get parameter_definitions for the construction of KwargsModel instance. + +        Args: +            path_parameters: Any expected path parameters. +            layered_parameters: A string keyed dictionary of layered parameters. +            dependencies: A string keyed dictionary mapping dependency providers. +            field_definitions: The SignatureModel fields. + +        Returns: +            A Tuple of sets +        """ +        expected_dependencies = { +            cls._create_dependency_graph(key=key, dependencies=dependencies) +            for key in dependencies +            if key in field_definitions +        } +        ignored_keys = {*RESERVED_KWARGS, *(dependency.key for dependency in expected_dependencies)} + +        param_definitions = { +            *( +                create_parameter_definition( +                    field_definition=field_definition, +                    field_name=field_name, +                    path_parameters=path_parameters, +                ) +                for field_name, field_definition in layered_parameters.items() +                if field_name not in ignored_keys and field_name not in field_definitions +            ), +            *( +                create_parameter_definition( +                    field_definition=field_definition, +                    field_name=field_name, +                    path_parameters=path_parameters, +                ) +                for field_name, field_definition in field_definitions.items() +                if field_name not in ignored_keys and field_name not in layered_parameters +            ), +        } + +        for field_name, field_definition in ( +            (k, v) for k, v in field_definitions.items() if k not in ignored_keys and k in layered_parameters +        ): +            layered_parameter = layered_parameters[field_name] +            field = field_definition if field_definition.is_parameter_field else layered_parameter +            default = field_definition.default if field_definition.has_default else layered_parameter.default + +            param_definitions.add( +                create_parameter_definition( +                    field_definition=FieldDefinition.from_kwarg( +                        name=field.name, +                        default=default, +                        inner_types=field.inner_types, +                        annotation=field.annotation, +                        kwarg_definition=field.kwarg_definition, +                        extra=field.extra, +                    ), +                    field_name=field_name, +                    path_parameters=path_parameters, +                ) +            ) + +        return param_definitions, expected_dependencies + +    @classmethod +    def create_for_signature_model( +        cls, +        signature_model: type[SignatureModel], +        parsed_signature: ParsedSignature, +        dependencies: dict[str, Provide], +        path_parameters: set[str], +        layered_parameters: dict[str, FieldDefinition], +    ) -> KwargsModel: +        """Pre-determine what parameters are required for a given combination of route + route handler. It is executed +        during the application bootstrap process. + +        Args: +            signature_model: A :class:`SignatureModel <litestar._signature.SignatureModel>` subclass. +            parsed_signature: A :class:`ParsedSignature <litestar._signature.ParsedSignature>` instance. +            dependencies: A string keyed dictionary mapping dependency providers. +            path_parameters: Any expected path parameters. +            layered_parameters: A string keyed dictionary of layered parameters. + +        Returns: +            An instance of KwargsModel +        """ + +        field_definitions = signature_model._fields + +        cls._validate_raw_kwargs( +            path_parameters=path_parameters, +            dependencies=dependencies, +            field_definitions=field_definitions, +            layered_parameters=layered_parameters, +        ) + +        param_definitions, expected_dependencies = cls._get_param_definitions( +            path_parameters=path_parameters, +            layered_parameters=layered_parameters, +            dependencies=dependencies, +            field_definitions=field_definitions, +        ) + +        expected_reserved_kwargs = {field_name for field_name in field_definitions if field_name in RESERVED_KWARGS} +        expected_path_parameters = {p for p in param_definitions if p.param_type == ParamType.PATH} +        expected_header_parameters = {p for p in param_definitions if p.param_type == ParamType.HEADER} +        expected_cookie_parameters = {p for p in param_definitions if p.param_type == ParamType.COOKIE} +        expected_query_parameters = {p for p in param_definitions if p.param_type == ParamType.QUERY} +        sequence_query_parameter_names = {p.field_alias for p in expected_query_parameters if p.is_sequence} + +        expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None = None +        expected_msgpack_data: FieldDefinition | None = None +        expected_data_dto: type[AbstractDTO] | None = None +        data_field_definition = field_definitions.get("data") + +        media_type: RequestEncodingType | str | None = None +        if data_field_definition: +            if isinstance(data_field_definition.kwarg_definition, BodyKwarg): +                media_type = data_field_definition.kwarg_definition.media_type + +            if media_type in (RequestEncodingType.MULTI_PART, RequestEncodingType.URL_ENCODED): +                expected_form_data = (media_type, data_field_definition) +                expected_data_dto = signature_model._data_dto +            elif signature_model._data_dto: +                expected_data_dto = signature_model._data_dto +            elif media_type == RequestEncodingType.MESSAGEPACK: +                expected_msgpack_data = data_field_definition + +        for dependency in expected_dependencies: +            dependency_kwargs_model = cls.create_for_signature_model( +                signature_model=dependency.provide.signature_model, +                parsed_signature=parsed_signature, +                dependencies=dependencies, +                path_parameters=path_parameters, +                layered_parameters=layered_parameters, +            ) +            expected_path_parameters = merge_parameter_sets( +                expected_path_parameters, dependency_kwargs_model.expected_path_params +            ) +            expected_query_parameters = merge_parameter_sets( +                expected_query_parameters, dependency_kwargs_model.expected_query_params +            ) +            expected_cookie_parameters = merge_parameter_sets( +                expected_cookie_parameters, dependency_kwargs_model.expected_cookie_params +            ) +            expected_header_parameters = merge_parameter_sets( +                expected_header_parameters, dependency_kwargs_model.expected_header_params +            ) + +            if "data" in expected_reserved_kwargs and "data" in dependency_kwargs_model.expected_reserved_kwargs: +                cls._validate_dependency_data( +                    expected_form_data=expected_form_data, +                    dependency_kwargs_model=dependency_kwargs_model, +                ) + +            expected_reserved_kwargs.update(dependency_kwargs_model.expected_reserved_kwargs) +            sequence_query_parameter_names.update(dependency_kwargs_model.sequence_query_parameter_names) + +        return KwargsModel( +            expected_cookie_params=expected_cookie_parameters, +            expected_dependencies=expected_dependencies, +            expected_data_dto=expected_data_dto, +            expected_form_data=expected_form_data, +            expected_header_params=expected_header_parameters, +            expected_msgpack_data=expected_msgpack_data, +            expected_path_params=expected_path_parameters, +            expected_query_params=expected_query_parameters, +            expected_reserved_kwargs=expected_reserved_kwargs, +            is_data_optional=field_definitions["data"].is_optional if "data" in expected_reserved_kwargs else False, +            sequence_query_parameter_names=sequence_query_parameter_names, +        ) + +    def to_kwargs(self, connection: ASGIConnection) -> dict[str, Any]: +        """Return a dictionary of kwargs. Async values, i.e. CoRoutines, are not resolved to ensure this function is +        sync. + +        Args: +            connection: An instance of :class:`Request <litestar.connection.Request>` or +                :class:`WebSocket <litestar.connection.WebSocket>`. + +        Returns: +            A string keyed dictionary of kwargs expected by the handler function and its dependencies. +        """ +        output: dict[str, Any] = {} + +        for extractor in self.extractors: +            extractor(output, connection) + +        return output + +    async def resolve_dependencies(self, connection: ASGIConnection, kwargs: dict[str, Any]) -> DependencyCleanupGroup: +        """Resolve all dependencies into the kwargs, recursively. + +        Args: +            connection: An instance of :class:`Request <litestar.connection.Request>` or +                :class:`WebSocket <litestar.connection.WebSocket>`. +            kwargs: Kwargs to pass to dependencies. +        """ +        cleanup_group = DependencyCleanupGroup() +        for batch in self.dependency_batches: +            if len(batch) == 1: +                await resolve_dependency(next(iter(batch)), connection, kwargs, cleanup_group) +            else: +                try: +                    async with create_task_group() as task_group: +                        for dependency in batch: +                            task_group.start_soon(resolve_dependency, dependency, connection, kwargs, cleanup_group) +                except _ExceptionGroup as excgroup: +                    raise excgroup.exceptions[0] from excgroup  # type: ignore[attr-defined] + +        return cleanup_group + +    @classmethod +    def _create_dependency_graph(cls, key: str, dependencies: dict[str, Provide]) -> Dependency: +        """Create a graph like structure of dependencies, with each dependency including its own dependencies as a +        list. +        """ +        provide = dependencies[key] +        sub_dependency_keys = [k for k in provide.signature_model._fields if k in dependencies] +        return Dependency( +            key=key, +            provide=provide, +            dependencies=[cls._create_dependency_graph(key=k, dependencies=dependencies) for k in sub_dependency_keys], +        ) + +    @classmethod +    def _validate_dependency_data( +        cls, +        expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None, +        dependency_kwargs_model: KwargsModel, +    ) -> None: +        """Validate that the 'data' kwarg is compatible across dependencies.""" +        if bool(expected_form_data) != bool(dependency_kwargs_model.expected_form_data): +            raise ImproperlyConfiguredException( +                "Dependencies have incompatible 'data' kwarg types: one expects JSON and the other expects form-data" +            ) +        if expected_form_data and dependency_kwargs_model.expected_form_data: +            local_media_type = expected_form_data[0] +            dependency_media_type = dependency_kwargs_model.expected_form_data[0] +            if local_media_type != dependency_media_type: +                raise ImproperlyConfiguredException( +                    "Dependencies have incompatible form-data encoding: one expects url-encoded and the other expects multi-part" +                ) + +    @classmethod +    def _validate_raw_kwargs( +        cls, +        path_parameters: set[str], +        dependencies: dict[str, Provide], +        field_definitions: dict[str, FieldDefinition], +        layered_parameters: dict[str, FieldDefinition], +    ) -> None: +        """Validate that there are no ambiguous kwargs, that is, kwargs declared using the same key in different +        places. +        """ +        dependency_keys = set(dependencies.keys()) + +        parameter_names = { +            *( +                k +                for k, f in field_definitions.items() +                if isinstance(f.kwarg_definition, ParameterKwarg) +                and (f.kwarg_definition.header or f.kwarg_definition.query or f.kwarg_definition.cookie) +            ), +            *list(layered_parameters.keys()), +        } + +        intersection = ( +            path_parameters.intersection(dependency_keys) +            or path_parameters.intersection(parameter_names) +            or dependency_keys.intersection(parameter_names) +        ) +        if intersection: +            raise ImproperlyConfiguredException( +                f"Kwarg resolution ambiguity detected for the following keys: {', '.join(intersection)}. " +                f"Make sure to use distinct keys for your dependencies, path parameters, and aliased parameters." +            ) + +        if used_reserved_kwargs := { +            *parameter_names, +            *path_parameters, +            *dependency_keys, +        }.intersection(RESERVED_KWARGS): +            raise ImproperlyConfiguredException( +                f"Reserved kwargs ({', '.join(RESERVED_KWARGS)}) cannot be used for dependencies and parameter arguments. " +                f"The following kwargs have been used: {', '.join(used_reserved_kwargs)}" +            ) diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/parameter_definition.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/parameter_definition.py new file mode 100644 index 0000000..02b09fc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/parameter_definition.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, NamedTuple + +from litestar.enums import ParamType +from litestar.params import ParameterKwarg + +if TYPE_CHECKING: +    from litestar.typing import FieldDefinition + +__all__ = ("ParameterDefinition", "create_parameter_definition", "merge_parameter_sets") + + +class ParameterDefinition(NamedTuple): +    """Tuple defining a kwarg representing a request parameter.""" + +    default: Any +    field_alias: str +    field_name: str +    is_required: bool +    is_sequence: bool +    param_type: ParamType + + +def create_parameter_definition( +    field_definition: FieldDefinition, +    field_name: str, +    path_parameters: set[str], +) -> ParameterDefinition: +    """Create a ParameterDefinition for the given FieldDefinition. + +    Args: +        field_definition: FieldDefinition instance. +        field_name: The field's name. +        path_parameters: A set of path parameter names. + +    Returns: +        A ParameterDefinition tuple. +    """ +    default = field_definition.default if field_definition.has_default else None +    kwarg_definition = ( +        field_definition.kwarg_definition if isinstance(field_definition.kwarg_definition, ParameterKwarg) else None +    ) + +    field_alias = kwarg_definition.query if kwarg_definition and kwarg_definition.query else field_name +    param_type = ParamType.QUERY + +    if field_name in path_parameters: +        field_alias = field_name +        param_type = ParamType.PATH +    elif kwarg_definition and kwarg_definition.header: +        field_alias = kwarg_definition.header +        param_type = ParamType.HEADER +    elif kwarg_definition and kwarg_definition.cookie: +        field_alias = kwarg_definition.cookie +        param_type = ParamType.COOKIE + +    return ParameterDefinition( +        param_type=param_type, +        field_name=field_name, +        field_alias=field_alias, +        default=default, +        is_required=field_definition.is_required +        and default is None +        and not field_definition.is_optional +        and not field_definition.is_any, +        is_sequence=field_definition.is_non_string_sequence, +    ) + + +def merge_parameter_sets(first: set[ParameterDefinition], second: set[ParameterDefinition]) -> set[ParameterDefinition]: +    """Given two sets of parameter definitions, coming from different dependencies for example, merge them into a single +    set. +    """ +    result: set[ParameterDefinition] = first.intersection(second) +    difference = first.symmetric_difference(second) +    for param in difference: +        # add the param if it's either required or no-other param in difference is the same but required +        if param.is_required or not any(p.field_alias == param.field_alias and p.is_required for p in difference): +            result.add(param) +    return result | 
