summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/msgspec/_utils.py
blob: ddf6f27c0936b23e2296e2bbb882e7eb6cd5e68c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# type: ignore
import collections
import sys
import typing

try:
    from typing_extensions import _AnnotatedAlias
except Exception:
    try:
        from typing import _AnnotatedAlias
    except Exception:
        _AnnotatedAlias = None

try:
    from typing_extensions import get_type_hints as _get_type_hints
except Exception:
    from typing import get_type_hints as _get_type_hints

try:
    from typing_extensions import NotRequired, Required
except Exception:
    try:
        from typing import NotRequired, Required
    except Exception:
        Required = NotRequired = None


if Required is None and _AnnotatedAlias is None:
    # No extras available, so no `include_extras`
    get_type_hints = _get_type_hints
else:

    def get_type_hints(obj):
        return _get_type_hints(obj, include_extras=True)


# The `is_class` argument was new in 3.11, but was backported to 3.9 and 3.10.
# It's _likely_ to be available for 3.9/3.10, but may not be. Easiest way to
# check is to try it and see. This check can be removed when we drop support
# for Python 3.10.
try:
    typing.ForwardRef("Foo", is_class=True)
except TypeError:

    def _forward_ref(value):
        return typing.ForwardRef(value, is_argument=False)

else:

    def _forward_ref(value):
        return typing.ForwardRef(value, is_argument=False, is_class=True)


def _apply_params(obj, mapping):
    if params := getattr(obj, "__parameters__", None):
        args = tuple(mapping.get(p, p) for p in params)
        return obj[args]
    elif isinstance(obj, typing.TypeVar):
        return mapping.get(obj, obj)
    return obj


def _get_class_mro_and_typevar_mappings(obj):
    mapping = {}

    if isinstance(obj, type):
        cls = obj
    else:
        cls = obj.__origin__

    def inner(c, scope):
        if isinstance(c, type):
            cls = c
            new_scope = {}
        else:
            cls = getattr(c, "__origin__", None)
            if cls in (None, object, typing.Generic) or cls in mapping:
                return
            params = cls.__parameters__
            args = tuple(_apply_params(a, scope) for a in c.__args__)
            assert len(params) == len(args)
            mapping[cls] = new_scope = dict(zip(params, args))

        if issubclass(cls, typing.Generic):
            bases = getattr(cls, "__orig_bases__", cls.__bases__)
            for b in bases:
                inner(b, new_scope)

    inner(obj, {})
    return cls.__mro__, mapping


def get_class_annotations(obj):
    """Get the annotations for a class.

    This is similar to ``typing.get_type_hints``, except:

    - We maintain it
    - It leaves extras like ``Annotated``/``ClassVar`` alone
    - It resolves any parametrized generics in the class mro. The returned
      mapping may still include ``TypeVar`` values, but those should be treated
      as their unparametrized variants (i.e. equal to ``Any`` for the common case).

    Note that this function doesn't check that Generic types are being used
    properly - invalid uses of `Generic` may slip through without complaint.

    The assumption here is that the user is making use of a static analysis
    tool like ``mypy``/``pyright`` already, which would catch misuse of these
    APIs.
    """
    hints = {}
    mro, typevar_mappings = _get_class_mro_and_typevar_mappings(obj)

    for cls in mro:
        if cls in (typing.Generic, object):
            continue

        mapping = typevar_mappings.get(cls)
        cls_locals = dict(vars(cls))
        cls_globals = getattr(sys.modules.get(cls.__module__, None), "__dict__", {})

        ann = cls.__dict__.get("__annotations__", {})
        for name, value in ann.items():
            if name in hints:
                continue
            if value is None:
                value = type(None)
            elif isinstance(value, str):
                value = _forward_ref(value)
            value = typing._eval_type(value, cls_locals, cls_globals)
            if mapping is not None:
                value = _apply_params(value, mapping)
            hints[name] = value
    return hints


# A mapping from a type annotation (or annotation __origin__) to the concrete
# python type that msgspec will use when decoding. THIS IS PRIVATE FOR A
# REASON. DON'T MUCK WITH THIS.
_CONCRETE_TYPES = {
    list: list,
    tuple: tuple,
    set: set,
    frozenset: frozenset,
    dict: dict,
    typing.List: list,
    typing.Tuple: tuple,
    typing.Set: set,
    typing.FrozenSet: frozenset,
    typing.Dict: dict,
    typing.Collection: list,
    typing.MutableSequence: list,
    typing.Sequence: list,
    typing.MutableMapping: dict,
    typing.Mapping: dict,
    typing.MutableSet: set,
    typing.AbstractSet: set,
    collections.abc.Collection: list,
    collections.abc.MutableSequence: list,
    collections.abc.Sequence: list,
    collections.abc.MutableSet: set,
    collections.abc.Set: set,
    collections.abc.MutableMapping: dict,
    collections.abc.Mapping: dict,
}


def get_typeddict_info(obj):
    if isinstance(obj, type):
        cls = obj
    else:
        cls = obj.__origin__

    raw_hints = get_class_annotations(obj)

    if hasattr(cls, "__required_keys__"):
        required = set(cls.__required_keys__)
    elif cls.__total__:
        required = set(raw_hints)
    else:
        required = set()

    # Both `typing.TypedDict` and `typing_extensions.TypedDict` have a bug
    # where `Required`/`NotRequired` aren't properly detected at runtime when
    # `__future__.annotations` is enabled, meaning the `__required_keys__`
    # isn't correct. This code block works around this issue by amending the
    # set of required keys as needed, while also stripping off any
    # `Required`/`NotRequired` wrappers.
    hints = {}
    for k, v in raw_hints.items():
        origin = getattr(v, "__origin__", False)
        if origin is Required:
            required.add(k)
            hints[k] = v.__args__[0]
        elif origin is NotRequired:
            required.discard(k)
            hints[k] = v.__args__[0]
        else:
            hints[k] = v
    return hints, required


def get_dataclass_info(obj):
    if isinstance(obj, type):
        cls = obj
    else:
        cls = obj.__origin__
    hints = get_class_annotations(obj)
    required = []
    optional = []
    defaults = []

    if hasattr(cls, "__dataclass_fields__"):
        from dataclasses import _FIELD, _FIELD_INITVAR, MISSING

        for field in cls.__dataclass_fields__.values():
            if field._field_type is not _FIELD:
                if field._field_type is _FIELD_INITVAR:
                    raise TypeError(
                        "dataclasses with `InitVar` fields are not supported"
                    )
                continue
            name = field.name
            typ = hints[name]
            if field.default is not MISSING:
                defaults.append(field.default)
                optional.append((name, typ, False))
            elif field.default_factory is not MISSING:
                defaults.append(field.default_factory)
                optional.append((name, typ, True))
            else:
                required.append((name, typ, False))

        required.extend(optional)

        pre_init = None
        post_init = getattr(cls, "__post_init__", None)
    else:
        from attrs import NOTHING, Factory

        fields_with_validators = []

        for field in cls.__attrs_attrs__:
            name = field.name
            typ = hints[name]
            default = field.default
            if default is not NOTHING:
                if isinstance(default, Factory):
                    if default.takes_self:
                        raise NotImplementedError(
                            "Support for default factories with `takes_self=True` "
                            "is not implemented. File a GitHub issue if you need "
                            "this feature!"
                        )
                    defaults.append(default.factory)
                    optional.append((name, typ, True))
                else:
                    defaults.append(default)
                    optional.append((name, typ, False))
            else:
                required.append((name, typ, False))

            if field.validator is not None:
                fields_with_validators.append(field)

        required.extend(optional)

        pre_init = getattr(cls, "__attrs_pre_init__", None)
        post_init = getattr(cls, "__attrs_post_init__", None)

        if fields_with_validators:
            post_init = _wrap_attrs_validators(fields_with_validators, post_init)

    return cls, tuple(required), tuple(defaults), pre_init, post_init


def _wrap_attrs_validators(fields, post_init):
    def inner(obj):
        for field in fields:
            field.validator(obj, field, getattr(obj, field.name))
        if post_init is not None:
            post_init(obj)

    return inner


def rebuild(cls, kwargs):
    """Used to unpickle Structs with keyword-only fields"""
    return cls(**kwargs)