summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/_kwargs
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/_kwargs')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_kwargs/__init__.py3
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/__init__.cpython-311.pycbin0 -> 278 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/cleanup.cpython-311.pycbin0 -> 6544 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/dependencies.cpython-311.pycbin0 -> 5733 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/extractors.cpython-311.pycbin0 -> 20001 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/kwargs_model.cpython-311.pycbin0 -> 22080 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/parameter_definition.cpython-311.pycbin0 -> 3656 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_kwargs/cleanup.py117
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_kwargs/dependencies.py119
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py492
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py479
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_kwargs/parameter_definition.py81
12 files changed, 1291 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__init__.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/__init__.py
new file mode 100644
index 0000000..af8ad36
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__init__.py
@@ -0,0 +1,3 @@
+from .kwargs_model import KwargsModel
+
+__all__ = ("KwargsModel",)
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..b61cc1b
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/__init__.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/cleanup.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/cleanup.cpython-311.pyc
new file mode 100644
index 0000000..a084eb5
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/cleanup.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/dependencies.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/dependencies.cpython-311.pyc
new file mode 100644
index 0000000..b0f49a4
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/dependencies.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/extractors.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/extractors.cpython-311.pyc
new file mode 100644
index 0000000..35a6e40
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/extractors.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/kwargs_model.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/kwargs_model.cpython-311.pyc
new file mode 100644
index 0000000..17029de
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/kwargs_model.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/parameter_definition.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/parameter_definition.cpython-311.pyc
new file mode 100644
index 0000000..10753dc
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/__pycache__/parameter_definition.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/cleanup.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/cleanup.py
new file mode 100644
index 0000000..8839d36
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/cleanup.py
@@ -0,0 +1,117 @@
+from __future__ import annotations
+
+from inspect import Traceback, isasyncgen
+from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator
+
+from anyio import create_task_group
+
+from litestar.utils import ensure_async_callable
+from litestar.utils.compat import async_next
+
+__all__ = ("DependencyCleanupGroup",)
+
+
+if TYPE_CHECKING:
+ from litestar.types import AnyGenerator
+
+
+class DependencyCleanupGroup:
+ """Wrapper for generator based dependencies.
+
+ Simplify cleanup by wrapping :func:`next` / :func:`anext` calls and providing facilities to
+ :meth:`throw <generator.throw>` / :meth:`athrow <agen.athrow>` into all generators consecutively. An instance of
+ this class can be used as a contextmanager, which will automatically throw any exceptions into its generators. All
+ exceptions caught in this manner will be re-raised after they have been thrown in the generators.
+ """
+
+ __slots__ = ("_generators", "_closed")
+
+ def __init__(self, generators: list[AnyGenerator] | None = None) -> None:
+ """Initialize ``DependencyCleanupGroup``.
+
+ Args:
+ generators: An optional list of generators to be called at cleanup
+ """
+ self._generators = generators or []
+ self._closed = False
+
+ def add(self, generator: Generator[Any, None, None] | AsyncGenerator[Any, None]) -> None:
+ """Add a new generator to the group.
+
+ Args:
+ generator: The generator to add
+
+ Returns:
+ None
+ """
+ if self._closed:
+ raise RuntimeError("Cannot call cleanup on a closed DependencyCleanupGroup")
+ self._generators.append(generator)
+
+ @staticmethod
+ def _wrap_next(generator: AnyGenerator) -> Callable[[], Awaitable[None]]:
+ if isasyncgen(generator):
+
+ async def wrapped_async() -> None:
+ await async_next(generator, None)
+
+ return wrapped_async
+
+ def wrapped() -> None:
+ next(generator, None) # type: ignore[arg-type]
+
+ return ensure_async_callable(wrapped)
+
+ async def cleanup(self) -> None:
+ """Execute cleanup by calling :func:`next` / :func:`anext` on all generators.
+
+ If there are multiple generators to be called, they will be executed in a :class:`anyio.TaskGroup`.
+
+ Returns:
+ None
+ """
+ if self._closed:
+ raise RuntimeError("Cannot call cleanup on a closed DependencyCleanupGroup")
+
+ self._closed = True
+
+ if not self._generators:
+ return
+
+ if len(self._generators) == 1:
+ await self._wrap_next(self._generators[0])()
+ return
+
+ async with create_task_group() as task_group:
+ for generator in self._generators:
+ task_group.start_soon(self._wrap_next(generator))
+
+ async def __aenter__(self) -> None:
+ """Support the async contextmanager protocol to allow for easier catching and throwing of exceptions into the
+ generators.
+ """
+
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: Traceback | None,
+ ) -> None:
+ """If an exception was raised within the contextmanager block, throw it into all generators."""
+ if exc_val:
+ await self.throw(exc_val)
+
+ async def throw(self, exc: BaseException) -> None:
+ """Throw an exception in all generators sequentially.
+
+ Args:
+ exc: Exception to throw
+ """
+ for gen in self._generators:
+ try:
+ if isasyncgen(gen):
+ await gen.athrow(exc)
+ else:
+ gen.throw(exc) # type: ignore[union-attr]
+ except (StopIteration, StopAsyncIteration):
+ continue
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/dependencies.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/dependencies.py
new file mode 100644
index 0000000..88ffb07
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/dependencies.py
@@ -0,0 +1,119 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+from litestar.utils.compat import async_next
+
+__all__ = ("Dependency", "create_dependency_batches", "map_dependencies_recursively", "resolve_dependency")
+
+
+if TYPE_CHECKING:
+ from litestar._kwargs.cleanup import DependencyCleanupGroup
+ from litestar.connection import ASGIConnection
+ from litestar.di import Provide
+
+
+class Dependency:
+ """Dependency graph of a given combination of ``Route`` + ``RouteHandler``"""
+
+ __slots__ = ("key", "provide", "dependencies")
+
+ def __init__(self, key: str, provide: Provide, dependencies: list[Dependency]) -> None:
+ """Initialize a dependency.
+
+ Args:
+ key: The dependency key
+ provide: Provider
+ dependencies: List of child nodes
+ """
+ self.key = key
+ self.provide = provide
+ self.dependencies = dependencies
+
+ def __eq__(self, other: Any) -> bool:
+ # check if memory address is identical, otherwise compare attributes
+ return other is self or (isinstance(other, self.__class__) and other.key == self.key)
+
+ def __hash__(self) -> int:
+ return hash(self.key)
+
+
+async def resolve_dependency(
+ dependency: Dependency,
+ connection: ASGIConnection,
+ kwargs: dict[str, Any],
+ cleanup_group: DependencyCleanupGroup,
+) -> None:
+ """Resolve a given instance of :class:`Dependency <litestar._kwargs.Dependency>`.
+
+ All required sub dependencies must already
+ be resolved into the kwargs. The result of the dependency will be stored in the kwargs.
+
+ Args:
+ dependency: An instance of :class:`Dependency <litestar._kwargs.Dependency>`
+ connection: An instance of :class:`Request <litestar.connection.Request>` or
+ :class:`WebSocket <litestar.connection.WebSocket>`.
+ kwargs: Any kwargs to pass to the dependency, the result will be stored here as well.
+ cleanup_group: DependencyCleanupGroup to which generators returned by ``dependency`` will be added
+ """
+ signature_model = dependency.provide.signature_model
+ dependency_kwargs = (
+ signature_model.parse_values_from_connection_kwargs(connection=connection, **kwargs)
+ if signature_model._fields
+ else {}
+ )
+ value = await dependency.provide(**dependency_kwargs)
+
+ if dependency.provide.has_sync_generator_dependency:
+ cleanup_group.add(value)
+ value = next(value)
+ elif dependency.provide.has_async_generator_dependency:
+ cleanup_group.add(value)
+ value = await async_next(value)
+
+ kwargs[dependency.key] = value
+
+
+def create_dependency_batches(expected_dependencies: set[Dependency]) -> list[set[Dependency]]:
+ """Calculate batches for all dependencies, recursively.
+
+ Args:
+ expected_dependencies: A set of all direct :class:`Dependencies <litestar._kwargs.Dependency>`.
+
+ Returns:
+ A list of batches.
+ """
+ dependencies_to: dict[Dependency, set[Dependency]] = {}
+ for dependency in expected_dependencies:
+ if dependency not in dependencies_to:
+ map_dependencies_recursively(dependency, dependencies_to)
+
+ batches = []
+ while dependencies_to:
+ current_batch = {
+ dependency
+ for dependency, remaining_sub_dependencies in dependencies_to.items()
+ if not remaining_sub_dependencies
+ }
+
+ for dependency in current_batch:
+ del dependencies_to[dependency]
+ for others_dependencies in dependencies_to.values():
+ others_dependencies.discard(dependency)
+
+ batches.append(current_batch)
+
+ return batches
+
+
+def map_dependencies_recursively(dependency: Dependency, dependencies_to: dict[Dependency, set[Dependency]]) -> None:
+ """Recursively map dependencies to their sub dependencies.
+
+ Args:
+ dependency: The current dependency to map.
+ dependencies_to: A map of dependency to its sub dependencies.
+ """
+ dependencies_to[dependency] = set(dependency.dependencies)
+ for sub in dependency.dependencies:
+ if sub not in dependencies_to:
+ map_dependencies_recursively(sub, dependencies_to)
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py
new file mode 100644
index 0000000..e3b347e
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py
@@ -0,0 +1,492 @@
+from __future__ import annotations
+
+from collections import defaultdict
+from functools import lru_cache, partial
+from typing import TYPE_CHECKING, Any, Callable, Coroutine, Mapping, NamedTuple, cast
+
+from litestar._multipart import parse_multipart_form
+from litestar._parsers import (
+ parse_query_string,
+ parse_url_encoded_form_data,
+)
+from litestar.datastructures import Headers
+from litestar.datastructures.upload_file import UploadFile
+from litestar.datastructures.url import URL
+from litestar.enums import ParamType, RequestEncodingType
+from litestar.exceptions import ValidationException
+from litestar.params import BodyKwarg
+from litestar.types import Empty
+from litestar.utils.predicates import is_non_string_sequence
+from litestar.utils.scope.state import ScopeState
+
+if TYPE_CHECKING:
+ from litestar._kwargs import KwargsModel
+ from litestar._kwargs.parameter_definition import ParameterDefinition
+ from litestar.connection import ASGIConnection, Request
+ from litestar.dto import AbstractDTO
+ from litestar.typing import FieldDefinition
+
+
+__all__ = (
+ "body_extractor",
+ "cookies_extractor",
+ "create_connection_value_extractor",
+ "create_data_extractor",
+ "create_multipart_extractor",
+ "create_query_default_dict",
+ "create_url_encoded_data_extractor",
+ "headers_extractor",
+ "json_extractor",
+ "msgpack_extractor",
+ "parse_connection_headers",
+ "parse_connection_query_params",
+ "query_extractor",
+ "request_extractor",
+ "scope_extractor",
+ "socket_extractor",
+ "state_extractor",
+)
+
+
+class ParamMappings(NamedTuple):
+ alias_and_key_tuples: list[tuple[str, str]]
+ alias_defaults: dict[str, Any]
+ alias_to_param: dict[str, ParameterDefinition]
+
+
+def _create_param_mappings(expected_params: set[ParameterDefinition]) -> ParamMappings:
+ alias_and_key_tuples = []
+ alias_defaults = {}
+ alias_to_params: dict[str, ParameterDefinition] = {}
+ for param in expected_params:
+ alias = param.field_alias
+ if param.param_type == ParamType.HEADER:
+ alias = alias.lower()
+
+ alias_and_key_tuples.append((alias, param.field_name))
+
+ if not (param.is_required or param.default is Ellipsis):
+ alias_defaults[alias] = param.default
+
+ alias_to_params[alias] = param
+
+ return ParamMappings(
+ alias_and_key_tuples=alias_and_key_tuples,
+ alias_defaults=alias_defaults,
+ alias_to_param=alias_to_params,
+ )
+
+
+def create_connection_value_extractor(
+ kwargs_model: KwargsModel,
+ connection_key: str,
+ expected_params: set[ParameterDefinition],
+ parser: Callable[[ASGIConnection, KwargsModel], Mapping[str, Any]] | None = None,
+) -> Callable[[dict[str, Any], ASGIConnection], None]:
+ """Create a kwargs extractor function.
+
+ Args:
+ kwargs_model: The KwargsModel instance.
+ connection_key: The attribute key to use.
+ expected_params: The set of expected params.
+ parser: An optional parser function.
+
+ Returns:
+ An extractor function.
+ """
+
+ alias_and_key_tuples, alias_defaults, alias_to_params = _create_param_mappings(expected_params)
+
+ def extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ data = parser(connection, kwargs_model) if parser else getattr(connection, connection_key, {})
+
+ try:
+ connection_mapping: dict[str, Any] = {
+ key: data[alias] if alias in data else alias_defaults[alias] for alias, key in alias_and_key_tuples
+ }
+ values.update(connection_mapping)
+ except KeyError as e:
+ param = alias_to_params[e.args[0]]
+ path = URL.from_components(
+ path=connection.url.path,
+ query=connection.url.query,
+ )
+ raise ValidationException(
+ f"Missing required {param.param_type.value} parameter {param.field_alias!r} for path {path}"
+ ) from e
+
+ return extractor
+
+
+@lru_cache(1024)
+def create_query_default_dict(
+ parsed_query: tuple[tuple[str, str], ...], sequence_query_parameter_names: tuple[str, ...]
+) -> defaultdict[str, list[str] | str]:
+ """Transform a list of tuples into a default dict. Ensures non-list values are not wrapped in a list.
+
+ Args:
+ parsed_query: The parsed query list of tuples.
+ sequence_query_parameter_names: A set of query parameters that should be wrapped in list.
+
+ Returns:
+ A default dict
+ """
+ output: defaultdict[str, list[str] | str] = defaultdict(list)
+
+ for k, v in parsed_query:
+ if k in sequence_query_parameter_names:
+ output[k].append(v) # type: ignore[union-attr]
+ else:
+ output[k] = v
+
+ return output
+
+
+def parse_connection_query_params(connection: ASGIConnection, kwargs_model: KwargsModel) -> dict[str, Any]:
+ """Parse query params and cache the result in scope.
+
+ Args:
+ connection: The ASGI connection instance.
+ kwargs_model: The KwargsModel instance.
+
+ Returns:
+ A dictionary of parsed values.
+ """
+ parsed_query = (
+ connection._parsed_query
+ if connection._parsed_query is not Empty
+ else parse_query_string(connection.scope.get("query_string", b""))
+ )
+ ScopeState.from_scope(connection.scope).parsed_query = parsed_query
+ return create_query_default_dict(
+ parsed_query=parsed_query,
+ sequence_query_parameter_names=kwargs_model.sequence_query_parameter_names,
+ )
+
+
+def parse_connection_headers(connection: ASGIConnection, _: KwargsModel) -> Headers:
+ """Parse header parameters and cache the result in scope.
+
+ Args:
+ connection: The ASGI connection instance.
+ _: The KwargsModel instance.
+
+ Returns:
+ A Headers instance
+ """
+ return Headers.from_scope(connection.scope)
+
+
+def state_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ """Extract the app state from the connection and insert it to the kwargs injected to the handler.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ None
+ """
+ values["state"] = connection.app.state._state
+
+
+def headers_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ """Extract the headers from the connection and insert them to the kwargs injected to the handler.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ None
+ """
+ # TODO: This should be removed in 3.0 and instead Headers should be injected
+ # directly. We are only keeping this one around to not break things
+ values["headers"] = dict(connection.headers.items())
+
+
+def cookies_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ """Extract the cookies from the connection and insert them to the kwargs injected to the handler.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ None
+ """
+ values["cookies"] = connection.cookies
+
+
+def query_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ """Extract the query params from the connection and insert them to the kwargs injected to the handler.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ None
+ """
+ values["query"] = connection.query_params
+
+
+def scope_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ """Extract the scope from the connection and insert it into the kwargs injected to the handler.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ None
+ """
+ values["scope"] = connection.scope
+
+
+def request_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ """Set the connection instance as the 'request' value in the kwargs injected to the handler.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ None
+ """
+ values["request"] = connection
+
+
+def socket_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ """Set the connection instance as the 'socket' value in the kwargs injected to the handler.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ None
+ """
+ values["socket"] = connection
+
+
+def body_extractor(
+ values: dict[str, Any],
+ connection: Request[Any, Any, Any],
+) -> None:
+ """Extract the body from the request instance.
+
+ Notes:
+ - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ The Body value.
+ """
+ values["body"] = connection.body()
+
+
+async def json_extractor(connection: Request[Any, Any, Any]) -> Any:
+ """Extract the data from request and insert it into the kwargs injected to the handler.
+
+ Notes:
+ - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage.
+
+ Args:
+ connection: The ASGI connection instance.
+
+ Returns:
+ The JSON value.
+ """
+ if not await connection.body():
+ return Empty
+ return await connection.json()
+
+
+async def msgpack_extractor(connection: Request[Any, Any, Any]) -> Any:
+ """Extract the data from request and insert it into the kwargs injected to the handler.
+
+ Notes:
+ - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage.
+
+ Args:
+ connection: The ASGI connection instance.
+
+ Returns:
+ The MessagePack value.
+ """
+ if not await connection.body():
+ return Empty
+ return await connection.msgpack()
+
+
+async def _extract_multipart(
+ connection: Request[Any, Any, Any],
+ body_kwarg_multipart_form_part_limit: int | None,
+ field_definition: FieldDefinition,
+ is_data_optional: bool,
+ data_dto: type[AbstractDTO] | None,
+) -> Any:
+ multipart_form_part_limit = (
+ body_kwarg_multipart_form_part_limit
+ if body_kwarg_multipart_form_part_limit is not None
+ else connection.app.multipart_form_part_limit
+ )
+ connection.scope["_form"] = form_values = ( # type: ignore[typeddict-unknown-key]
+ connection.scope["_form"] # type: ignore[typeddict-item]
+ if "_form" in connection.scope
+ else parse_multipart_form(
+ body=await connection.body(),
+ boundary=connection.content_type[-1].get("boundary", "").encode(),
+ multipart_form_part_limit=multipart_form_part_limit,
+ type_decoders=connection.route_handler.resolve_type_decoders(),
+ )
+ )
+
+ if field_definition.is_non_string_sequence:
+ values = list(form_values.values())
+ if field_definition.has_inner_subclass_of(UploadFile) and isinstance(values[0], list):
+ return values[0]
+
+ return values
+
+ if field_definition.is_simple_type and field_definition.annotation is UploadFile and form_values:
+ return next(v for v in form_values.values() if isinstance(v, UploadFile))
+
+ if not form_values and is_data_optional:
+ return None
+
+ if data_dto:
+ return data_dto(connection).decode_builtins(form_values)
+
+ for name, tp in field_definition.get_type_hints().items():
+ value = form_values.get(name)
+ if value is not None and is_non_string_sequence(tp) and not isinstance(value, list):
+ form_values[name] = [value]
+
+ return form_values
+
+
+def create_multipart_extractor(
+ field_definition: FieldDefinition, is_data_optional: bool, data_dto: type[AbstractDTO] | None
+) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]:
+ """Create a multipart form-data extractor.
+
+ Args:
+ field_definition: A FieldDefinition instance.
+ is_data_optional: Boolean dictating whether the field is optional.
+ data_dto: A data DTO type, if configured for handler.
+
+ Returns:
+ An extractor function.
+ """
+ body_kwarg_multipart_form_part_limit: int | None = None
+ if field_definition.kwarg_definition and isinstance(field_definition.kwarg_definition, BodyKwarg):
+ body_kwarg_multipart_form_part_limit = field_definition.kwarg_definition.multipart_form_part_limit
+
+ extract_multipart = partial(
+ _extract_multipart,
+ body_kwarg_multipart_form_part_limit=body_kwarg_multipart_form_part_limit,
+ is_data_optional=is_data_optional,
+ data_dto=data_dto,
+ field_definition=field_definition,
+ )
+
+ return cast("Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", extract_multipart)
+
+
+def create_url_encoded_data_extractor(
+ is_data_optional: bool, data_dto: type[AbstractDTO] | None
+) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]:
+ """Create extractor for url encoded form-data.
+
+ Args:
+ is_data_optional: Boolean dictating whether the field is optional.
+ data_dto: A data DTO type, if configured for handler.
+
+ Returns:
+ An extractor function.
+ """
+
+ async def extract_url_encoded_extractor(
+ connection: Request[Any, Any, Any],
+ ) -> Any:
+ connection.scope["_form"] = form_values = ( # type: ignore[typeddict-unknown-key]
+ connection.scope["_form"] # type: ignore[typeddict-item]
+ if "_form" in connection.scope
+ else parse_url_encoded_form_data(await connection.body())
+ )
+
+ if not form_values and is_data_optional:
+ return None
+
+ return data_dto(connection).decode_builtins(form_values) if data_dto else form_values
+
+ return cast(
+ "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", extract_url_encoded_extractor
+ )
+
+
+def create_data_extractor(kwargs_model: KwargsModel) -> Callable[[dict[str, Any], ASGIConnection], None]:
+ """Create an extractor for a request's body.
+
+ Args:
+ kwargs_model: The KwargsModel instance.
+
+ Returns:
+ An extractor for the request's body.
+ """
+
+ if kwargs_model.expected_form_data:
+ media_type, field_definition = kwargs_model.expected_form_data
+
+ if media_type == RequestEncodingType.MULTI_PART:
+ data_extractor = create_multipart_extractor(
+ field_definition=field_definition,
+ is_data_optional=kwargs_model.is_data_optional,
+ data_dto=kwargs_model.expected_data_dto,
+ )
+ else:
+ data_extractor = create_url_encoded_data_extractor(
+ is_data_optional=kwargs_model.is_data_optional,
+ data_dto=kwargs_model.expected_data_dto,
+ )
+ elif kwargs_model.expected_msgpack_data:
+ data_extractor = cast(
+ "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", msgpack_extractor
+ )
+ elif kwargs_model.expected_data_dto:
+ data_extractor = create_dto_extractor(data_dto=kwargs_model.expected_data_dto)
+ else:
+ data_extractor = cast(
+ "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", json_extractor
+ )
+
+ def extractor(
+ values: dict[str, Any],
+ connection: ASGIConnection[Any, Any, Any, Any],
+ ) -> None:
+ values["data"] = data_extractor(connection)
+
+ return extractor
+
+
+def create_dto_extractor(
+ data_dto: type[AbstractDTO],
+) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]:
+ """Create a DTO data extractor.
+
+
+ Returns:
+ An extractor function.
+ """
+
+ async def dto_extractor(connection: Request[Any, Any, Any]) -> Any:
+ if not (body := await connection.body()):
+ return Empty
+ return data_dto(connection).decode_bytes(body)
+
+ return dto_extractor # type:ignore[return-value]
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py
new file mode 100644
index 0000000..01ed2e5
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/kwargs_model.py
@@ -0,0 +1,479 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Callable
+
+from anyio import create_task_group
+
+from litestar._kwargs.cleanup import DependencyCleanupGroup
+from litestar._kwargs.dependencies import (
+ Dependency,
+ create_dependency_batches,
+ resolve_dependency,
+)
+from litestar._kwargs.extractors import (
+ body_extractor,
+ cookies_extractor,
+ create_connection_value_extractor,
+ create_data_extractor,
+ headers_extractor,
+ parse_connection_headers,
+ parse_connection_query_params,
+ query_extractor,
+ request_extractor,
+ scope_extractor,
+ socket_extractor,
+ state_extractor,
+)
+from litestar._kwargs.parameter_definition import (
+ ParameterDefinition,
+ create_parameter_definition,
+ merge_parameter_sets,
+)
+from litestar.constants import RESERVED_KWARGS
+from litestar.enums import ParamType, RequestEncodingType
+from litestar.exceptions import ImproperlyConfiguredException
+from litestar.params import BodyKwarg, ParameterKwarg
+from litestar.typing import FieldDefinition
+from litestar.utils.helpers import get_exception_group
+
+__all__ = ("KwargsModel",)
+
+
+if TYPE_CHECKING:
+ from litestar._signature import SignatureModel
+ from litestar.connection import ASGIConnection
+ from litestar.di import Provide
+ from litestar.dto import AbstractDTO
+ from litestar.utils.signature import ParsedSignature
+
+_ExceptionGroup = get_exception_group()
+
+
+class KwargsModel:
+ """Model required kwargs for a given RouteHandler and its dependencies.
+
+ This is done once and is memoized during application bootstrap, ensuring minimal runtime overhead.
+ """
+
+ __slots__ = (
+ "dependency_batches",
+ "expected_cookie_params",
+ "expected_data_dto",
+ "expected_form_data",
+ "expected_header_params",
+ "expected_msgpack_data",
+ "expected_path_params",
+ "expected_query_params",
+ "expected_reserved_kwargs",
+ "extractors",
+ "has_kwargs",
+ "is_data_optional",
+ "sequence_query_parameter_names",
+ )
+
+ def __init__(
+ self,
+ *,
+ expected_cookie_params: set[ParameterDefinition],
+ expected_data_dto: type[AbstractDTO] | None,
+ expected_dependencies: set[Dependency],
+ expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None,
+ expected_header_params: set[ParameterDefinition],
+ expected_msgpack_data: FieldDefinition | None,
+ expected_path_params: set[ParameterDefinition],
+ expected_query_params: set[ParameterDefinition],
+ expected_reserved_kwargs: set[str],
+ is_data_optional: bool,
+ sequence_query_parameter_names: set[str],
+ ) -> None:
+ """Initialize ``KwargsModel``.
+
+ Args:
+ expected_cookie_params: Any expected cookie parameter kwargs
+ expected_dependencies: Any expected dependency kwargs
+ expected_form_data: Any expected form data kwargs
+ expected_header_params: Any expected header parameter kwargs
+ expected_msgpack_data: Any expected MessagePack data kwargs
+ expected_path_params: Any expected path parameter kwargs
+ expected_query_params: Any expected query parameter kwargs
+ expected_reserved_kwargs: Any expected reserved kwargs, e.g. 'state'
+ expected_data_dto: A data DTO, if defined
+ is_data_optional: Treat data as optional
+ sequence_query_parameter_names: Any query parameters that are sequences
+ """
+ self.expected_cookie_params = expected_cookie_params
+ self.expected_form_data = expected_form_data
+ self.expected_header_params = expected_header_params
+ self.expected_msgpack_data = expected_msgpack_data
+ self.expected_path_params = expected_path_params
+ self.expected_query_params = expected_query_params
+ self.expected_reserved_kwargs = expected_reserved_kwargs
+ self.expected_data_dto = expected_data_dto
+ self.sequence_query_parameter_names = tuple(sequence_query_parameter_names)
+
+ self.has_kwargs = (
+ expected_cookie_params
+ or expected_dependencies
+ or expected_form_data
+ or expected_msgpack_data
+ or expected_header_params
+ or expected_path_params
+ or expected_query_params
+ or expected_reserved_kwargs
+ or expected_data_dto
+ )
+
+ self.is_data_optional = is_data_optional
+ self.extractors = self._create_extractors()
+ self.dependency_batches = create_dependency_batches(expected_dependencies)
+
+ def _create_extractors(self) -> list[Callable[[dict[str, Any], ASGIConnection], None]]:
+ reserved_kwargs_extractors: dict[str, Callable[[dict[str, Any], ASGIConnection], None]] = {
+ "data": create_data_extractor(self),
+ "state": state_extractor,
+ "scope": scope_extractor,
+ "request": request_extractor,
+ "socket": socket_extractor,
+ "headers": headers_extractor,
+ "cookies": cookies_extractor,
+ "query": query_extractor,
+ "body": body_extractor, # type: ignore[dict-item]
+ }
+
+ extractors: list[Callable[[dict[str, Any], ASGIConnection], None]] = [
+ reserved_kwargs_extractors[reserved_kwarg] for reserved_kwarg in self.expected_reserved_kwargs
+ ]
+
+ if self.expected_header_params:
+ extractors.append(
+ create_connection_value_extractor(
+ connection_key="headers",
+ expected_params=self.expected_header_params,
+ kwargs_model=self,
+ parser=parse_connection_headers,
+ ),
+ )
+
+ if self.expected_path_params:
+ extractors.append(
+ create_connection_value_extractor(
+ connection_key="path_params",
+ expected_params=self.expected_path_params,
+ kwargs_model=self,
+ ),
+ )
+
+ if self.expected_cookie_params:
+ extractors.append(
+ create_connection_value_extractor(
+ connection_key="cookies",
+ expected_params=self.expected_cookie_params,
+ kwargs_model=self,
+ ),
+ )
+
+ if self.expected_query_params:
+ extractors.append(
+ create_connection_value_extractor(
+ connection_key="query_params",
+ expected_params=self.expected_query_params,
+ kwargs_model=self,
+ parser=parse_connection_query_params,
+ ),
+ )
+ return extractors
+
+ @classmethod
+ def _get_param_definitions(
+ cls,
+ path_parameters: set[str],
+ layered_parameters: dict[str, FieldDefinition],
+ dependencies: dict[str, Provide],
+ field_definitions: dict[str, FieldDefinition],
+ ) -> tuple[set[ParameterDefinition], set[Dependency]]:
+ """Get parameter_definitions for the construction of KwargsModel instance.
+
+ Args:
+ path_parameters: Any expected path parameters.
+ layered_parameters: A string keyed dictionary of layered parameters.
+ dependencies: A string keyed dictionary mapping dependency providers.
+ field_definitions: The SignatureModel fields.
+
+ Returns:
+ A Tuple of sets
+ """
+ expected_dependencies = {
+ cls._create_dependency_graph(key=key, dependencies=dependencies)
+ for key in dependencies
+ if key in field_definitions
+ }
+ ignored_keys = {*RESERVED_KWARGS, *(dependency.key for dependency in expected_dependencies)}
+
+ param_definitions = {
+ *(
+ create_parameter_definition(
+ field_definition=field_definition,
+ field_name=field_name,
+ path_parameters=path_parameters,
+ )
+ for field_name, field_definition in layered_parameters.items()
+ if field_name not in ignored_keys and field_name not in field_definitions
+ ),
+ *(
+ create_parameter_definition(
+ field_definition=field_definition,
+ field_name=field_name,
+ path_parameters=path_parameters,
+ )
+ for field_name, field_definition in field_definitions.items()
+ if field_name not in ignored_keys and field_name not in layered_parameters
+ ),
+ }
+
+ for field_name, field_definition in (
+ (k, v) for k, v in field_definitions.items() if k not in ignored_keys and k in layered_parameters
+ ):
+ layered_parameter = layered_parameters[field_name]
+ field = field_definition if field_definition.is_parameter_field else layered_parameter
+ default = field_definition.default if field_definition.has_default else layered_parameter.default
+
+ param_definitions.add(
+ create_parameter_definition(
+ field_definition=FieldDefinition.from_kwarg(
+ name=field.name,
+ default=default,
+ inner_types=field.inner_types,
+ annotation=field.annotation,
+ kwarg_definition=field.kwarg_definition,
+ extra=field.extra,
+ ),
+ field_name=field_name,
+ path_parameters=path_parameters,
+ )
+ )
+
+ return param_definitions, expected_dependencies
+
+ @classmethod
+ def create_for_signature_model(
+ cls,
+ signature_model: type[SignatureModel],
+ parsed_signature: ParsedSignature,
+ dependencies: dict[str, Provide],
+ path_parameters: set[str],
+ layered_parameters: dict[str, FieldDefinition],
+ ) -> KwargsModel:
+ """Pre-determine what parameters are required for a given combination of route + route handler. It is executed
+ during the application bootstrap process.
+
+ Args:
+ signature_model: A :class:`SignatureModel <litestar._signature.SignatureModel>` subclass.
+ parsed_signature: A :class:`ParsedSignature <litestar._signature.ParsedSignature>` instance.
+ dependencies: A string keyed dictionary mapping dependency providers.
+ path_parameters: Any expected path parameters.
+ layered_parameters: A string keyed dictionary of layered parameters.
+
+ Returns:
+ An instance of KwargsModel
+ """
+
+ field_definitions = signature_model._fields
+
+ cls._validate_raw_kwargs(
+ path_parameters=path_parameters,
+ dependencies=dependencies,
+ field_definitions=field_definitions,
+ layered_parameters=layered_parameters,
+ )
+
+ param_definitions, expected_dependencies = cls._get_param_definitions(
+ path_parameters=path_parameters,
+ layered_parameters=layered_parameters,
+ dependencies=dependencies,
+ field_definitions=field_definitions,
+ )
+
+ expected_reserved_kwargs = {field_name for field_name in field_definitions if field_name in RESERVED_KWARGS}
+ expected_path_parameters = {p for p in param_definitions if p.param_type == ParamType.PATH}
+ expected_header_parameters = {p for p in param_definitions if p.param_type == ParamType.HEADER}
+ expected_cookie_parameters = {p for p in param_definitions if p.param_type == ParamType.COOKIE}
+ expected_query_parameters = {p for p in param_definitions if p.param_type == ParamType.QUERY}
+ sequence_query_parameter_names = {p.field_alias for p in expected_query_parameters if p.is_sequence}
+
+ expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None = None
+ expected_msgpack_data: FieldDefinition | None = None
+ expected_data_dto: type[AbstractDTO] | None = None
+ data_field_definition = field_definitions.get("data")
+
+ media_type: RequestEncodingType | str | None = None
+ if data_field_definition:
+ if isinstance(data_field_definition.kwarg_definition, BodyKwarg):
+ media_type = data_field_definition.kwarg_definition.media_type
+
+ if media_type in (RequestEncodingType.MULTI_PART, RequestEncodingType.URL_ENCODED):
+ expected_form_data = (media_type, data_field_definition)
+ expected_data_dto = signature_model._data_dto
+ elif signature_model._data_dto:
+ expected_data_dto = signature_model._data_dto
+ elif media_type == RequestEncodingType.MESSAGEPACK:
+ expected_msgpack_data = data_field_definition
+
+ for dependency in expected_dependencies:
+ dependency_kwargs_model = cls.create_for_signature_model(
+ signature_model=dependency.provide.signature_model,
+ parsed_signature=parsed_signature,
+ dependencies=dependencies,
+ path_parameters=path_parameters,
+ layered_parameters=layered_parameters,
+ )
+ expected_path_parameters = merge_parameter_sets(
+ expected_path_parameters, dependency_kwargs_model.expected_path_params
+ )
+ expected_query_parameters = merge_parameter_sets(
+ expected_query_parameters, dependency_kwargs_model.expected_query_params
+ )
+ expected_cookie_parameters = merge_parameter_sets(
+ expected_cookie_parameters, dependency_kwargs_model.expected_cookie_params
+ )
+ expected_header_parameters = merge_parameter_sets(
+ expected_header_parameters, dependency_kwargs_model.expected_header_params
+ )
+
+ if "data" in expected_reserved_kwargs and "data" in dependency_kwargs_model.expected_reserved_kwargs:
+ cls._validate_dependency_data(
+ expected_form_data=expected_form_data,
+ dependency_kwargs_model=dependency_kwargs_model,
+ )
+
+ expected_reserved_kwargs.update(dependency_kwargs_model.expected_reserved_kwargs)
+ sequence_query_parameter_names.update(dependency_kwargs_model.sequence_query_parameter_names)
+
+ return KwargsModel(
+ expected_cookie_params=expected_cookie_parameters,
+ expected_dependencies=expected_dependencies,
+ expected_data_dto=expected_data_dto,
+ expected_form_data=expected_form_data,
+ expected_header_params=expected_header_parameters,
+ expected_msgpack_data=expected_msgpack_data,
+ expected_path_params=expected_path_parameters,
+ expected_query_params=expected_query_parameters,
+ expected_reserved_kwargs=expected_reserved_kwargs,
+ is_data_optional=field_definitions["data"].is_optional if "data" in expected_reserved_kwargs else False,
+ sequence_query_parameter_names=sequence_query_parameter_names,
+ )
+
+ def to_kwargs(self, connection: ASGIConnection) -> dict[str, Any]:
+ """Return a dictionary of kwargs. Async values, i.e. CoRoutines, are not resolved to ensure this function is
+ sync.
+
+ Args:
+ connection: An instance of :class:`Request <litestar.connection.Request>` or
+ :class:`WebSocket <litestar.connection.WebSocket>`.
+
+ Returns:
+ A string keyed dictionary of kwargs expected by the handler function and its dependencies.
+ """
+ output: dict[str, Any] = {}
+
+ for extractor in self.extractors:
+ extractor(output, connection)
+
+ return output
+
+ async def resolve_dependencies(self, connection: ASGIConnection, kwargs: dict[str, Any]) -> DependencyCleanupGroup:
+ """Resolve all dependencies into the kwargs, recursively.
+
+ Args:
+ connection: An instance of :class:`Request <litestar.connection.Request>` or
+ :class:`WebSocket <litestar.connection.WebSocket>`.
+ kwargs: Kwargs to pass to dependencies.
+ """
+ cleanup_group = DependencyCleanupGroup()
+ for batch in self.dependency_batches:
+ if len(batch) == 1:
+ await resolve_dependency(next(iter(batch)), connection, kwargs, cleanup_group)
+ else:
+ try:
+ async with create_task_group() as task_group:
+ for dependency in batch:
+ task_group.start_soon(resolve_dependency, dependency, connection, kwargs, cleanup_group)
+ except _ExceptionGroup as excgroup:
+ raise excgroup.exceptions[0] from excgroup # type: ignore[attr-defined]
+
+ return cleanup_group
+
+ @classmethod
+ def _create_dependency_graph(cls, key: str, dependencies: dict[str, Provide]) -> Dependency:
+ """Create a graph like structure of dependencies, with each dependency including its own dependencies as a
+ list.
+ """
+ provide = dependencies[key]
+ sub_dependency_keys = [k for k in provide.signature_model._fields if k in dependencies]
+ return Dependency(
+ key=key,
+ provide=provide,
+ dependencies=[cls._create_dependency_graph(key=k, dependencies=dependencies) for k in sub_dependency_keys],
+ )
+
+ @classmethod
+ def _validate_dependency_data(
+ cls,
+ expected_form_data: tuple[RequestEncodingType | str, FieldDefinition] | None,
+ dependency_kwargs_model: KwargsModel,
+ ) -> None:
+ """Validate that the 'data' kwarg is compatible across dependencies."""
+ if bool(expected_form_data) != bool(dependency_kwargs_model.expected_form_data):
+ raise ImproperlyConfiguredException(
+ "Dependencies have incompatible 'data' kwarg types: one expects JSON and the other expects form-data"
+ )
+ if expected_form_data and dependency_kwargs_model.expected_form_data:
+ local_media_type = expected_form_data[0]
+ dependency_media_type = dependency_kwargs_model.expected_form_data[0]
+ if local_media_type != dependency_media_type:
+ raise ImproperlyConfiguredException(
+ "Dependencies have incompatible form-data encoding: one expects url-encoded and the other expects multi-part"
+ )
+
+ @classmethod
+ def _validate_raw_kwargs(
+ cls,
+ path_parameters: set[str],
+ dependencies: dict[str, Provide],
+ field_definitions: dict[str, FieldDefinition],
+ layered_parameters: dict[str, FieldDefinition],
+ ) -> None:
+ """Validate that there are no ambiguous kwargs, that is, kwargs declared using the same key in different
+ places.
+ """
+ dependency_keys = set(dependencies.keys())
+
+ parameter_names = {
+ *(
+ k
+ for k, f in field_definitions.items()
+ if isinstance(f.kwarg_definition, ParameterKwarg)
+ and (f.kwarg_definition.header or f.kwarg_definition.query or f.kwarg_definition.cookie)
+ ),
+ *list(layered_parameters.keys()),
+ }
+
+ intersection = (
+ path_parameters.intersection(dependency_keys)
+ or path_parameters.intersection(parameter_names)
+ or dependency_keys.intersection(parameter_names)
+ )
+ if intersection:
+ raise ImproperlyConfiguredException(
+ f"Kwarg resolution ambiguity detected for the following keys: {', '.join(intersection)}. "
+ f"Make sure to use distinct keys for your dependencies, path parameters, and aliased parameters."
+ )
+
+ if used_reserved_kwargs := {
+ *parameter_names,
+ *path_parameters,
+ *dependency_keys,
+ }.intersection(RESERVED_KWARGS):
+ raise ImproperlyConfiguredException(
+ f"Reserved kwargs ({', '.join(RESERVED_KWARGS)}) cannot be used for dependencies and parameter arguments. "
+ f"The following kwargs have been used: {', '.join(used_reserved_kwargs)}"
+ )
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/parameter_definition.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/parameter_definition.py
new file mode 100644
index 0000000..02b09fc
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/parameter_definition.py
@@ -0,0 +1,81 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, NamedTuple
+
+from litestar.enums import ParamType
+from litestar.params import ParameterKwarg
+
+if TYPE_CHECKING:
+ from litestar.typing import FieldDefinition
+
+__all__ = ("ParameterDefinition", "create_parameter_definition", "merge_parameter_sets")
+
+
+class ParameterDefinition(NamedTuple):
+ """Tuple defining a kwarg representing a request parameter."""
+
+ default: Any
+ field_alias: str
+ field_name: str
+ is_required: bool
+ is_sequence: bool
+ param_type: ParamType
+
+
+def create_parameter_definition(
+ field_definition: FieldDefinition,
+ field_name: str,
+ path_parameters: set[str],
+) -> ParameterDefinition:
+ """Create a ParameterDefinition for the given FieldDefinition.
+
+ Args:
+ field_definition: FieldDefinition instance.
+ field_name: The field's name.
+ path_parameters: A set of path parameter names.
+
+ Returns:
+ A ParameterDefinition tuple.
+ """
+ default = field_definition.default if field_definition.has_default else None
+ kwarg_definition = (
+ field_definition.kwarg_definition if isinstance(field_definition.kwarg_definition, ParameterKwarg) else None
+ )
+
+ field_alias = kwarg_definition.query if kwarg_definition and kwarg_definition.query else field_name
+ param_type = ParamType.QUERY
+
+ if field_name in path_parameters:
+ field_alias = field_name
+ param_type = ParamType.PATH
+ elif kwarg_definition and kwarg_definition.header:
+ field_alias = kwarg_definition.header
+ param_type = ParamType.HEADER
+ elif kwarg_definition and kwarg_definition.cookie:
+ field_alias = kwarg_definition.cookie
+ param_type = ParamType.COOKIE
+
+ return ParameterDefinition(
+ param_type=param_type,
+ field_name=field_name,
+ field_alias=field_alias,
+ default=default,
+ is_required=field_definition.is_required
+ and default is None
+ and not field_definition.is_optional
+ and not field_definition.is_any,
+ is_sequence=field_definition.is_non_string_sequence,
+ )
+
+
+def merge_parameter_sets(first: set[ParameterDefinition], second: set[ParameterDefinition]) -> set[ParameterDefinition]:
+ """Given two sets of parameter definitions, coming from different dependencies for example, merge them into a single
+ set.
+ """
+ result: set[ParameterDefinition] = first.intersection(second)
+ difference = first.symmetric_difference(second)
+ for param in difference:
+ # add the param if it's either required or no-other param in difference is the same but required
+ if param.is_required or not any(p.field_alias == param.field_alias and p.is_required for p in difference):
+ result.add(param)
+ return result