summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/utils/signature.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/utils/signature.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/utils/signature.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/utils/signature.py267
1 files changed, 267 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/utils/signature.py b/venv/lib/python3.11/site-packages/litestar/utils/signature.py
new file mode 100644
index 0000000..eb58599
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/utils/signature.py
@@ -0,0 +1,267 @@
+from __future__ import annotations
+
+import sys
+import typing
+from copy import deepcopy
+from dataclasses import dataclass, replace
+from inspect import Signature, getmembers, isclass, ismethod
+from itertools import chain
+from typing import TYPE_CHECKING, Any, Union
+
+from typing_extensions import Annotated, Self, get_args, get_origin, get_type_hints
+
+from litestar import connection, datastructures, types
+from litestar.exceptions import ImproperlyConfiguredException
+from litestar.types import Empty
+from litestar.typing import FieldDefinition
+from litestar.utils.typing import unwrap_annotation
+
+if TYPE_CHECKING:
+ from typing import Sequence
+
+ from litestar.types import AnyCallable
+
+if sys.version_info < (3, 11):
+ from typing import _get_defaults # type: ignore[attr-defined]
+else:
+
+ def _get_defaults(_: Any) -> Any: ...
+
+
+__all__ = (
+ "add_types_to_signature_namespace",
+ "get_fn_type_hints",
+ "ParsedSignature",
+)
+
+_GLOBAL_NAMES = {
+ namespace: export
+ for namespace, export in chain(
+ tuple(getmembers(types)), tuple(getmembers(connection)), tuple(getmembers(datastructures))
+ )
+ if namespace[0].isupper() and namespace in chain(types.__all__, connection.__all__, datastructures.__all__) # pyright: ignore
+}
+"""A mapping of names used for handler signature forward-ref resolution.
+
+This allows users to include these names within an `if TYPE_CHECKING:` block in their handler module.
+"""
+
+
+def _unwrap_implicit_optional_hints(defaults: dict[str, Any], hints: dict[str, Any]) -> dict[str, Any]:
+ """Unwrap implicit optional hints.
+
+ On Python<3.11, if a function parameter annotation has a ``None`` default, it is unconditionally wrapped in an
+ ``Optional`` type.
+
+ If the annotation is not annotated, then any nested unions are flattened, e.g.,:
+
+ .. code-block:: python
+
+ def foo(a: Optional[Union[str, int]] = None): ...
+
+ ...will become `Union[str, int, NoneType]`.
+
+ However, if the annotation is annotated, then we end up with an optional union around the annotated type, e.g.,:
+
+ .. code-block:: python
+
+ def foo(a: Annotated[Optional[Union[str, int]], ...] = None): ...
+
+ ... becomes `Union[Annotated[Union[str, int, NoneType], ...], NoneType]`
+
+ This function makes the latter case consistent with the former by either removing the outer union if it is redundant
+ or flattening the union if it is not. The latter case would become `Annotated[Union[str, int, NoneType], ...]`.
+
+ Args:
+ defaults: Mapping of names to default values.
+ hints: Mapping of names to types.
+
+ Returns:
+ Mapping of names to types.
+ """
+
+ def _is_two_arg_optional(origin_: Any, args_: Any) -> bool:
+ """Check if a type is a two-argument optional type.
+
+ If the type has been wrapped in `Optional` by `get_type_hints()` it will always be a union of a type and
+ `NoneType`.
+
+ See: https://github.com/litestar-org/litestar/pull/2516
+ """
+ return origin_ is Union and len(args_) == 2 and args_[1] is type(None)
+
+ def _is_any_optional(origin_: Any, args_: tuple[Any, ...]) -> bool:
+ """Detect if a type is a union with `NoneType`.
+
+ After detecting that a type is a two-argument optional type, this function can be used to detect if the
+ inner type is a union with `NoneType` at all.
+
+ We only want to perform the unwrapping of the optional union if the inner type is optional as well.
+ """
+ return origin_ is Union and any(arg is type(None) for arg in args_)
+
+ for name, default in defaults.items():
+ if default is not None:
+ continue
+
+ hint = hints[name]
+ origin = get_origin(hint)
+ args = get_args(hint)
+
+ if _is_two_arg_optional(origin, args):
+ unwrapped_inner, meta, wrappers = unwrap_annotation(args[0])
+
+ if Annotated not in wrappers:
+ continue
+
+ inner_args = get_args(unwrapped_inner)
+
+ if not _is_any_optional(get_origin(unwrapped_inner), inner_args):
+ # this is where hint is like `Union[Annotated[Union[str, int], ...], NoneType]`, we add the outer union
+ # into the inner one, and re-wrap with Annotated
+ union_args = (*(inner_args or (unwrapped_inner,)), type(None))
+ # calling `__class_getitem__` directly as in earlier py vers it is a syntax error to unpack into
+ # the getitem brackets, e.g., Annotated[T, *meta].
+ hints[name] = Annotated.__class_getitem__((Union[union_args], *meta)) # type: ignore[attr-defined]
+ continue
+
+ # this is where hint is like `Union[Annotated[Union[str, NoneType], ...], NoneType]`, we remove the
+ # redundant outer union
+ hints[name] = args[0]
+ return hints
+
+
+def get_fn_type_hints(fn: Any, namespace: dict[str, Any] | None = None) -> dict[str, Any]:
+ """Resolve type hints for ``fn``.
+
+ Args:
+ fn: Callable that is being inspected
+ namespace: Extra names for resolution of forward references.
+
+ Returns:
+ Mapping of names to types.
+ """
+ fn_to_inspect: Any = fn
+
+ module_name = fn_to_inspect.__module__
+
+ if isclass(fn_to_inspect):
+ fn_to_inspect = fn_to_inspect.__init__
+
+ # detect objects that are not functions and that have a `__call__` method
+ if callable(fn_to_inspect) and ismethod(fn_to_inspect.__call__):
+ fn_to_inspect = fn_to_inspect.__call__
+
+ # inspect the underlying function for methods
+ if hasattr(fn_to_inspect, "__func__"):
+ fn_to_inspect = fn_to_inspect.__func__
+
+ # Order important. If a litestar name has been overridden in the function module, we want
+ # to use that instead of the litestar one.
+ namespace = {
+ **_GLOBAL_NAMES,
+ **vars(typing),
+ **vars(sys.modules[module_name]),
+ **(namespace or {}),
+ }
+ hints = get_type_hints(fn_to_inspect, globalns=namespace, include_extras=True)
+
+ if sys.version_info < (3, 11):
+ # see https://github.com/litestar-org/litestar/pull/2516
+ defaults = _get_defaults(fn_to_inspect)
+ hints = _unwrap_implicit_optional_hints(defaults, hints)
+
+ return hints
+
+
+@dataclass(frozen=True)
+class ParsedSignature:
+ """Parsed signature.
+
+ This object is the primary source of handler/dependency signature information.
+
+ The only post-processing that occurs is the conversion of any forward referenced type annotations.
+ """
+
+ __slots__ = ("parameters", "return_type", "original_signature")
+
+ parameters: dict[str, FieldDefinition]
+ """A mapping of parameter names to ParsedSignatureParameter instances."""
+ return_type: FieldDefinition
+ """The return annotation of the callable."""
+ original_signature: Signature
+ """The raw signature as returned by :func:`inspect.signature`"""
+
+ def __deepcopy__(self, memo: dict[str, Any]) -> Self:
+ return type(self)(
+ parameters={k: deepcopy(v) for k, v in self.parameters.items()},
+ return_type=deepcopy(self.return_type),
+ original_signature=deepcopy(self.original_signature),
+ )
+
+ @classmethod
+ def from_fn(cls, fn: AnyCallable, signature_namespace: dict[str, Any]) -> Self:
+ """Parse a function signature.
+
+ Args:
+ fn: Any callable.
+ signature_namespace: mapping of names to types for forward reference resolution
+
+ Returns:
+ ParsedSignature
+ """
+ signature = Signature.from_callable(fn)
+ fn_type_hints = get_fn_type_hints(fn, namespace=signature_namespace)
+
+ return cls.from_signature(signature, fn_type_hints)
+
+ @classmethod
+ def from_signature(cls, signature: Signature, fn_type_hints: dict[str, type]) -> Self:
+ """Parse an :class:`inspect.Signature` instance.
+
+ Args:
+ signature: An :class:`inspect.Signature` instance.
+ fn_type_hints: mapping of types
+
+ Returns:
+ ParsedSignature
+ """
+
+ parameters = tuple(
+ FieldDefinition.from_parameter(parameter=parameter, fn_type_hints=fn_type_hints)
+ for name, parameter in signature.parameters.items()
+ if name not in ("self", "cls")
+ )
+
+ return_type = FieldDefinition.from_annotation(fn_type_hints.get("return", Any))
+
+ return cls(
+ parameters={p.name: p for p in parameters},
+ return_type=return_type if "return" in fn_type_hints else replace(return_type, annotation=Empty),
+ original_signature=signature,
+ )
+
+
+def add_types_to_signature_namespace(
+ signature_types: Sequence[Any], signature_namespace: dict[str, Any]
+) -> dict[str, Any]:
+ """Add types to ith signature namespace mapping.
+
+ Types are added mapped to their `__name__` attribute.
+
+ Args:
+ signature_types: A list of types to add to the signature namespace.
+ signature_namespace: The signature namespace to add types to.
+
+ Raises:
+ ImproperlyConfiguredException: If a type is already defined in the signature namespace.
+ AttributeError: If a type does not have a `__name__` attribute.
+
+ Returns:
+ The updated signature namespace.
+ """
+ for typ in signature_types:
+ if (name := typ.__name__) in signature_namespace:
+ raise ImproperlyConfiguredException(f"Type '{name}' is already defined in the signature namespace")
+ signature_namespace[name] = typ
+ return signature_namespace