summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/msgspec/_utils.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/msgspec/_utils.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/msgspec/_utils.py')
-rw-r--r--venv/lib/python3.11/site-packages/msgspec/_utils.py289
1 files changed, 289 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/msgspec/_utils.py b/venv/lib/python3.11/site-packages/msgspec/_utils.py
new file mode 100644
index 0000000..ddf6f27
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/msgspec/_utils.py
@@ -0,0 +1,289 @@
+# type: ignore
+import collections
+import sys
+import typing
+
+try:
+ from typing_extensions import _AnnotatedAlias
+except Exception:
+ try:
+ from typing import _AnnotatedAlias
+ except Exception:
+ _AnnotatedAlias = None
+
+try:
+ from typing_extensions import get_type_hints as _get_type_hints
+except Exception:
+ from typing import get_type_hints as _get_type_hints
+
+try:
+ from typing_extensions import NotRequired, Required
+except Exception:
+ try:
+ from typing import NotRequired, Required
+ except Exception:
+ Required = NotRequired = None
+
+
+if Required is None and _AnnotatedAlias is None:
+ # No extras available, so no `include_extras`
+ get_type_hints = _get_type_hints
+else:
+
+ def get_type_hints(obj):
+ return _get_type_hints(obj, include_extras=True)
+
+
+# The `is_class` argument was new in 3.11, but was backported to 3.9 and 3.10.
+# It's _likely_ to be available for 3.9/3.10, but may not be. Easiest way to
+# check is to try it and see. This check can be removed when we drop support
+# for Python 3.10.
+try:
+ typing.ForwardRef("Foo", is_class=True)
+except TypeError:
+
+ def _forward_ref(value):
+ return typing.ForwardRef(value, is_argument=False)
+
+else:
+
+ def _forward_ref(value):
+ return typing.ForwardRef(value, is_argument=False, is_class=True)
+
+
+def _apply_params(obj, mapping):
+ if params := getattr(obj, "__parameters__", None):
+ args = tuple(mapping.get(p, p) for p in params)
+ return obj[args]
+ elif isinstance(obj, typing.TypeVar):
+ return mapping.get(obj, obj)
+ return obj
+
+
+def _get_class_mro_and_typevar_mappings(obj):
+ mapping = {}
+
+ if isinstance(obj, type):
+ cls = obj
+ else:
+ cls = obj.__origin__
+
+ def inner(c, scope):
+ if isinstance(c, type):
+ cls = c
+ new_scope = {}
+ else:
+ cls = getattr(c, "__origin__", None)
+ if cls in (None, object, typing.Generic) or cls in mapping:
+ return
+ params = cls.__parameters__
+ args = tuple(_apply_params(a, scope) for a in c.__args__)
+ assert len(params) == len(args)
+ mapping[cls] = new_scope = dict(zip(params, args))
+
+ if issubclass(cls, typing.Generic):
+ bases = getattr(cls, "__orig_bases__", cls.__bases__)
+ for b in bases:
+ inner(b, new_scope)
+
+ inner(obj, {})
+ return cls.__mro__, mapping
+
+
+def get_class_annotations(obj):
+ """Get the annotations for a class.
+
+ This is similar to ``typing.get_type_hints``, except:
+
+ - We maintain it
+ - It leaves extras like ``Annotated``/``ClassVar`` alone
+ - It resolves any parametrized generics in the class mro. The returned
+ mapping may still include ``TypeVar`` values, but those should be treated
+ as their unparametrized variants (i.e. equal to ``Any`` for the common case).
+
+ Note that this function doesn't check that Generic types are being used
+ properly - invalid uses of `Generic` may slip through without complaint.
+
+ The assumption here is that the user is making use of a static analysis
+ tool like ``mypy``/``pyright`` already, which would catch misuse of these
+ APIs.
+ """
+ hints = {}
+ mro, typevar_mappings = _get_class_mro_and_typevar_mappings(obj)
+
+ for cls in mro:
+ if cls in (typing.Generic, object):
+ continue
+
+ mapping = typevar_mappings.get(cls)
+ cls_locals = dict(vars(cls))
+ cls_globals = getattr(sys.modules.get(cls.__module__, None), "__dict__", {})
+
+ ann = cls.__dict__.get("__annotations__", {})
+ for name, value in ann.items():
+ if name in hints:
+ continue
+ if value is None:
+ value = type(None)
+ elif isinstance(value, str):
+ value = _forward_ref(value)
+ value = typing._eval_type(value, cls_locals, cls_globals)
+ if mapping is not None:
+ value = _apply_params(value, mapping)
+ hints[name] = value
+ return hints
+
+
+# A mapping from a type annotation (or annotation __origin__) to the concrete
+# python type that msgspec will use when decoding. THIS IS PRIVATE FOR A
+# REASON. DON'T MUCK WITH THIS.
+_CONCRETE_TYPES = {
+ list: list,
+ tuple: tuple,
+ set: set,
+ frozenset: frozenset,
+ dict: dict,
+ typing.List: list,
+ typing.Tuple: tuple,
+ typing.Set: set,
+ typing.FrozenSet: frozenset,
+ typing.Dict: dict,
+ typing.Collection: list,
+ typing.MutableSequence: list,
+ typing.Sequence: list,
+ typing.MutableMapping: dict,
+ typing.Mapping: dict,
+ typing.MutableSet: set,
+ typing.AbstractSet: set,
+ collections.abc.Collection: list,
+ collections.abc.MutableSequence: list,
+ collections.abc.Sequence: list,
+ collections.abc.MutableSet: set,
+ collections.abc.Set: set,
+ collections.abc.MutableMapping: dict,
+ collections.abc.Mapping: dict,
+}
+
+
+def get_typeddict_info(obj):
+ if isinstance(obj, type):
+ cls = obj
+ else:
+ cls = obj.__origin__
+
+ raw_hints = get_class_annotations(obj)
+
+ if hasattr(cls, "__required_keys__"):
+ required = set(cls.__required_keys__)
+ elif cls.__total__:
+ required = set(raw_hints)
+ else:
+ required = set()
+
+ # Both `typing.TypedDict` and `typing_extensions.TypedDict` have a bug
+ # where `Required`/`NotRequired` aren't properly detected at runtime when
+ # `__future__.annotations` is enabled, meaning the `__required_keys__`
+ # isn't correct. This code block works around this issue by amending the
+ # set of required keys as needed, while also stripping off any
+ # `Required`/`NotRequired` wrappers.
+ hints = {}
+ for k, v in raw_hints.items():
+ origin = getattr(v, "__origin__", False)
+ if origin is Required:
+ required.add(k)
+ hints[k] = v.__args__[0]
+ elif origin is NotRequired:
+ required.discard(k)
+ hints[k] = v.__args__[0]
+ else:
+ hints[k] = v
+ return hints, required
+
+
+def get_dataclass_info(obj):
+ if isinstance(obj, type):
+ cls = obj
+ else:
+ cls = obj.__origin__
+ hints = get_class_annotations(obj)
+ required = []
+ optional = []
+ defaults = []
+
+ if hasattr(cls, "__dataclass_fields__"):
+ from dataclasses import _FIELD, _FIELD_INITVAR, MISSING
+
+ for field in cls.__dataclass_fields__.values():
+ if field._field_type is not _FIELD:
+ if field._field_type is _FIELD_INITVAR:
+ raise TypeError(
+ "dataclasses with `InitVar` fields are not supported"
+ )
+ continue
+ name = field.name
+ typ = hints[name]
+ if field.default is not MISSING:
+ defaults.append(field.default)
+ optional.append((name, typ, False))
+ elif field.default_factory is not MISSING:
+ defaults.append(field.default_factory)
+ optional.append((name, typ, True))
+ else:
+ required.append((name, typ, False))
+
+ required.extend(optional)
+
+ pre_init = None
+ post_init = getattr(cls, "__post_init__", None)
+ else:
+ from attrs import NOTHING, Factory
+
+ fields_with_validators = []
+
+ for field in cls.__attrs_attrs__:
+ name = field.name
+ typ = hints[name]
+ default = field.default
+ if default is not NOTHING:
+ if isinstance(default, Factory):
+ if default.takes_self:
+ raise NotImplementedError(
+ "Support for default factories with `takes_self=True` "
+ "is not implemented. File a GitHub issue if you need "
+ "this feature!"
+ )
+ defaults.append(default.factory)
+ optional.append((name, typ, True))
+ else:
+ defaults.append(default)
+ optional.append((name, typ, False))
+ else:
+ required.append((name, typ, False))
+
+ if field.validator is not None:
+ fields_with_validators.append(field)
+
+ required.extend(optional)
+
+ pre_init = getattr(cls, "__attrs_pre_init__", None)
+ post_init = getattr(cls, "__attrs_post_init__", None)
+
+ if fields_with_validators:
+ post_init = _wrap_attrs_validators(fields_with_validators, post_init)
+
+ return cls, tuple(required), tuple(defaults), pre_init, post_init
+
+
+def _wrap_attrs_validators(fields, post_init):
+ def inner(obj):
+ for field in fields:
+ field.validator(obj, field, getattr(obj, field.name))
+ if post_init is not None:
+ post_init(obj)
+
+ return inner
+
+
+def rebuild(cls, kwargs):
+ """Used to unpickle Structs with keyword-only fields"""
+ return cls(**kwargs)