diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/msgspec/_utils.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/msgspec/_utils.py | 289 |
1 files changed, 289 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/msgspec/_utils.py b/venv/lib/python3.11/site-packages/msgspec/_utils.py new file mode 100644 index 0000000..ddf6f27 --- /dev/null +++ b/venv/lib/python3.11/site-packages/msgspec/_utils.py @@ -0,0 +1,289 @@ +# type: ignore +import collections +import sys +import typing + +try: + from typing_extensions import _AnnotatedAlias +except Exception: + try: + from typing import _AnnotatedAlias + except Exception: + _AnnotatedAlias = None + +try: + from typing_extensions import get_type_hints as _get_type_hints +except Exception: + from typing import get_type_hints as _get_type_hints + +try: + from typing_extensions import NotRequired, Required +except Exception: + try: + from typing import NotRequired, Required + except Exception: + Required = NotRequired = None + + +if Required is None and _AnnotatedAlias is None: + # No extras available, so no `include_extras` + get_type_hints = _get_type_hints +else: + + def get_type_hints(obj): + return _get_type_hints(obj, include_extras=True) + + +# The `is_class` argument was new in 3.11, but was backported to 3.9 and 3.10. +# It's _likely_ to be available for 3.9/3.10, but may not be. Easiest way to +# check is to try it and see. This check can be removed when we drop support +# for Python 3.10. +try: + typing.ForwardRef("Foo", is_class=True) +except TypeError: + + def _forward_ref(value): + return typing.ForwardRef(value, is_argument=False) + +else: + + def _forward_ref(value): + return typing.ForwardRef(value, is_argument=False, is_class=True) + + +def _apply_params(obj, mapping): + if params := getattr(obj, "__parameters__", None): + args = tuple(mapping.get(p, p) for p in params) + return obj[args] + elif isinstance(obj, typing.TypeVar): + return mapping.get(obj, obj) + return obj + + +def _get_class_mro_and_typevar_mappings(obj): + mapping = {} + + if isinstance(obj, type): + cls = obj + else: + cls = obj.__origin__ + + def inner(c, scope): + if isinstance(c, type): + cls = c + new_scope = {} + else: + cls = getattr(c, "__origin__", None) + if cls in (None, object, typing.Generic) or cls in mapping: + return + params = cls.__parameters__ + args = tuple(_apply_params(a, scope) for a in c.__args__) + assert len(params) == len(args) + mapping[cls] = new_scope = dict(zip(params, args)) + + if issubclass(cls, typing.Generic): + bases = getattr(cls, "__orig_bases__", cls.__bases__) + for b in bases: + inner(b, new_scope) + + inner(obj, {}) + return cls.__mro__, mapping + + +def get_class_annotations(obj): + """Get the annotations for a class. + + This is similar to ``typing.get_type_hints``, except: + + - We maintain it + - It leaves extras like ``Annotated``/``ClassVar`` alone + - It resolves any parametrized generics in the class mro. The returned + mapping may still include ``TypeVar`` values, but those should be treated + as their unparametrized variants (i.e. equal to ``Any`` for the common case). + + Note that this function doesn't check that Generic types are being used + properly - invalid uses of `Generic` may slip through without complaint. + + The assumption here is that the user is making use of a static analysis + tool like ``mypy``/``pyright`` already, which would catch misuse of these + APIs. + """ + hints = {} + mro, typevar_mappings = _get_class_mro_and_typevar_mappings(obj) + + for cls in mro: + if cls in (typing.Generic, object): + continue + + mapping = typevar_mappings.get(cls) + cls_locals = dict(vars(cls)) + cls_globals = getattr(sys.modules.get(cls.__module__, None), "__dict__", {}) + + ann = cls.__dict__.get("__annotations__", {}) + for name, value in ann.items(): + if name in hints: + continue + if value is None: + value = type(None) + elif isinstance(value, str): + value = _forward_ref(value) + value = typing._eval_type(value, cls_locals, cls_globals) + if mapping is not None: + value = _apply_params(value, mapping) + hints[name] = value + return hints + + +# A mapping from a type annotation (or annotation __origin__) to the concrete +# python type that msgspec will use when decoding. THIS IS PRIVATE FOR A +# REASON. DON'T MUCK WITH THIS. +_CONCRETE_TYPES = { + list: list, + tuple: tuple, + set: set, + frozenset: frozenset, + dict: dict, + typing.List: list, + typing.Tuple: tuple, + typing.Set: set, + typing.FrozenSet: frozenset, + typing.Dict: dict, + typing.Collection: list, + typing.MutableSequence: list, + typing.Sequence: list, + typing.MutableMapping: dict, + typing.Mapping: dict, + typing.MutableSet: set, + typing.AbstractSet: set, + collections.abc.Collection: list, + collections.abc.MutableSequence: list, + collections.abc.Sequence: list, + collections.abc.MutableSet: set, + collections.abc.Set: set, + collections.abc.MutableMapping: dict, + collections.abc.Mapping: dict, +} + + +def get_typeddict_info(obj): + if isinstance(obj, type): + cls = obj + else: + cls = obj.__origin__ + + raw_hints = get_class_annotations(obj) + + if hasattr(cls, "__required_keys__"): + required = set(cls.__required_keys__) + elif cls.__total__: + required = set(raw_hints) + else: + required = set() + + # Both `typing.TypedDict` and `typing_extensions.TypedDict` have a bug + # where `Required`/`NotRequired` aren't properly detected at runtime when + # `__future__.annotations` is enabled, meaning the `__required_keys__` + # isn't correct. This code block works around this issue by amending the + # set of required keys as needed, while also stripping off any + # `Required`/`NotRequired` wrappers. + hints = {} + for k, v in raw_hints.items(): + origin = getattr(v, "__origin__", False) + if origin is Required: + required.add(k) + hints[k] = v.__args__[0] + elif origin is NotRequired: + required.discard(k) + hints[k] = v.__args__[0] + else: + hints[k] = v + return hints, required + + +def get_dataclass_info(obj): + if isinstance(obj, type): + cls = obj + else: + cls = obj.__origin__ + hints = get_class_annotations(obj) + required = [] + optional = [] + defaults = [] + + if hasattr(cls, "__dataclass_fields__"): + from dataclasses import _FIELD, _FIELD_INITVAR, MISSING + + for field in cls.__dataclass_fields__.values(): + if field._field_type is not _FIELD: + if field._field_type is _FIELD_INITVAR: + raise TypeError( + "dataclasses with `InitVar` fields are not supported" + ) + continue + name = field.name + typ = hints[name] + if field.default is not MISSING: + defaults.append(field.default) + optional.append((name, typ, False)) + elif field.default_factory is not MISSING: + defaults.append(field.default_factory) + optional.append((name, typ, True)) + else: + required.append((name, typ, False)) + + required.extend(optional) + + pre_init = None + post_init = getattr(cls, "__post_init__", None) + else: + from attrs import NOTHING, Factory + + fields_with_validators = [] + + for field in cls.__attrs_attrs__: + name = field.name + typ = hints[name] + default = field.default + if default is not NOTHING: + if isinstance(default, Factory): + if default.takes_self: + raise NotImplementedError( + "Support for default factories with `takes_self=True` " + "is not implemented. File a GitHub issue if you need " + "this feature!" + ) + defaults.append(default.factory) + optional.append((name, typ, True)) + else: + defaults.append(default) + optional.append((name, typ, False)) + else: + required.append((name, typ, False)) + + if field.validator is not None: + fields_with_validators.append(field) + + required.extend(optional) + + pre_init = getattr(cls, "__attrs_pre_init__", None) + post_init = getattr(cls, "__attrs_post_init__", None) + + if fields_with_validators: + post_init = _wrap_attrs_validators(fields_with_validators, post_init) + + return cls, tuple(required), tuple(defaults), pre_init, post_init + + +def _wrap_attrs_validators(fields, post_init): + def inner(obj): + for field in fields: + field.validator(obj, field, getattr(obj, field.name)) + if post_init is not None: + post_init(obj) + + return inner + + +def rebuild(cls, kwargs): + """Used to unpickle Structs with keyword-only fields""" + return cls(**kwargs) |