summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py492
1 files changed, 492 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py
new file mode 100644
index 0000000..e3b347e
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py
@@ -0,0 +1,492 @@
+from __future__ import annotations
+
+from collections import defaultdict
+from functools import lru_cache, partial
+from typing import TYPE_CHECKING, Any, Callable, Coroutine, Mapping, NamedTuple, cast
+
+from litestar._multipart import parse_multipart_form
+from litestar._parsers import (
+ parse_query_string,
+ parse_url_encoded_form_data,
+)
+from litestar.datastructures import Headers
+from litestar.datastructures.upload_file import UploadFile
+from litestar.datastructures.url import URL
+from litestar.enums import ParamType, RequestEncodingType
+from litestar.exceptions import ValidationException
+from litestar.params import BodyKwarg
+from litestar.types import Empty
+from litestar.utils.predicates import is_non_string_sequence
+from litestar.utils.scope.state import ScopeState
+
+if TYPE_CHECKING:
+ from litestar._kwargs import KwargsModel
+ from litestar._kwargs.parameter_definition import ParameterDefinition
+ from litestar.connection import ASGIConnection, Request
+ from litestar.dto import AbstractDTO
+ from litestar.typing import FieldDefinition
+
+
+__all__ = (
+ "body_extractor",
+ "cookies_extractor",
+ "create_connection_value_extractor",
+ "create_data_extractor",
+ "create_multipart_extractor",
+ "create_query_default_dict",
+ "create_url_encoded_data_extractor",
+ "headers_extractor",
+ "json_extractor",
+ "msgpack_extractor",
+ "parse_connection_headers",
+ "parse_connection_query_params",
+ "query_extractor",
+ "request_extractor",
+ "scope_extractor",
+ "socket_extractor",
+ "state_extractor",
+)
+
+
+class ParamMappings(NamedTuple):
+ alias_and_key_tuples: list[tuple[str, str]]
+ alias_defaults: dict[str, Any]
+ alias_to_param: dict[str, ParameterDefinition]
+
+
+def _create_param_mappings(expected_params: set[ParameterDefinition]) -> ParamMappings:
+ alias_and_key_tuples = []
+ alias_defaults = {}
+ alias_to_params: dict[str, ParameterDefinition] = {}
+ for param in expected_params:
+ alias = param.field_alias
+ if param.param_type == ParamType.HEADER:
+ alias = alias.lower()
+
+ alias_and_key_tuples.append((alias, param.field_name))
+
+ if not (param.is_required or param.default is Ellipsis):
+ alias_defaults[alias] = param.default
+
+ alias_to_params[alias] = param
+
+ return ParamMappings(
+ alias_and_key_tuples=alias_and_key_tuples,
+ alias_defaults=alias_defaults,
+ alias_to_param=alias_to_params,
+ )
+
+
+def create_connection_value_extractor(
+ kwargs_model: KwargsModel,
+ connection_key: str,
+ expected_params: set[ParameterDefinition],
+ parser: Callable[[ASGIConnection, KwargsModel], Mapping[str, Any]] | None = None,
+) -> Callable[[dict[str, Any], ASGIConnection], None]:
+ """Create a kwargs extractor function.
+
+ Args:
+ kwargs_model: The KwargsModel instance.
+ connection_key: The attribute key to use.
+ expected_params: The set of expected params.
+ parser: An optional parser function.
+
+ Returns:
+ An extractor function.
+ """
+
+ alias_and_key_tuples, alias_defaults, alias_to_params = _create_param_mappings(expected_params)
+
+ def extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ data = parser(connection, kwargs_model) if parser else getattr(connection, connection_key, {})
+
+ try:
+ connection_mapping: dict[str, Any] = {
+ key: data[alias] if alias in data else alias_defaults[alias] for alias, key in alias_and_key_tuples
+ }
+ values.update(connection_mapping)
+ except KeyError as e:
+ param = alias_to_params[e.args[0]]
+ path = URL.from_components(
+ path=connection.url.path,
+ query=connection.url.query,
+ )
+ raise ValidationException(
+ f"Missing required {param.param_type.value} parameter {param.field_alias!r} for path {path}"
+ ) from e
+
+ return extractor
+
+
+@lru_cache(1024)
+def create_query_default_dict(
+ parsed_query: tuple[tuple[str, str], ...], sequence_query_parameter_names: tuple[str, ...]
+) -> defaultdict[str, list[str] | str]:
+ """Transform a list of tuples into a default dict. Ensures non-list values are not wrapped in a list.
+
+ Args:
+ parsed_query: The parsed query list of tuples.
+ sequence_query_parameter_names: A set of query parameters that should be wrapped in list.
+
+ Returns:
+ A default dict
+ """
+ output: defaultdict[str, list[str] | str] = defaultdict(list)
+
+ for k, v in parsed_query:
+ if k in sequence_query_parameter_names:
+ output[k].append(v) # type: ignore[union-attr]
+ else:
+ output[k] = v
+
+ return output
+
+
+def parse_connection_query_params(connection: ASGIConnection, kwargs_model: KwargsModel) -> dict[str, Any]:
+ """Parse query params and cache the result in scope.
+
+ Args:
+ connection: The ASGI connection instance.
+ kwargs_model: The KwargsModel instance.
+
+ Returns:
+ A dictionary of parsed values.
+ """
+ parsed_query = (
+ connection._parsed_query
+ if connection._parsed_query is not Empty
+ else parse_query_string(connection.scope.get("query_string", b""))
+ )
+ ScopeState.from_scope(connection.scope).parsed_query = parsed_query
+ return create_query_default_dict(
+ parsed_query=parsed_query,
+ sequence_query_parameter_names=kwargs_model.sequence_query_parameter_names,
+ )
+
+
+def parse_connection_headers(connection: ASGIConnection, _: KwargsModel) -> Headers:
+ """Parse header parameters and cache the result in scope.
+
+ Args:
+ connection: The ASGI connection instance.
+ _: The KwargsModel instance.
+
+ Returns:
+ A Headers instance
+ """
+ return Headers.from_scope(connection.scope)
+
+
+def state_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ """Extract the app state from the connection and insert it to the kwargs injected to the handler.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ None
+ """
+ values["state"] = connection.app.state._state
+
+
+def headers_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ """Extract the headers from the connection and insert them to the kwargs injected to the handler.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ None
+ """
+ # TODO: This should be removed in 3.0 and instead Headers should be injected
+ # directly. We are only keeping this one around to not break things
+ values["headers"] = dict(connection.headers.items())
+
+
+def cookies_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ """Extract the cookies from the connection and insert them to the kwargs injected to the handler.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ None
+ """
+ values["cookies"] = connection.cookies
+
+
+def query_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ """Extract the query params from the connection and insert them to the kwargs injected to the handler.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ None
+ """
+ values["query"] = connection.query_params
+
+
+def scope_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ """Extract the scope from the connection and insert it into the kwargs injected to the handler.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ None
+ """
+ values["scope"] = connection.scope
+
+
+def request_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ """Set the connection instance as the 'request' value in the kwargs injected to the handler.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ None
+ """
+ values["request"] = connection
+
+
+def socket_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
+ """Set the connection instance as the 'socket' value in the kwargs injected to the handler.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ None
+ """
+ values["socket"] = connection
+
+
+def body_extractor(
+ values: dict[str, Any],
+ connection: Request[Any, Any, Any],
+) -> None:
+ """Extract the body from the request instance.
+
+ Notes:
+ - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage.
+
+ Args:
+ connection: The ASGI connection instance.
+ values: The kwargs that are extracted from the connection and will be injected into the handler.
+
+ Returns:
+ The Body value.
+ """
+ values["body"] = connection.body()
+
+
+async def json_extractor(connection: Request[Any, Any, Any]) -> Any:
+ """Extract the data from request and insert it into the kwargs injected to the handler.
+
+ Notes:
+ - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage.
+
+ Args:
+ connection: The ASGI connection instance.
+
+ Returns:
+ The JSON value.
+ """
+ if not await connection.body():
+ return Empty
+ return await connection.json()
+
+
+async def msgpack_extractor(connection: Request[Any, Any, Any]) -> Any:
+ """Extract the data from request and insert it into the kwargs injected to the handler.
+
+ Notes:
+ - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage.
+
+ Args:
+ connection: The ASGI connection instance.
+
+ Returns:
+ The MessagePack value.
+ """
+ if not await connection.body():
+ return Empty
+ return await connection.msgpack()
+
+
+async def _extract_multipart(
+ connection: Request[Any, Any, Any],
+ body_kwarg_multipart_form_part_limit: int | None,
+ field_definition: FieldDefinition,
+ is_data_optional: bool,
+ data_dto: type[AbstractDTO] | None,
+) -> Any:
+ multipart_form_part_limit = (
+ body_kwarg_multipart_form_part_limit
+ if body_kwarg_multipart_form_part_limit is not None
+ else connection.app.multipart_form_part_limit
+ )
+ connection.scope["_form"] = form_values = ( # type: ignore[typeddict-unknown-key]
+ connection.scope["_form"] # type: ignore[typeddict-item]
+ if "_form" in connection.scope
+ else parse_multipart_form(
+ body=await connection.body(),
+ boundary=connection.content_type[-1].get("boundary", "").encode(),
+ multipart_form_part_limit=multipart_form_part_limit,
+ type_decoders=connection.route_handler.resolve_type_decoders(),
+ )
+ )
+
+ if field_definition.is_non_string_sequence:
+ values = list(form_values.values())
+ if field_definition.has_inner_subclass_of(UploadFile) and isinstance(values[0], list):
+ return values[0]
+
+ return values
+
+ if field_definition.is_simple_type and field_definition.annotation is UploadFile and form_values:
+ return next(v for v in form_values.values() if isinstance(v, UploadFile))
+
+ if not form_values and is_data_optional:
+ return None
+
+ if data_dto:
+ return data_dto(connection).decode_builtins(form_values)
+
+ for name, tp in field_definition.get_type_hints().items():
+ value = form_values.get(name)
+ if value is not None and is_non_string_sequence(tp) and not isinstance(value, list):
+ form_values[name] = [value]
+
+ return form_values
+
+
+def create_multipart_extractor(
+ field_definition: FieldDefinition, is_data_optional: bool, data_dto: type[AbstractDTO] | None
+) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]:
+ """Create a multipart form-data extractor.
+
+ Args:
+ field_definition: A FieldDefinition instance.
+ is_data_optional: Boolean dictating whether the field is optional.
+ data_dto: A data DTO type, if configured for handler.
+
+ Returns:
+ An extractor function.
+ """
+ body_kwarg_multipart_form_part_limit: int | None = None
+ if field_definition.kwarg_definition and isinstance(field_definition.kwarg_definition, BodyKwarg):
+ body_kwarg_multipart_form_part_limit = field_definition.kwarg_definition.multipart_form_part_limit
+
+ extract_multipart = partial(
+ _extract_multipart,
+ body_kwarg_multipart_form_part_limit=body_kwarg_multipart_form_part_limit,
+ is_data_optional=is_data_optional,
+ data_dto=data_dto,
+ field_definition=field_definition,
+ )
+
+ return cast("Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", extract_multipart)
+
+
+def create_url_encoded_data_extractor(
+ is_data_optional: bool, data_dto: type[AbstractDTO] | None
+) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]:
+ """Create extractor for url encoded form-data.
+
+ Args:
+ is_data_optional: Boolean dictating whether the field is optional.
+ data_dto: A data DTO type, if configured for handler.
+
+ Returns:
+ An extractor function.
+ """
+
+ async def extract_url_encoded_extractor(
+ connection: Request[Any, Any, Any],
+ ) -> Any:
+ connection.scope["_form"] = form_values = ( # type: ignore[typeddict-unknown-key]
+ connection.scope["_form"] # type: ignore[typeddict-item]
+ if "_form" in connection.scope
+ else parse_url_encoded_form_data(await connection.body())
+ )
+
+ if not form_values and is_data_optional:
+ return None
+
+ return data_dto(connection).decode_builtins(form_values) if data_dto else form_values
+
+ return cast(
+ "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", extract_url_encoded_extractor
+ )
+
+
+def create_data_extractor(kwargs_model: KwargsModel) -> Callable[[dict[str, Any], ASGIConnection], None]:
+ """Create an extractor for a request's body.
+
+ Args:
+ kwargs_model: The KwargsModel instance.
+
+ Returns:
+ An extractor for the request's body.
+ """
+
+ if kwargs_model.expected_form_data:
+ media_type, field_definition = kwargs_model.expected_form_data
+
+ if media_type == RequestEncodingType.MULTI_PART:
+ data_extractor = create_multipart_extractor(
+ field_definition=field_definition,
+ is_data_optional=kwargs_model.is_data_optional,
+ data_dto=kwargs_model.expected_data_dto,
+ )
+ else:
+ data_extractor = create_url_encoded_data_extractor(
+ is_data_optional=kwargs_model.is_data_optional,
+ data_dto=kwargs_model.expected_data_dto,
+ )
+ elif kwargs_model.expected_msgpack_data:
+ data_extractor = cast(
+ "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", msgpack_extractor
+ )
+ elif kwargs_model.expected_data_dto:
+ data_extractor = create_dto_extractor(data_dto=kwargs_model.expected_data_dto)
+ else:
+ data_extractor = cast(
+ "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", json_extractor
+ )
+
+ def extractor(
+ values: dict[str, Any],
+ connection: ASGIConnection[Any, Any, Any, Any],
+ ) -> None:
+ values["data"] = data_extractor(connection)
+
+ return extractor
+
+
+def create_dto_extractor(
+ data_dto: type[AbstractDTO],
+) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]:
+ """Create a DTO data extractor.
+
+
+ Returns:
+ An extractor function.
+ """
+
+ async def dto_extractor(connection: Request[Any, Any, Any]) -> Any:
+ if not (body := await connection.body()):
+ return Empty
+ return data_dto(connection).decode_bytes(body)
+
+ return dto_extractor # type:ignore[return-value]