diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py | 492 |
1 files changed, 492 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py b/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py new file mode 100644 index 0000000..e3b347e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_kwargs/extractors.py @@ -0,0 +1,492 @@ +from __future__ import annotations + +from collections import defaultdict +from functools import lru_cache, partial +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Mapping, NamedTuple, cast + +from litestar._multipart import parse_multipart_form +from litestar._parsers import ( + parse_query_string, + parse_url_encoded_form_data, +) +from litestar.datastructures import Headers +from litestar.datastructures.upload_file import UploadFile +from litestar.datastructures.url import URL +from litestar.enums import ParamType, RequestEncodingType +from litestar.exceptions import ValidationException +from litestar.params import BodyKwarg +from litestar.types import Empty +from litestar.utils.predicates import is_non_string_sequence +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from litestar._kwargs import KwargsModel + from litestar._kwargs.parameter_definition import ParameterDefinition + from litestar.connection import ASGIConnection, Request + from litestar.dto import AbstractDTO + from litestar.typing import FieldDefinition + + +__all__ = ( + "body_extractor", + "cookies_extractor", + "create_connection_value_extractor", + "create_data_extractor", + "create_multipart_extractor", + "create_query_default_dict", + "create_url_encoded_data_extractor", + "headers_extractor", + "json_extractor", + "msgpack_extractor", + "parse_connection_headers", + "parse_connection_query_params", + "query_extractor", + "request_extractor", + "scope_extractor", + "socket_extractor", + "state_extractor", +) + + +class ParamMappings(NamedTuple): + alias_and_key_tuples: list[tuple[str, str]] + alias_defaults: dict[str, Any] + alias_to_param: dict[str, ParameterDefinition] + + +def _create_param_mappings(expected_params: set[ParameterDefinition]) -> ParamMappings: + alias_and_key_tuples = [] + alias_defaults = {} + alias_to_params: dict[str, ParameterDefinition] = {} + for param in expected_params: + alias = param.field_alias + if param.param_type == ParamType.HEADER: + alias = alias.lower() + + alias_and_key_tuples.append((alias, param.field_name)) + + if not (param.is_required or param.default is Ellipsis): + alias_defaults[alias] = param.default + + alias_to_params[alias] = param + + return ParamMappings( + alias_and_key_tuples=alias_and_key_tuples, + alias_defaults=alias_defaults, + alias_to_param=alias_to_params, + ) + + +def create_connection_value_extractor( + kwargs_model: KwargsModel, + connection_key: str, + expected_params: set[ParameterDefinition], + parser: Callable[[ASGIConnection, KwargsModel], Mapping[str, Any]] | None = None, +) -> Callable[[dict[str, Any], ASGIConnection], None]: + """Create a kwargs extractor function. + + Args: + kwargs_model: The KwargsModel instance. + connection_key: The attribute key to use. + expected_params: The set of expected params. + parser: An optional parser function. + + Returns: + An extractor function. + """ + + alias_and_key_tuples, alias_defaults, alias_to_params = _create_param_mappings(expected_params) + + def extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + data = parser(connection, kwargs_model) if parser else getattr(connection, connection_key, {}) + + try: + connection_mapping: dict[str, Any] = { + key: data[alias] if alias in data else alias_defaults[alias] for alias, key in alias_and_key_tuples + } + values.update(connection_mapping) + except KeyError as e: + param = alias_to_params[e.args[0]] + path = URL.from_components( + path=connection.url.path, + query=connection.url.query, + ) + raise ValidationException( + f"Missing required {param.param_type.value} parameter {param.field_alias!r} for path {path}" + ) from e + + return extractor + + +@lru_cache(1024) +def create_query_default_dict( + parsed_query: tuple[tuple[str, str], ...], sequence_query_parameter_names: tuple[str, ...] +) -> defaultdict[str, list[str] | str]: + """Transform a list of tuples into a default dict. Ensures non-list values are not wrapped in a list. + + Args: + parsed_query: The parsed query list of tuples. + sequence_query_parameter_names: A set of query parameters that should be wrapped in list. + + Returns: + A default dict + """ + output: defaultdict[str, list[str] | str] = defaultdict(list) + + for k, v in parsed_query: + if k in sequence_query_parameter_names: + output[k].append(v) # type: ignore[union-attr] + else: + output[k] = v + + return output + + +def parse_connection_query_params(connection: ASGIConnection, kwargs_model: KwargsModel) -> dict[str, Any]: + """Parse query params and cache the result in scope. + + Args: + connection: The ASGI connection instance. + kwargs_model: The KwargsModel instance. + + Returns: + A dictionary of parsed values. + """ + parsed_query = ( + connection._parsed_query + if connection._parsed_query is not Empty + else parse_query_string(connection.scope.get("query_string", b"")) + ) + ScopeState.from_scope(connection.scope).parsed_query = parsed_query + return create_query_default_dict( + parsed_query=parsed_query, + sequence_query_parameter_names=kwargs_model.sequence_query_parameter_names, + ) + + +def parse_connection_headers(connection: ASGIConnection, _: KwargsModel) -> Headers: + """Parse header parameters and cache the result in scope. + + Args: + connection: The ASGI connection instance. + _: The KwargsModel instance. + + Returns: + A Headers instance + """ + return Headers.from_scope(connection.scope) + + +def state_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the app state from the connection and insert it to the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["state"] = connection.app.state._state + + +def headers_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the headers from the connection and insert them to the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + # TODO: This should be removed in 3.0 and instead Headers should be injected + # directly. We are only keeping this one around to not break things + values["headers"] = dict(connection.headers.items()) + + +def cookies_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the cookies from the connection and insert them to the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["cookies"] = connection.cookies + + +def query_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the query params from the connection and insert them to the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["query"] = connection.query_params + + +def scope_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Extract the scope from the connection and insert it into the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["scope"] = connection.scope + + +def request_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Set the connection instance as the 'request' value in the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["request"] = connection + + +def socket_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + """Set the connection instance as the 'socket' value in the kwargs injected to the handler. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + None + """ + values["socket"] = connection + + +def body_extractor( + values: dict[str, Any], + connection: Request[Any, Any, Any], +) -> None: + """Extract the body from the request instance. + + Notes: + - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage. + + Args: + connection: The ASGI connection instance. + values: The kwargs that are extracted from the connection and will be injected into the handler. + + Returns: + The Body value. + """ + values["body"] = connection.body() + + +async def json_extractor(connection: Request[Any, Any, Any]) -> Any: + """Extract the data from request and insert it into the kwargs injected to the handler. + + Notes: + - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage. + + Args: + connection: The ASGI connection instance. + + Returns: + The JSON value. + """ + if not await connection.body(): + return Empty + return await connection.json() + + +async def msgpack_extractor(connection: Request[Any, Any, Any]) -> Any: + """Extract the data from request and insert it into the kwargs injected to the handler. + + Notes: + - this extractor sets a Coroutine as the value in the kwargs. These are resolved at a later stage. + + Args: + connection: The ASGI connection instance. + + Returns: + The MessagePack value. + """ + if not await connection.body(): + return Empty + return await connection.msgpack() + + +async def _extract_multipart( + connection: Request[Any, Any, Any], + body_kwarg_multipart_form_part_limit: int | None, + field_definition: FieldDefinition, + is_data_optional: bool, + data_dto: type[AbstractDTO] | None, +) -> Any: + multipart_form_part_limit = ( + body_kwarg_multipart_form_part_limit + if body_kwarg_multipart_form_part_limit is not None + else connection.app.multipart_form_part_limit + ) + connection.scope["_form"] = form_values = ( # type: ignore[typeddict-unknown-key] + connection.scope["_form"] # type: ignore[typeddict-item] + if "_form" in connection.scope + else parse_multipart_form( + body=await connection.body(), + boundary=connection.content_type[-1].get("boundary", "").encode(), + multipart_form_part_limit=multipart_form_part_limit, + type_decoders=connection.route_handler.resolve_type_decoders(), + ) + ) + + if field_definition.is_non_string_sequence: + values = list(form_values.values()) + if field_definition.has_inner_subclass_of(UploadFile) and isinstance(values[0], list): + return values[0] + + return values + + if field_definition.is_simple_type and field_definition.annotation is UploadFile and form_values: + return next(v for v in form_values.values() if isinstance(v, UploadFile)) + + if not form_values and is_data_optional: + return None + + if data_dto: + return data_dto(connection).decode_builtins(form_values) + + for name, tp in field_definition.get_type_hints().items(): + value = form_values.get(name) + if value is not None and is_non_string_sequence(tp) and not isinstance(value, list): + form_values[name] = [value] + + return form_values + + +def create_multipart_extractor( + field_definition: FieldDefinition, is_data_optional: bool, data_dto: type[AbstractDTO] | None +) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]: + """Create a multipart form-data extractor. + + Args: + field_definition: A FieldDefinition instance. + is_data_optional: Boolean dictating whether the field is optional. + data_dto: A data DTO type, if configured for handler. + + Returns: + An extractor function. + """ + body_kwarg_multipart_form_part_limit: int | None = None + if field_definition.kwarg_definition and isinstance(field_definition.kwarg_definition, BodyKwarg): + body_kwarg_multipart_form_part_limit = field_definition.kwarg_definition.multipart_form_part_limit + + extract_multipart = partial( + _extract_multipart, + body_kwarg_multipart_form_part_limit=body_kwarg_multipart_form_part_limit, + is_data_optional=is_data_optional, + data_dto=data_dto, + field_definition=field_definition, + ) + + return cast("Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", extract_multipart) + + +def create_url_encoded_data_extractor( + is_data_optional: bool, data_dto: type[AbstractDTO] | None +) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]: + """Create extractor for url encoded form-data. + + Args: + is_data_optional: Boolean dictating whether the field is optional. + data_dto: A data DTO type, if configured for handler. + + Returns: + An extractor function. + """ + + async def extract_url_encoded_extractor( + connection: Request[Any, Any, Any], + ) -> Any: + connection.scope["_form"] = form_values = ( # type: ignore[typeddict-unknown-key] + connection.scope["_form"] # type: ignore[typeddict-item] + if "_form" in connection.scope + else parse_url_encoded_form_data(await connection.body()) + ) + + if not form_values and is_data_optional: + return None + + return data_dto(connection).decode_builtins(form_values) if data_dto else form_values + + return cast( + "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", extract_url_encoded_extractor + ) + + +def create_data_extractor(kwargs_model: KwargsModel) -> Callable[[dict[str, Any], ASGIConnection], None]: + """Create an extractor for a request's body. + + Args: + kwargs_model: The KwargsModel instance. + + Returns: + An extractor for the request's body. + """ + + if kwargs_model.expected_form_data: + media_type, field_definition = kwargs_model.expected_form_data + + if media_type == RequestEncodingType.MULTI_PART: + data_extractor = create_multipart_extractor( + field_definition=field_definition, + is_data_optional=kwargs_model.is_data_optional, + data_dto=kwargs_model.expected_data_dto, + ) + else: + data_extractor = create_url_encoded_data_extractor( + is_data_optional=kwargs_model.is_data_optional, + data_dto=kwargs_model.expected_data_dto, + ) + elif kwargs_model.expected_msgpack_data: + data_extractor = cast( + "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", msgpack_extractor + ) + elif kwargs_model.expected_data_dto: + data_extractor = create_dto_extractor(data_dto=kwargs_model.expected_data_dto) + else: + data_extractor = cast( + "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", json_extractor + ) + + def extractor( + values: dict[str, Any], + connection: ASGIConnection[Any, Any, Any, Any], + ) -> None: + values["data"] = data_extractor(connection) + + return extractor + + +def create_dto_extractor( + data_dto: type[AbstractDTO], +) -> Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]: + """Create a DTO data extractor. + + + Returns: + An extractor function. + """ + + async def dto_extractor(connection: Request[Any, Any, Any]) -> Any: + if not (body := await connection.body()): + return Empty + return data_dto(connection).decode_bytes(body) + + return dto_extractor # type:ignore[return-value] |