summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/utils/signature.py
blob: eb585990e0df549c617963a925d419f4f63216e6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
from __future__ import annotations

import sys
import typing
from copy import deepcopy
from dataclasses import dataclass, replace
from inspect import Signature, getmembers, isclass, ismethod
from itertools import chain
from typing import TYPE_CHECKING, Any, Union

from typing_extensions import Annotated, Self, get_args, get_origin, get_type_hints

from litestar import connection, datastructures, types
from litestar.exceptions import ImproperlyConfiguredException
from litestar.types import Empty
from litestar.typing import FieldDefinition
from litestar.utils.typing import unwrap_annotation

if TYPE_CHECKING:
    from typing import Sequence

    from litestar.types import AnyCallable

if sys.version_info < (3, 11):
    from typing import _get_defaults  # type: ignore[attr-defined]
else:

    def _get_defaults(_: Any) -> Any: ...


__all__ = (
    "add_types_to_signature_namespace",
    "get_fn_type_hints",
    "ParsedSignature",
)

_GLOBAL_NAMES = {
    namespace: export
    for namespace, export in chain(
        tuple(getmembers(types)), tuple(getmembers(connection)), tuple(getmembers(datastructures))
    )
    if namespace[0].isupper() and namespace in chain(types.__all__, connection.__all__, datastructures.__all__)  # pyright: ignore
}
"""A mapping of names used for handler signature forward-ref resolution.

This allows users to include these names within an `if TYPE_CHECKING:` block in their handler module.
"""


def _unwrap_implicit_optional_hints(defaults: dict[str, Any], hints: dict[str, Any]) -> dict[str, Any]:
    """Unwrap implicit optional hints.

    On Python<3.11, if a function parameter annotation has a ``None`` default, it is unconditionally wrapped in an
    ``Optional`` type.

    If the annotation is not annotated, then any nested unions are flattened, e.g.,:

    .. code-block:: python

        def foo(a: Optional[Union[str, int]] = None): ...

    ...will become `Union[str, int, NoneType]`.

    However, if the annotation is annotated, then we end up with an optional union around the annotated type, e.g.,:

    .. code-block:: python

        def foo(a: Annotated[Optional[Union[str, int]], ...] = None): ...

    ... becomes `Union[Annotated[Union[str, int, NoneType], ...], NoneType]`

    This function makes the latter case consistent with the former by either removing the outer union if it is redundant
    or flattening the union if it is not. The latter case would become `Annotated[Union[str, int, NoneType], ...]`.

    Args:
        defaults: Mapping of names to default values.
        hints: Mapping of names to types.

    Returns:
        Mapping of names to types.
    """

    def _is_two_arg_optional(origin_: Any, args_: Any) -> bool:
        """Check if a type is a two-argument optional type.

        If the type has been wrapped in `Optional` by `get_type_hints()` it will always be a union of a type and
        `NoneType`.

        See: https://github.com/litestar-org/litestar/pull/2516
        """
        return origin_ is Union and len(args_) == 2 and args_[1] is type(None)

    def _is_any_optional(origin_: Any, args_: tuple[Any, ...]) -> bool:
        """Detect if a type is a union with `NoneType`.

        After detecting that a type is a two-argument optional type, this function can be used to detect if the
        inner type is a union with `NoneType` at all.

        We only want to perform the unwrapping of the optional union if the inner type is optional as well.
        """
        return origin_ is Union and any(arg is type(None) for arg in args_)

    for name, default in defaults.items():
        if default is not None:
            continue

        hint = hints[name]
        origin = get_origin(hint)
        args = get_args(hint)

        if _is_two_arg_optional(origin, args):
            unwrapped_inner, meta, wrappers = unwrap_annotation(args[0])

            if Annotated not in wrappers:
                continue

            inner_args = get_args(unwrapped_inner)

            if not _is_any_optional(get_origin(unwrapped_inner), inner_args):
                # this is where hint is like `Union[Annotated[Union[str, int], ...], NoneType]`, we add the outer union
                # into the inner one, and re-wrap with Annotated
                union_args = (*(inner_args or (unwrapped_inner,)), type(None))
                # calling `__class_getitem__` directly as in earlier py vers it is a syntax error to unpack into
                # the getitem brackets, e.g., Annotated[T, *meta].
                hints[name] = Annotated.__class_getitem__((Union[union_args], *meta))  # type: ignore[attr-defined]
                continue

            # this is where hint is like `Union[Annotated[Union[str, NoneType], ...], NoneType]`, we remove the
            # redundant outer union
            hints[name] = args[0]
    return hints


def get_fn_type_hints(fn: Any, namespace: dict[str, Any] | None = None) -> dict[str, Any]:
    """Resolve type hints for ``fn``.

    Args:
        fn: Callable that is being inspected
        namespace: Extra names for resolution of forward references.

    Returns:
        Mapping of names to types.
    """
    fn_to_inspect: Any = fn

    module_name = fn_to_inspect.__module__

    if isclass(fn_to_inspect):
        fn_to_inspect = fn_to_inspect.__init__

    # detect objects that are not functions and that have a `__call__` method
    if callable(fn_to_inspect) and ismethod(fn_to_inspect.__call__):
        fn_to_inspect = fn_to_inspect.__call__

    # inspect the underlying function for methods
    if hasattr(fn_to_inspect, "__func__"):
        fn_to_inspect = fn_to_inspect.__func__

    # Order important. If a litestar name has been overridden in the function module, we want
    # to use that instead of the litestar one.
    namespace = {
        **_GLOBAL_NAMES,
        **vars(typing),
        **vars(sys.modules[module_name]),
        **(namespace or {}),
    }
    hints = get_type_hints(fn_to_inspect, globalns=namespace, include_extras=True)

    if sys.version_info < (3, 11):
        # see https://github.com/litestar-org/litestar/pull/2516
        defaults = _get_defaults(fn_to_inspect)
        hints = _unwrap_implicit_optional_hints(defaults, hints)

    return hints


@dataclass(frozen=True)
class ParsedSignature:
    """Parsed signature.

    This object is the primary source of handler/dependency signature information.

    The only post-processing that occurs is the conversion of any forward referenced type annotations.
    """

    __slots__ = ("parameters", "return_type", "original_signature")

    parameters: dict[str, FieldDefinition]
    """A mapping of parameter names to ParsedSignatureParameter instances."""
    return_type: FieldDefinition
    """The return annotation of the callable."""
    original_signature: Signature
    """The raw signature as returned by :func:`inspect.signature`"""

    def __deepcopy__(self, memo: dict[str, Any]) -> Self:
        return type(self)(
            parameters={k: deepcopy(v) for k, v in self.parameters.items()},
            return_type=deepcopy(self.return_type),
            original_signature=deepcopy(self.original_signature),
        )

    @classmethod
    def from_fn(cls, fn: AnyCallable, signature_namespace: dict[str, Any]) -> Self:
        """Parse a function signature.

        Args:
            fn: Any callable.
            signature_namespace: mapping of names to types for forward reference resolution

        Returns:
            ParsedSignature
        """
        signature = Signature.from_callable(fn)
        fn_type_hints = get_fn_type_hints(fn, namespace=signature_namespace)

        return cls.from_signature(signature, fn_type_hints)

    @classmethod
    def from_signature(cls, signature: Signature, fn_type_hints: dict[str, type]) -> Self:
        """Parse an :class:`inspect.Signature` instance.

        Args:
            signature: An :class:`inspect.Signature` instance.
            fn_type_hints: mapping of types

        Returns:
            ParsedSignature
        """

        parameters = tuple(
            FieldDefinition.from_parameter(parameter=parameter, fn_type_hints=fn_type_hints)
            for name, parameter in signature.parameters.items()
            if name not in ("self", "cls")
        )

        return_type = FieldDefinition.from_annotation(fn_type_hints.get("return", Any))

        return cls(
            parameters={p.name: p for p in parameters},
            return_type=return_type if "return" in fn_type_hints else replace(return_type, annotation=Empty),
            original_signature=signature,
        )


def add_types_to_signature_namespace(
    signature_types: Sequence[Any], signature_namespace: dict[str, Any]
) -> dict[str, Any]:
    """Add types to ith signature namespace mapping.

    Types are added mapped to their `__name__` attribute.

    Args:
        signature_types: A list of types to add to the signature namespace.
        signature_namespace: The signature namespace to add types to.

    Raises:
        ImproperlyConfiguredException: If a type is already defined in the signature namespace.
        AttributeError: If a type does not have a `__name__` attribute.

    Returns:
        The updated signature namespace.
    """
    for typ in signature_types:
        if (name := typ.__name__) in signature_namespace:
            raise ImproperlyConfiguredException(f"Type '{name}' is already defined in the signature namespace")
        signature_namespace[name] = typ
    return signature_namespace