diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/typing.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/typing.py | 636 |
1 files changed, 636 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/typing.py b/venv/lib/python3.11/site-packages/litestar/typing.py new file mode 100644 index 0000000..3a27557 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/typing.py @@ -0,0 +1,636 @@ +from __future__ import annotations + +from collections import abc, deque +from copy import deepcopy +from dataclasses import dataclass, is_dataclass, replace +from inspect import Parameter, Signature +from typing import ( + Any, + AnyStr, + Callable, + Collection, + ForwardRef, + Literal, + Mapping, + Protocol, + Sequence, + TypeVar, + cast, +) + +from msgspec import UnsetType +from typing_extensions import NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict + +from litestar.exceptions import ImproperlyConfiguredException +from litestar.openapi.spec import Example +from litestar.params import BodyKwarg, DependencyKwarg, KwargDefinition, ParameterKwarg +from litestar.types import Empty +from litestar.types.builtin_types import NoneType, UnionTypes +from litestar.utils.predicates import ( + is_annotated_type, + is_any, + is_class_and_subclass, + is_generic, + is_non_string_iterable, + is_non_string_sequence, + is_union, +) +from litestar.utils.typing import ( + get_instantiable_origin, + get_safe_generic_origin, + get_type_hints_with_generics_resolved, + make_non_optional_union, + unwrap_annotation, +) + +__all__ = ("FieldDefinition",) + +T = TypeVar("T", bound=KwargDefinition) + + +class _KwargMetaExtractor(Protocol): + @staticmethod + def matches(annotation: Any, name: str | None, default: Any) -> bool: ... + + @staticmethod + def extract(annotation: Any, default: Any) -> Any: ... + + +_KWARG_META_EXTRACTORS: set[_KwargMetaExtractor] = set() + + +def _unpack_predicate(value: Any) -> dict[str, Any]: + try: + from annotated_types import Predicate + + if isinstance(value, Predicate): + if value.func == str.islower: + return {"lower_case": True} + if value.func == str.isupper: + return {"upper_case": True} + if value.func == str.isascii: + return {"pattern": "[[:ascii:]]"} + if value.func == str.isdigit: + return {"pattern": "[[:digit:]]"} + except ImportError: + pass + + return {} + + +def _parse_metadata(value: Any, is_sequence_container: bool, extra: dict[str, Any] | None) -> dict[str, Any]: + """Parse metadata from a value. + + Args: + value: A metadata value from annotation, namely anything stored under Annotated[x, metadata...] + is_sequence_container: Whether the type is a sequence container (list, tuple etc...) + extra: Extra key values to parse. + + Returns: + A dictionary of constraints, which fulfill the kwargs of a KwargDefinition class. + """ + extra = { + **cast("dict[str, Any]", extra or getattr(value, "extra", None) or {}), + **(getattr(value, "json_schema_extra", None) or {}), + } + example_list: list[Any] | None + if example := extra.pop("example", None): + example_list = [Example(value=example)] + elif examples := getattr(value, "examples", None): + example_list = [Example(value=example) for example in cast("list[str]", examples)] + else: + example_list = None + + return { + k: v + for k, v in { + "gt": getattr(value, "gt", None), + "ge": getattr(value, "ge", None), + "lt": getattr(value, "lt", None), + "le": getattr(value, "le", None), + "multiple_of": getattr(value, "multiple_of", None), + "min_length": None if is_sequence_container else getattr(value, "min_length", None), + "max_length": None if is_sequence_container else getattr(value, "max_length", None), + "description": getattr(value, "description", None), + "examples": example_list, + "title": getattr(value, "title", None), + "lower_case": getattr(value, "to_lower", None), + "upper_case": getattr(value, "to_upper", None), + "pattern": getattr(value, "regex", getattr(value, "pattern", None)), + "min_items": getattr(value, "min_items", getattr(value, "min_length", None)) + if is_sequence_container + else None, + "max_items": getattr(value, "max_items", getattr(value, "max_length", None)) + if is_sequence_container + else None, + "const": getattr(value, "const", None) is not None, + **extra, + }.items() + if v is not None + } + + +def _traverse_metadata( + metadata: Sequence[Any], is_sequence_container: bool, extra: dict[str, Any] | None +) -> dict[str, Any]: + """Recursively traverse metadata from a value. + + Args: + metadata: A list of metadata values from annotation, namely anything stored under Annotated[x, metadata...] + is_sequence_container: Whether the container is a sequence container (list, tuple etc...) + extra: Extra key values to parse. + + Returns: + A dictionary of constraints, which fulfill the kwargs of a KwargDefinition class. + """ + constraints: dict[str, Any] = {} + for value in metadata: + if isinstance(value, (list, set, frozenset, deque)): + constraints.update( + _traverse_metadata( + metadata=cast("Sequence[Any]", value), is_sequence_container=is_sequence_container, extra=extra + ) + ) + elif is_annotated_type(value) and (type_args := [v for v in get_args(value) if v is not None]): + # annotated values can be nested inside other annotated values + # this behaviour is buggy in python 3.8, hence we need to guard here. + if len(type_args) > 1: + constraints.update( + _traverse_metadata(metadata=type_args[1:], is_sequence_container=is_sequence_container, extra=extra) + ) + elif unpacked_predicate := _unpack_predicate(value): + constraints.update(unpacked_predicate) + else: + constraints.update(_parse_metadata(value=value, is_sequence_container=is_sequence_container, extra=extra)) + return constraints + + +def _create_metadata_from_type( + metadata: Sequence[Any], model: type[T], annotation: Any, extra: dict[str, Any] | None +) -> tuple[T | None, dict[str, Any]]: + is_sequence_container = is_non_string_sequence(annotation) + result = _traverse_metadata(metadata=metadata, is_sequence_container=is_sequence_container, extra=extra) + + constraints = {k: v for k, v in result.items() if k in dir(model)} + extra = {k: v for k, v in result.items() if k not in constraints} + return model(**constraints) if constraints else None, extra + + +@dataclass(frozen=True) +class FieldDefinition: + """Represents a function parameter or type annotation.""" + + __slots__ = ( + "annotation", + "args", + "default", + "extra", + "inner_types", + "instantiable_origin", + "kwarg_definition", + "metadata", + "name", + "origin", + "raw", + "safe_generic_origin", + "type_wrappers", + ) + + raw: Any + """The annotation exactly as received.""" + annotation: Any + """The annotation with any "wrapper" types removed, e.g. Annotated.""" + type_wrappers: tuple[type, ...] + """A set of all "wrapper" types, e.g. Annotated.""" + origin: Any + """The result of calling ``get_origin(annotation)`` after unwrapping Annotated, e.g. list.""" + args: tuple[Any, ...] + """The result of calling ``get_args(annotation)`` after unwrapping Annotated, e.g. (int,).""" + metadata: tuple[Any, ...] + """Any metadata associated with the annotation via ``Annotated``.""" + instantiable_origin: Any + """An equivalent type to ``origin`` that can be safely instantiated. E.g., ``Sequence`` -> ``list``.""" + safe_generic_origin: Any + """An equivalent type to ``origin`` that can be safely used as a generic type across all supported Python versions. + + This is to serve safely rebuilding a generic outer type with different args at runtime. + """ + inner_types: tuple[FieldDefinition, ...] + """The type's generic args parsed as ``FieldDefinition``, if applicable.""" + default: Any + """Default value of the field.""" + extra: dict[str, Any] + """A mapping of extra values.""" + kwarg_definition: KwargDefinition | DependencyKwarg | None + """Kwarg Parameter.""" + name: str + """Field name.""" + + def __deepcopy__(self, memo: dict[str, Any]) -> Self: + return type(self)(**{attr: deepcopy(getattr(self, attr)) for attr in self.__slots__}) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, FieldDefinition): + return False + + if self.origin: + return self.origin == other.origin and self.inner_types == other.inner_types + + return self.annotation == other.annotation # type: ignore[no-any-return] + + def __hash__(self) -> int: + return hash((self.name, self.raw, self.annotation, self.origin, self.inner_types)) + + @classmethod + def _extract_metadata( + cls, annotation: Any, name: str | None, default: Any, metadata: tuple[Any, ...], extra: dict[str, Any] | None + ) -> tuple[KwargDefinition | None, dict[str, Any]]: + model = BodyKwarg if name == "data" else ParameterKwarg + + for extractor in _KWARG_META_EXTRACTORS: + if extractor.matches(annotation=annotation, name=name, default=default): + return _create_metadata_from_type( + extractor.extract(annotation=annotation, default=default), + model=model, + annotation=annotation, + extra=extra, + ) + + if any(isinstance(arg, KwargDefinition) for arg in get_args(annotation)): + return next(arg for arg in get_args(annotation) if isinstance(arg, KwargDefinition)), extra or {} + + if metadata: + return _create_metadata_from_type(metadata=metadata, model=model, annotation=annotation, extra=extra) + + return None, {} + + @property + def has_default(self) -> bool: + """Check if the field has a default value. + + Returns: + True if the default is not Empty or Ellipsis otherwise False. + """ + return self.default is not Empty and self.default is not Ellipsis + + @property + def is_non_string_iterable(self) -> bool: + """Check if the field type is an Iterable. + + If ``self.annotation`` is an optional union, only the non-optional members of the union are evaluated. + + See: https://github.com/litestar-org/litestar/issues/1106 + """ + annotation = self.annotation + if self.is_optional: + annotation = make_non_optional_union(annotation) + return is_non_string_iterable(annotation) + + @property + def is_non_string_sequence(self) -> bool: + """Check if the field type is a non-string Sequence. + + If ``self.annotation`` is an optional union, only the non-optional members of the union are evaluated. + + See: https://github.com/litestar-org/litestar/issues/1106 + """ + annotation = self.annotation + if self.is_optional: + annotation = make_non_optional_union(annotation) + return is_non_string_sequence(annotation) + + @property + def is_any(self) -> bool: + """Check if the field type is Any.""" + return is_any(self.annotation) + + @property + def is_generic(self) -> bool: + """Check if the field type is a custom class extending Generic.""" + return is_generic(self.annotation) + + @property + def is_simple_type(self) -> bool: + """Check if the field type is a singleton value (e.g. int, str etc.).""" + return not ( + self.is_generic or self.is_optional or self.is_union or self.is_mapping or self.is_non_string_iterable + ) + + @property + def is_parameter_field(self) -> bool: + """Check if the field type is a parameter kwarg value.""" + return isinstance(self.kwarg_definition, ParameterKwarg) + + @property + def is_const(self) -> bool: + """Check if the field is defined as constant value.""" + return bool(self.kwarg_definition and getattr(self.kwarg_definition, "const", False)) + + @property + def is_required(self) -> bool: + """Check if the field should be marked as a required parameter.""" + if Required in self.type_wrappers: # type: ignore[comparison-overlap] + return True + + if NotRequired in self.type_wrappers or UnsetType in self.args: # type: ignore[comparison-overlap] + return False + + if isinstance(self.kwarg_definition, ParameterKwarg) and self.kwarg_definition.required is not None: + return self.kwarg_definition.required + + return not self.is_optional and not self.is_any and (not self.has_default or self.default is None) + + @property + def is_annotated(self) -> bool: + """Check if the field type is Annotated.""" + return bool(self.metadata) + + @property + def is_literal(self) -> bool: + """Check if the field type is Literal.""" + return self.origin is Literal + + @property + def is_forward_ref(self) -> bool: + """Whether the annotation is a forward reference or not.""" + return isinstance(self.annotation, (str, ForwardRef)) + + @property + def is_mapping(self) -> bool: + """Whether the annotation is a mapping or not.""" + return self.is_subclass_of(Mapping) + + @property + def is_tuple(self) -> bool: + """Whether the annotation is a ``tuple`` or not.""" + return self.is_subclass_of(tuple) + + @property + def is_type_var(self) -> bool: + """Whether the annotation is a TypeVar or not.""" + return isinstance(self.annotation, TypeVar) + + @property + def is_union(self) -> bool: + """Whether the annotation is a union type or not.""" + return self.origin in UnionTypes + + @property + def is_optional(self) -> bool: + """Whether the annotation is Optional or not.""" + return bool(self.is_union and NoneType in self.args) + + @property + def is_none_type(self) -> bool: + """Whether the annotation is NoneType or not.""" + return self.annotation is NoneType + + @property + def is_collection(self) -> bool: + """Whether the annotation is a collection type or not.""" + return self.is_subclass_of(Collection) + + @property + def is_non_string_collection(self) -> bool: + """Whether the annotation is a non-string collection type or not.""" + return self.is_collection and not self.is_subclass_of((str, bytes)) + + @property + def bound_types(self) -> tuple[FieldDefinition, ...] | None: + """A tuple of bound types - if the annotation is a TypeVar with bound types, otherwise None.""" + if self.is_type_var and (bound := getattr(self.annotation, "__bound__", None)): + if is_union(bound): + return tuple(FieldDefinition.from_annotation(t) for t in get_args(bound)) + return (FieldDefinition.from_annotation(bound),) + return None + + @property + def generic_types(self) -> tuple[FieldDefinition, ...] | None: + """A tuple of generic types passed into the annotation - if its generic.""" + if not (bases := getattr(self.annotation, "__orig_bases__", None)): + return None + args: list[FieldDefinition] = [] + for base_args in [getattr(base, "__args__", ()) for base in bases]: + for arg in base_args: + field_definition = FieldDefinition.from_annotation(arg) + if field_definition.generic_types: + args.extend(field_definition.generic_types) + else: + args.append(field_definition) + return tuple(args) + + @property + def is_dataclass_type(self) -> bool: + """Whether the annotation is a dataclass type or not.""" + + return is_dataclass(cast("type", self.origin or self.annotation)) + + @property + def is_typeddict_type(self) -> bool: + """Whether the type is TypedDict or not.""" + + return is_typeddict(self.origin or self.annotation) + + @property + def type_(self) -> Any: + """The type of the annotation with all the wrappers removed, including the generic types.""" + + return self.origin or self.annotation + + def is_subclass_of(self, cl: type[Any] | tuple[type[Any], ...]) -> bool: + """Whether the annotation is a subclass of the given type. + + Where ``self.annotation`` is a union type, this method will return ``True`` when all members of the union are + a subtype of ``cl``, otherwise, ``False``. + + Args: + cl: The type to check, or tuple of types. Passed as 2nd argument to ``issubclass()``. + + Returns: + Whether the annotation is a subtype of the given type(s). + """ + if self.origin: + if self.origin in UnionTypes: + return all(t.is_subclass_of(cl) for t in self.inner_types) + + return self.origin not in UnionTypes and is_class_and_subclass(self.origin, cl) + + if self.annotation is AnyStr: + return is_class_and_subclass(str, cl) or is_class_and_subclass(bytes, cl) + + return self.annotation is not Any and not self.is_type_var and is_class_and_subclass(self.annotation, cl) + + def has_inner_subclass_of(self, cl: type[Any] | tuple[type[Any], ...]) -> bool: + """Whether any generic args are a subclass of the given type. + + Args: + cl: The type to check, or tuple of types. Passed as 2nd argument to ``issubclass()``. + + Returns: + Whether any of the type's generic args are a subclass of the given type. + """ + return any(t.is_subclass_of(cl) for t in self.inner_types) + + def get_type_hints(self, *, include_extras: bool = False, resolve_generics: bool = False) -> dict[str, Any]: + """Get the type hints for the annotation. + + Args: + include_extras: Flag to indicate whether to include ``Annotated[T, ...]`` or not. + resolve_generics: Flag to indicate whether to resolve the generic types in the type hints or not. + + Returns: + The type hints. + """ + + if self.origin is not None or self.is_generic: + if resolve_generics: + return get_type_hints_with_generics_resolved(self.annotation, include_extras=include_extras) + return get_type_hints(self.origin or self.annotation, include_extras=include_extras) + + return get_type_hints(self.annotation, include_extras=include_extras) + + @classmethod + def from_annotation(cls, annotation: Any, **kwargs: Any) -> FieldDefinition: + """Initialize FieldDefinition. + + Args: + annotation: The type annotation. This should be extracted from the return of + ``get_type_hints(..., include_extras=True)`` so that forward references are resolved and recursive + ``Annotated`` types are flattened. + **kwargs: Additional keyword arguments to pass to the ``FieldDefinition`` constructor. + + Returns: + FieldDefinition + """ + + unwrapped, metadata, wrappers = unwrap_annotation(annotation if annotation is not Empty else Any) + origin = get_origin(unwrapped) + + args = () if origin is abc.Callable else get_args(unwrapped) + + if not kwargs.get("kwarg_definition"): + if isinstance(kwargs.get("default"), (KwargDefinition, DependencyKwarg)): + kwargs["kwarg_definition"] = kwargs.pop("default") + elif any(isinstance(v, (KwargDefinition, DependencyKwarg)) for v in metadata): + kwargs["kwarg_definition"] = next( # pragma: no cover + # see https://github.com/nedbat/coveragepy/issues/475 + v + for v in metadata + if isinstance(v, (KwargDefinition, DependencyKwarg)) + ) + metadata = tuple(v for v in metadata if not isinstance(v, (KwargDefinition, DependencyKwarg))) + elif (extra := kwargs.get("extra", {})) and "kwarg_definition" in extra: + kwargs["kwarg_definition"] = extra.pop("kwarg_definition") + else: + kwargs["kwarg_definition"], kwargs["extra"] = cls._extract_metadata( + annotation=annotation, + name=kwargs.get("name", ""), + default=kwargs.get("default", Empty), + metadata=metadata, + extra=kwargs.get("extra"), + ) + + kwargs.setdefault("annotation", unwrapped) + kwargs.setdefault("args", args) + kwargs.setdefault("default", Empty) + kwargs.setdefault("extra", {}) + kwargs.setdefault("inner_types", tuple(FieldDefinition.from_annotation(arg) for arg in args)) + kwargs.setdefault("instantiable_origin", get_instantiable_origin(origin, unwrapped)) + kwargs.setdefault("kwarg_definition", None) + kwargs.setdefault("metadata", metadata) + kwargs.setdefault("name", "") + kwargs.setdefault("origin", origin) + kwargs.setdefault("raw", annotation) + kwargs.setdefault("safe_generic_origin", get_safe_generic_origin(origin, unwrapped)) + kwargs.setdefault("type_wrappers", wrappers) + + instance = FieldDefinition(**kwargs) + if not instance.has_default and instance.kwarg_definition: + return replace(instance, default=instance.kwarg_definition.default) + + return instance + + @classmethod + def from_kwarg( + cls, + annotation: Any, + name: str, + default: Any = Empty, + inner_types: tuple[FieldDefinition, ...] | None = None, + kwarg_definition: KwargDefinition | DependencyKwarg | None = None, + extra: dict[str, Any] | None = None, + ) -> FieldDefinition: + """Create a new FieldDefinition instance. + + Args: + annotation: The type of the kwarg. + name: Field name. + default: A default value. + inner_types: A tuple of FieldDefinition instances representing the inner types, if any. + kwarg_definition: Kwarg Parameter. + extra: A mapping of extra values. + + Returns: + FieldDefinition instance. + """ + + return cls.from_annotation( + annotation, + name=name, + default=default, + **{ + k: v + for k, v in { + "inner_types": inner_types, + "kwarg_definition": kwarg_definition, + "extra": extra, + }.items() + if v is not None + }, + ) + + @classmethod + def from_parameter(cls, parameter: Parameter, fn_type_hints: dict[str, Any]) -> FieldDefinition: + """Initialize ParsedSignatureParameter. + + Args: + parameter: inspect.Parameter + fn_type_hints: mapping of names to types. Should be result of ``get_type_hints()``, preferably via the + :attr:``get_fn_type_hints() <.utils.signature_parsing.get_fn_type_hints>`` helper. + + Returns: + ParsedSignatureParameter. + + """ + from litestar.datastructures import ImmutableState + + try: + annotation = fn_type_hints[parameter.name] + except KeyError as e: + raise ImproperlyConfiguredException( + f"'{parameter.name}' does not have a type annotation. If it should receive any value, use 'Any'." + ) from e + + if parameter.name == "state" and not issubclass(annotation, ImmutableState): + raise ImproperlyConfiguredException( + f"The type annotation `{annotation}` is an invalid type for the 'state' reserved kwarg. " + "It must be typed to a subclass of `litestar.datastructures.ImmutableState` or " + "`litestar.datastructures.State`." + ) + + return FieldDefinition.from_kwarg( + annotation=annotation, + name=parameter.name, + default=Empty if parameter.default is Signature.empty else parameter.default, + ) + + def match_predicate_recursively(self, predicate: Callable[[FieldDefinition], bool]) -> bool: + """Recursively test the passed in predicate against the field and any of its inner fields. + + Args: + predicate: A callable that receives a field definition instance as an arg and returns a boolean. + + Returns: + A boolean. + """ + return predicate(self) or any(t.match_predicate_recursively(predicate) for t in self.inner_types) |