diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/util/_collections.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/sqlalchemy/util/_collections.py | 715 |
1 files changed, 715 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/util/_collections.py b/venv/lib/python3.11/site-packages/sqlalchemy/util/_collections.py new file mode 100644 index 0000000..e3a8ad8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/util/_collections.py @@ -0,0 +1,715 @@ +# util/_collections.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Collection classes and helpers.""" +from __future__ import annotations + +import operator +import threading +import types +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import FrozenSet +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TypeVar +from typing import Union +from typing import ValuesView +import weakref + +from ._has_cy import HAS_CYEXTENSION +from .typing import is_non_string_iterable +from .typing import Literal +from .typing import Protocol + +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: + from ._py_collections import immutabledict as immutabledict + from ._py_collections import IdentitySet as IdentitySet + from ._py_collections import ReadOnlyContainer as ReadOnlyContainer + from ._py_collections import ImmutableDictBase as ImmutableDictBase + from ._py_collections import OrderedSet as OrderedSet + from ._py_collections import unique_list as unique_list +else: + from sqlalchemy.cyextension.immutabledict import ( + ReadOnlyContainer as ReadOnlyContainer, + ) + from sqlalchemy.cyextension.immutabledict import ( + ImmutableDictBase as ImmutableDictBase, + ) + from sqlalchemy.cyextension.immutabledict import ( + immutabledict as immutabledict, + ) + from sqlalchemy.cyextension.collections import IdentitySet as IdentitySet + from sqlalchemy.cyextension.collections import OrderedSet as OrderedSet + from sqlalchemy.cyextension.collections import ( # noqa + unique_list as unique_list, + ) + + +_T = TypeVar("_T", bound=Any) +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) +_T_co = TypeVar("_T_co", covariant=True) + +EMPTY_SET: FrozenSet[Any] = frozenset() +NONE_SET: FrozenSet[Any] = frozenset([None]) + + +def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]: + """merge two lists, maintaining ordering as much as possible. + + this is to reconcile vars(cls) with cls.__annotations__. + + Example:: + + >>> a = ['__tablename__', 'id', 'x', 'created_at'] + >>> b = ['id', 'name', 'data', 'y', 'created_at'] + >>> merge_lists_w_ordering(a, b) + ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at'] + + This is not necessarily the ordering that things had on the class, + in this case the class is:: + + class User(Base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + data: Mapped[Optional[str]] + x = Column(Integer) + y: Mapped[int] + created_at: Mapped[datetime.datetime] = mapped_column() + + But things are *mostly* ordered. + + The algorithm could also be done by creating a partial ordering for + all items in both lists and then using topological_sort(), but that + is too much overhead. + + Background on how I came up with this is at: + https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae + + """ + overlap = set(a).intersection(b) + + result = [] + + current, other = iter(a), iter(b) + + while True: + for element in current: + if element in overlap: + overlap.discard(element) + other, current = current, other + break + + result.append(element) + else: + result.extend(other) + break + + return result + + +def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]: + if not d: + return EMPTY_DICT + elif isinstance(d, immutabledict): + return d + else: + return immutabledict(d) + + +EMPTY_DICT: immutabledict[Any, Any] = immutabledict() + + +class FacadeDict(ImmutableDictBase[_KT, _VT]): + """A dictionary that is not publicly mutable.""" + + def __new__(cls, *args: Any) -> FacadeDict[Any, Any]: + new = ImmutableDictBase.__new__(cls) + return new + + def copy(self) -> NoReturn: + raise NotImplementedError( + "an immutabledict shouldn't need to be copied. use dict(d) " + "if you need a mutable dictionary." + ) + + def __reduce__(self) -> Any: + return FacadeDict, (dict(self),) + + def _insert_item(self, key: _KT, value: _VT) -> None: + """insert an item into the dictionary directly.""" + dict.__setitem__(self, key, value) + + def __repr__(self) -> str: + return "FacadeDict(%s)" % dict.__repr__(self) + + +_DT = TypeVar("_DT", bound=Any) + +_F = TypeVar("_F", bound=Any) + + +class Properties(Generic[_T]): + """Provide a __getattr__/__setattr__ interface over a dict.""" + + __slots__ = ("_data",) + + _data: Dict[str, _T] + + def __init__(self, data: Dict[str, _T]): + object.__setattr__(self, "_data", data) + + def __len__(self) -> int: + return len(self._data) + + def __iter__(self) -> Iterator[_T]: + return iter(list(self._data.values())) + + def __dir__(self) -> List[str]: + return dir(super()) + [str(k) for k in self._data.keys()] + + def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]: + return list(self) + list(other) + + def __setitem__(self, key: str, obj: _T) -> None: + self._data[key] = obj + + def __getitem__(self, key: str) -> _T: + return self._data[key] + + def __delitem__(self, key: str) -> None: + del self._data[key] + + def __setattr__(self, key: str, obj: _T) -> None: + self._data[key] = obj + + def __getstate__(self) -> Dict[str, Any]: + return {"_data": self._data} + + def __setstate__(self, state: Dict[str, Any]) -> None: + object.__setattr__(self, "_data", state["_data"]) + + def __getattr__(self, key: str) -> _T: + try: + return self._data[key] + except KeyError: + raise AttributeError(key) + + def __contains__(self, key: str) -> bool: + return key in self._data + + def as_readonly(self) -> ReadOnlyProperties[_T]: + """Return an immutable proxy for this :class:`.Properties`.""" + + return ReadOnlyProperties(self._data) + + def update(self, value: Dict[str, _T]) -> None: + self._data.update(value) + + @overload + def get(self, key: str) -> Optional[_T]: ... + + @overload + def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ... + + def get( + self, key: str, default: Optional[Union[_DT, _T]] = None + ) -> Optional[Union[_T, _DT]]: + if key in self: + return self[key] + else: + return default + + def keys(self) -> List[str]: + return list(self._data) + + def values(self) -> List[_T]: + return list(self._data.values()) + + def items(self) -> List[Tuple[str, _T]]: + return list(self._data.items()) + + def has_key(self, key: str) -> bool: + return key in self._data + + def clear(self) -> None: + self._data.clear() + + +class OrderedProperties(Properties[_T]): + """Provide a __getattr__/__setattr__ interface with an OrderedDict + as backing store.""" + + __slots__ = () + + def __init__(self): + Properties.__init__(self, OrderedDict()) + + +class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]): + """Provide immutable dict/object attribute to an underlying dictionary.""" + + __slots__ = () + + +def _ordered_dictionary_sort(d, key=None): + """Sort an OrderedDict in-place.""" + + items = [(k, d[k]) for k in sorted(d, key=key)] + + d.clear() + + d.update(items) + + +OrderedDict = dict +sort_dictionary = _ordered_dictionary_sort + + +class WeakSequence(Sequence[_T]): + def __init__(self, __elements: Sequence[_T] = ()): + # adapted from weakref.WeakKeyDictionary, prevent reference + # cycles in the collection itself + def _remove(item, selfref=weakref.ref(self)): + self = selfref() + if self is not None: + self._storage.remove(item) + + self._remove = _remove + self._storage = [ + weakref.ref(element, _remove) for element in __elements + ] + + def append(self, item): + self._storage.append(weakref.ref(item, self._remove)) + + def __len__(self): + return len(self._storage) + + def __iter__(self): + return ( + obj for obj in (ref() for ref in self._storage) if obj is not None + ) + + def __getitem__(self, index): + try: + obj = self._storage[index] + except KeyError: + raise IndexError("Index %s out of range" % index) + else: + return obj() + + +class OrderedIdentitySet(IdentitySet): + def __init__(self, iterable: Optional[Iterable[Any]] = None): + IdentitySet.__init__(self) + self._members = OrderedDict() + if iterable: + for o in iterable: + self.add(o) + + +class PopulateDict(Dict[_KT, _VT]): + """A dict which populates missing values via a creation function. + + Note the creation function takes a key, unlike + collections.defaultdict. + + """ + + def __init__(self, creator: Callable[[_KT], _VT]): + self.creator = creator + + def __missing__(self, key: Any) -> Any: + self[key] = val = self.creator(key) + return val + + +class WeakPopulateDict(Dict[_KT, _VT]): + """Like PopulateDict, but assumes a self + a method and does not create + a reference cycle. + + """ + + def __init__(self, creator_method: types.MethodType): + self.creator = creator_method.__func__ + weakself = creator_method.__self__ + self.weakself = weakref.ref(weakself) + + def __missing__(self, key: Any) -> Any: + self[key] = val = self.creator(self.weakself(), key) + return val + + +# Define collections that are capable of storing +# ColumnElement objects as hashable keys/elements. +# At this point, these are mostly historical, things +# used to be more complicated. +column_set = set +column_dict = dict +ordered_column_set = OrderedSet + + +class UniqueAppender(Generic[_T]): + """Appends items to a collection ensuring uniqueness. + + Additional appends() of the same object are ignored. Membership is + determined by identity (``is a``) not equality (``==``). + """ + + __slots__ = "data", "_data_appender", "_unique" + + data: Union[Iterable[_T], Set[_T], List[_T]] + _data_appender: Callable[[_T], None] + _unique: Dict[int, Literal[True]] + + def __init__( + self, + data: Union[Iterable[_T], Set[_T], List[_T]], + via: Optional[str] = None, + ): + self.data = data + self._unique = {} + if via: + self._data_appender = getattr(data, via) + elif hasattr(data, "append"): + self._data_appender = cast("List[_T]", data).append + elif hasattr(data, "add"): + self._data_appender = cast("Set[_T]", data).add + + def append(self, item: _T) -> None: + id_ = id(item) + if id_ not in self._unique: + self._data_appender(item) + self._unique[id_] = True + + def __iter__(self) -> Iterator[_T]: + return iter(self.data) + + +def coerce_generator_arg(arg: Any) -> List[Any]: + if len(arg) == 1 and isinstance(arg[0], types.GeneratorType): + return list(arg[0]) + else: + return cast("List[Any]", arg) + + +def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]: + if x is None: + return default # type: ignore + if not is_non_string_iterable(x): + return [x] + elif isinstance(x, list): + return x + else: + return list(x) + + +def has_intersection(set_, iterable): + r"""return True if any items of set\_ are present in iterable. + + Goes through special effort to ensure __hash__ is not called + on items in iterable that don't support it. + + """ + # TODO: optimize, write in C, etc. + return bool(set_.intersection([i for i in iterable if i.__hash__])) + + +def to_set(x): + if x is None: + return set() + if not isinstance(x, set): + return set(to_list(x)) + else: + return x + + +def to_column_set(x: Any) -> Set[Any]: + if x is None: + return column_set() + if not isinstance(x, column_set): + return column_set(to_list(x)) + else: + return x + + +def update_copy(d, _new=None, **kw): + """Copy the given dict and update with the given values.""" + + d = d.copy() + if _new: + d.update(_new) + d.update(**kw) + return d + + +def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]: + """Given an iterator of which further sub-elements may also be + iterators, flatten the sub-elements into a single iterator. + + """ + elem: _T + for elem in x: + if not isinstance(elem, str) and hasattr(elem, "__iter__"): + yield from flatten_iterator(elem) + else: + yield elem + + +class LRUCache(typing.MutableMapping[_KT, _VT]): + """Dictionary with 'squishy' removal of least + recently used items. + + Note that either get() or [] should be used here, but + generally its not safe to do an "in" check first as the dictionary + can change subsequent to that call. + + """ + + __slots__ = ( + "capacity", + "threshold", + "size_alert", + "_data", + "_counter", + "_mutex", + ) + + capacity: int + threshold: float + size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]] + + def __init__( + self, + capacity: int = 100, + threshold: float = 0.5, + size_alert: Optional[Callable[..., None]] = None, + ): + self.capacity = capacity + self.threshold = threshold + self.size_alert = size_alert + self._counter = 0 + self._mutex = threading.Lock() + self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {} + + def _inc_counter(self): + self._counter += 1 + return self._counter + + @overload + def get(self, key: _KT) -> Optional[_VT]: ... + + @overload + def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ... + + def get( + self, key: _KT, default: Optional[Union[_VT, _T]] = None + ) -> Optional[Union[_VT, _T]]: + item = self._data.get(key) + if item is not None: + item[2][0] = self._inc_counter() + return item[1] + else: + return default + + def __getitem__(self, key: _KT) -> _VT: + item = self._data[key] + item[2][0] = self._inc_counter() + return item[1] + + def __iter__(self) -> Iterator[_KT]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def values(self) -> ValuesView[_VT]: + return typing.ValuesView({k: i[1] for k, i in self._data.items()}) + + def __setitem__(self, key: _KT, value: _VT) -> None: + self._data[key] = (key, value, [self._inc_counter()]) + self._manage_size() + + def __delitem__(self, __v: _KT) -> None: + del self._data[__v] + + @property + def size_threshold(self) -> float: + return self.capacity + self.capacity * self.threshold + + def _manage_size(self) -> None: + if not self._mutex.acquire(False): + return + try: + size_alert = bool(self.size_alert) + while len(self) > self.capacity + self.capacity * self.threshold: + if size_alert: + size_alert = False + self.size_alert(self) # type: ignore + by_counter = sorted( + self._data.values(), + key=operator.itemgetter(2), + reverse=True, + ) + for item in by_counter[self.capacity :]: + try: + del self._data[item[0]] + except KeyError: + # deleted elsewhere; skip + continue + finally: + self._mutex.release() + + +class _CreateFuncType(Protocol[_T_co]): + def __call__(self) -> _T_co: ... + + +class _ScopeFuncType(Protocol): + def __call__(self) -> Any: ... + + +class ScopedRegistry(Generic[_T]): + """A Registry that can store one or multiple instances of a single + class on the basis of a "scope" function. + + The object implements ``__call__`` as the "getter", so by + calling ``myregistry()`` the contained object is returned + for the current scope. + + :param createfunc: + a callable that returns a new object to be placed in the registry + + :param scopefunc: + a callable that will return a key to store/retrieve an object. + """ + + __slots__ = "createfunc", "scopefunc", "registry" + + createfunc: _CreateFuncType[_T] + scopefunc: _ScopeFuncType + registry: Any + + def __init__( + self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any] + ): + """Construct a new :class:`.ScopedRegistry`. + + :param createfunc: A creation function that will generate + a new value for the current scope, if none is present. + + :param scopefunc: A function that returns a hashable + token representing the current scope (such as, current + thread identifier). + + """ + self.createfunc = createfunc + self.scopefunc = scopefunc + self.registry = {} + + def __call__(self) -> _T: + key = self.scopefunc() + try: + return self.registry[key] # type: ignore[no-any-return] + except KeyError: + return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501 + + def has(self) -> bool: + """Return True if an object is present in the current scope.""" + + return self.scopefunc() in self.registry + + def set(self, obj: _T) -> None: + """Set the value for the current scope.""" + + self.registry[self.scopefunc()] = obj + + def clear(self) -> None: + """Clear the current scope, if any.""" + + try: + del self.registry[self.scopefunc()] + except KeyError: + pass + + +class ThreadLocalRegistry(ScopedRegistry[_T]): + """A :class:`.ScopedRegistry` that uses a ``threading.local()`` + variable for storage. + + """ + + def __init__(self, createfunc: Callable[[], _T]): + self.createfunc = createfunc + self.registry = threading.local() + + def __call__(self) -> _T: + try: + return self.registry.value # type: ignore[no-any-return] + except AttributeError: + val = self.registry.value = self.createfunc() + return val + + def has(self) -> bool: + return hasattr(self.registry, "value") + + def set(self, obj: _T) -> None: + self.registry.value = obj + + def clear(self) -> None: + try: + del self.registry.value + except AttributeError: + pass + + +def has_dupes(sequence, target): + """Given a sequence and search object, return True if there's more + than one, False if zero or one of them. + + + """ + # compare to .index version below, this version introduces less function + # overhead and is usually the same speed. At 15000 items (way bigger than + # a relationship-bound collection in memory usually is) it begins to + # fall behind the other version only by microseconds. + c = 0 + for item in sequence: + if item is target: + c += 1 + if c > 1: + return True + return False + + +# .index version. the two __contains__ calls as well +# as .index() and isinstance() slow this down. +# def has_dupes(sequence, target): +# if target not in sequence: +# return False +# elif not isinstance(sequence, collections_abc.Sequence): +# return False +# +# idx = sequence.index(target) +# return target in sequence[idx + 1:] |