diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/sqlalchemy/ext | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/ext')
56 files changed, 18585 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__init__.py new file mode 100644 index 0000000..f03ed94 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__init__.py @@ -0,0 +1,11 @@ +# ext/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from .. import util as _sa_util + + +_sa_util.preloaded.import_prefix("sqlalchemy.ext") diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0340e5e --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/associationproxy.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/associationproxy.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5e08d9b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/associationproxy.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/automap.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/automap.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..846e172 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/automap.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/baked.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/baked.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0e36847 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/baked.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/compiler.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/compiler.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0b1beaa --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/compiler.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/horizontal_shard.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/horizontal_shard.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2bd5054 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/horizontal_shard.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/hybrid.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/hybrid.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..31a156f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/hybrid.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/indexable.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/indexable.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..d7bde5e --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/indexable.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/instrumentation.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/instrumentation.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..90b77be --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/instrumentation.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/mutable.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/mutable.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0247602 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/mutable.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/orderinglist.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/orderinglist.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c51955b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/orderinglist.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/serializer.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/serializer.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..3d5c8d3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/serializer.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/associationproxy.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/associationproxy.py new file mode 100644 index 0000000..80e6fda --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/associationproxy.py @@ -0,0 +1,2005 @@ +# ext/associationproxy.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Contain the ``AssociationProxy`` class. + +The ``AssociationProxy`` is a Python property object which provides +transparent proxied access to the endpoint of an association object. + +See the example ``examples/association/proxied_association.py``. + +""" +from __future__ import annotations + +import operator +import typing +from typing import AbstractSet +from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Dict +from typing import Generic +from typing import ItemsView +from typing import Iterable +from typing import Iterator +from typing import KeysView +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import MutableSequence +from typing import MutableSet +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +from typing import ValuesView + +from .. import ColumnElement +from .. import exc +from .. import inspect +from .. import orm +from .. import util +from ..orm import collections +from ..orm import InspectionAttrExtensionType +from ..orm import interfaces +from ..orm import ORMDescriptor +from ..orm.base import SQLORMOperations +from ..orm.interfaces import _AttributeOptions +from ..orm.interfaces import _DCAttributeOptions +from ..orm.interfaces import _DEFAULT_ATTRIBUTE_OPTIONS +from ..sql import operators +from ..sql import or_ +from ..sql.base import _NoArg +from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import Self +from ..util.typing import SupportsIndex +from ..util.typing import SupportsKeysAndGetItem + +if typing.TYPE_CHECKING: + from ..orm.interfaces import MapperProperty + from ..orm.interfaces import PropComparator + from ..orm.mapper import Mapper + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _InfoType + + +_T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) +_T_con = TypeVar("_T_con", bound=Any, contravariant=True) +_S = TypeVar("_S", bound=Any) +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + + +def association_proxy( + target_collection: str, + attr: str, + *, + creator: Optional[_CreatorProtocol] = None, + getset_factory: Optional[_GetSetFactoryProtocol] = None, + proxy_factory: Optional[_ProxyFactoryProtocol] = None, + proxy_bulk_set: Optional[_ProxyBulkSetProtocol] = None, + info: Optional[_InfoType] = None, + cascade_scalar_deletes: bool = False, + create_on_none_assignment: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, +) -> AssociationProxy[Any]: + r"""Return a Python property implementing a view of a target + attribute which references an attribute on members of the + target. + + The returned value is an instance of :class:`.AssociationProxy`. + + Implements a Python property representing a relationship as a collection + of simpler values, or a scalar value. The proxied property will mimic + the collection type of the target (list, dict or set), or, in the case of + a one to one relationship, a simple scalar value. + + :param target_collection: Name of the attribute that is the immediate + target. This attribute is typically mapped by + :func:`~sqlalchemy.orm.relationship` to link to a target collection, but + can also be a many-to-one or non-scalar relationship. + + :param attr: Attribute on the associated instance or instances that + are available on instances of the target object. + + :param creator: optional. + + Defines custom behavior when new items are added to the proxied + collection. + + By default, adding new items to the collection will trigger a + construction of an instance of the target object, passing the given + item as a positional argument to the target constructor. For cases + where this isn't sufficient, :paramref:`.association_proxy.creator` + can supply a callable that will construct the object in the + appropriate way, given the item that was passed. + + For list- and set- oriented collections, a single argument is + passed to the callable. For dictionary oriented collections, two + arguments are passed, corresponding to the key and value. + + The :paramref:`.association_proxy.creator` callable is also invoked + for scalar (i.e. many-to-one, one-to-one) relationships. If the + current value of the target relationship attribute is ``None``, the + callable is used to construct a new object. If an object value already + exists, the given attribute value is populated onto that object. + + .. seealso:: + + :ref:`associationproxy_creator` + + :param cascade_scalar_deletes: when True, indicates that setting + the proxied value to ``None``, or deleting it via ``del``, should + also remove the source object. Only applies to scalar attributes. + Normally, removing the proxied target will not remove the proxy + source, as this object may have other state that is still to be + kept. + + .. versionadded:: 1.3 + + .. seealso:: + + :ref:`cascade_scalar_deletes` - complete usage example + + :param create_on_none_assignment: when True, indicates that setting + the proxied value to ``None`` should **create** the source object + if it does not exist, using the creator. Only applies to scalar + attributes. This is mutually exclusive + vs. the :paramref:`.assocation_proxy.cascade_scalar_deletes`. + + .. versionadded:: 2.0.18 + + :param init: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__init__()`` + method as generated by the dataclass process. + + .. versionadded:: 2.0.0b4 + + :param repr: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the attribute established by this :class:`.AssociationProxy` + should be part of the ``__repr__()`` method as generated by the dataclass + process. + + .. versionadded:: 2.0.0b4 + + :param default_factory: Specific to + :ref:`orm_declarative_native_dataclasses`, specifies a default-value + generation function that will take place as part of the ``__init__()`` + method as generated by the dataclass process. + + .. versionadded:: 2.0.0b4 + + :param compare: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be included in comparison operations when generating the + ``__eq__()`` and ``__ne__()`` methods for the mapped class. + + .. versionadded:: 2.0.0b4 + + :param kw_only: Specific to :ref:`orm_declarative_native_dataclasses`, + indicates if this field should be marked as keyword-only when generating + the ``__init__()`` method as generated by the dataclass process. + + .. versionadded:: 2.0.0b4 + + :param info: optional, will be assigned to + :attr:`.AssociationProxy.info` if present. + + + The following additional parameters involve injection of custom behaviors + within the :class:`.AssociationProxy` object and are for advanced use + only: + + :param getset_factory: Optional. Proxied attribute access is + automatically handled by routines that get and set values based on + the `attr` argument for this proxy. + + If you would like to customize this behavior, you may supply a + `getset_factory` callable that produces a tuple of `getter` and + `setter` functions. The factory is called with two arguments, the + abstract type of the underlying collection and this proxy instance. + + :param proxy_factory: Optional. The type of collection to emulate is + determined by sniffing the target collection. If your collection + type can't be determined by duck typing or you'd like to use a + different collection implementation, you may supply a factory + function to produce those collections. Only applicable to + non-scalar relationships. + + :param proxy_bulk_set: Optional, use with proxy_factory. + + + """ + return AssociationProxy( + target_collection, + attr, + creator=creator, + getset_factory=getset_factory, + proxy_factory=proxy_factory, + proxy_bulk_set=proxy_bulk_set, + info=info, + cascade_scalar_deletes=cascade_scalar_deletes, + create_on_none_assignment=create_on_none_assignment, + attribute_options=_AttributeOptions( + init, repr, default, default_factory, compare, kw_only + ), + ) + + +class AssociationProxyExtensionType(InspectionAttrExtensionType): + ASSOCIATION_PROXY = "ASSOCIATION_PROXY" + """Symbol indicating an :class:`.InspectionAttr` that's + of type :class:`.AssociationProxy`. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + + """ + + +class _GetterProtocol(Protocol[_T_co]): + def __call__(self, instance: Any) -> _T_co: ... + + +# mypy 0.990 we are no longer allowed to make this Protocol[_T_con] +class _SetterProtocol(Protocol): ... + + +class _PlainSetterProtocol(_SetterProtocol, Protocol[_T_con]): + def __call__(self, instance: Any, value: _T_con) -> None: ... + + +class _DictSetterProtocol(_SetterProtocol, Protocol[_T_con]): + def __call__(self, instance: Any, key: Any, value: _T_con) -> None: ... + + +# mypy 0.990 we are no longer allowed to make this Protocol[_T_con] +class _CreatorProtocol(Protocol): ... + + +class _PlainCreatorProtocol(_CreatorProtocol, Protocol[_T_con]): + def __call__(self, value: _T_con) -> Any: ... + + +class _KeyCreatorProtocol(_CreatorProtocol, Protocol[_T_con]): + def __call__(self, key: Any, value: Optional[_T_con]) -> Any: ... + + +class _LazyCollectionProtocol(Protocol[_T]): + def __call__( + self, + ) -> Union[ + MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T] + ]: ... + + +class _GetSetFactoryProtocol(Protocol): + def __call__( + self, + collection_class: Optional[Type[Any]], + assoc_instance: AssociationProxyInstance[Any], + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: ... + + +class _ProxyFactoryProtocol(Protocol): + def __call__( + self, + lazy_collection: _LazyCollectionProtocol[Any], + creator: _CreatorProtocol, + value_attr: str, + parent: AssociationProxyInstance[Any], + ) -> Any: ... + + +class _ProxyBulkSetProtocol(Protocol): + def __call__( + self, proxy: _AssociationCollection[Any], collection: Iterable[Any] + ) -> None: ... + + +class _AssociationProxyProtocol(Protocol[_T]): + """describes the interface of :class:`.AssociationProxy` + without including descriptor methods in the interface.""" + + creator: Optional[_CreatorProtocol] + key: str + target_collection: str + value_attr: str + cascade_scalar_deletes: bool + create_on_none_assignment: bool + getset_factory: Optional[_GetSetFactoryProtocol] + proxy_factory: Optional[_ProxyFactoryProtocol] + proxy_bulk_set: Optional[_ProxyBulkSetProtocol] + + @util.ro_memoized_property + def info(self) -> _InfoType: ... + + def for_class( + self, class_: Type[Any], obj: Optional[object] = None + ) -> AssociationProxyInstance[_T]: ... + + def _default_getset( + self, collection_class: Any + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: ... + + +class AssociationProxy( + interfaces.InspectionAttrInfo, + ORMDescriptor[_T], + _DCAttributeOptions, + _AssociationProxyProtocol[_T], +): + """A descriptor that presents a read/write view of an object attribute.""" + + is_attribute = True + extension_type = AssociationProxyExtensionType.ASSOCIATION_PROXY + + def __init__( + self, + target_collection: str, + attr: str, + *, + creator: Optional[_CreatorProtocol] = None, + getset_factory: Optional[_GetSetFactoryProtocol] = None, + proxy_factory: Optional[_ProxyFactoryProtocol] = None, + proxy_bulk_set: Optional[_ProxyBulkSetProtocol] = None, + info: Optional[_InfoType] = None, + cascade_scalar_deletes: bool = False, + create_on_none_assignment: bool = False, + attribute_options: Optional[_AttributeOptions] = None, + ): + """Construct a new :class:`.AssociationProxy`. + + The :class:`.AssociationProxy` object is typically constructed using + the :func:`.association_proxy` constructor function. See the + description of :func:`.association_proxy` for a description of all + parameters. + + + """ + self.target_collection = target_collection + self.value_attr = attr + self.creator = creator + self.getset_factory = getset_factory + self.proxy_factory = proxy_factory + self.proxy_bulk_set = proxy_bulk_set + + if cascade_scalar_deletes and create_on_none_assignment: + raise exc.ArgumentError( + "The cascade_scalar_deletes and create_on_none_assignment " + "parameters are mutually exclusive." + ) + self.cascade_scalar_deletes = cascade_scalar_deletes + self.create_on_none_assignment = create_on_none_assignment + + self.key = "_%s_%s_%s" % ( + type(self).__name__, + target_collection, + id(self), + ) + if info: + self.info = info # type: ignore + + if ( + attribute_options + and attribute_options != _DEFAULT_ATTRIBUTE_OPTIONS + ): + self._has_dataclass_arguments = True + self._attribute_options = attribute_options + else: + self._has_dataclass_arguments = False + self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS + + @overload + def __get__( + self, instance: Literal[None], owner: Literal[None] + ) -> Self: ... + + @overload + def __get__( + self, instance: Literal[None], owner: Any + ) -> AssociationProxyInstance[_T]: ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T: ... + + def __get__( + self, instance: object, owner: Any + ) -> Union[AssociationProxyInstance[_T], _T, AssociationProxy[_T]]: + if owner is None: + return self + inst = self._as_instance(owner, instance) + if inst: + return inst.get(instance) + + assert instance is None + + return self + + def __set__(self, instance: object, values: _T) -> None: + class_ = type(instance) + self._as_instance(class_, instance).set(instance, values) + + def __delete__(self, instance: object) -> None: + class_ = type(instance) + self._as_instance(class_, instance).delete(instance) + + def for_class( + self, class_: Type[Any], obj: Optional[object] = None + ) -> AssociationProxyInstance[_T]: + r"""Return the internal state local to a specific mapped class. + + E.g., given a class ``User``:: + + class User(Base): + # ... + + keywords = association_proxy('kws', 'keyword') + + If we access this :class:`.AssociationProxy` from + :attr:`_orm.Mapper.all_orm_descriptors`, and we want to view the + target class for this proxy as mapped by ``User``:: + + inspect(User).all_orm_descriptors["keywords"].for_class(User).target_class + + This returns an instance of :class:`.AssociationProxyInstance` that + is specific to the ``User`` class. The :class:`.AssociationProxy` + object remains agnostic of its parent class. + + :param class\_: the class that we are returning state for. + + :param obj: optional, an instance of the class that is required + if the attribute refers to a polymorphic target, e.g. where we have + to look at the type of the actual destination object to get the + complete path. + + .. versionadded:: 1.3 - :class:`.AssociationProxy` no longer stores + any state specific to a particular parent class; the state is now + stored in per-class :class:`.AssociationProxyInstance` objects. + + + """ + return self._as_instance(class_, obj) + + def _as_instance( + self, class_: Any, obj: Any + ) -> AssociationProxyInstance[_T]: + try: + inst = class_.__dict__[self.key + "_inst"] + except KeyError: + inst = None + + # avoid exception context + if inst is None: + owner = self._calc_owner(class_) + if owner is not None: + inst = AssociationProxyInstance.for_proxy(self, owner, obj) + setattr(class_, self.key + "_inst", inst) + else: + inst = None + + if inst is not None and not inst._is_canonical: + # the AssociationProxyInstance can't be generalized + # since the proxied attribute is not on the targeted + # class, only on subclasses of it, which might be + # different. only return for the specific + # object's current value + return inst._non_canonical_get_for_object(obj) # type: ignore + else: + return inst # type: ignore # TODO + + def _calc_owner(self, target_cls: Any) -> Any: + # we might be getting invoked for a subclass + # that is not mapped yet, in some declarative situations. + # save until we are mapped + try: + insp = inspect(target_cls) + except exc.NoInspectionAvailable: + # can't find a mapper, don't set owner. if we are a not-yet-mapped + # subclass, we can also scan through __mro__ to find a mapped + # class, but instead just wait for us to be called again against a + # mapped class normally. + return None + else: + return insp.mapper.class_manager.class_ + + def _default_getset( + self, collection_class: Any + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: + attr = self.value_attr + _getter = operator.attrgetter(attr) + + def getter(instance: Any) -> Optional[Any]: + return _getter(instance) if instance is not None else None + + if collection_class is dict: + + def dict_setter(instance: Any, k: Any, value: Any) -> None: + setattr(instance, attr, value) + + return getter, dict_setter + + else: + + def plain_setter(o: Any, v: Any) -> None: + setattr(o, attr, v) + + return getter, plain_setter + + def __repr__(self) -> str: + return "AssociationProxy(%r, %r)" % ( + self.target_collection, + self.value_attr, + ) + + +# the pep-673 Self type does not work in Mypy for a "hybrid" +# style method that returns type or Self, so for one specific case +# we still need to use the pre-pep-673 workaround. +_Self = TypeVar("_Self", bound="AssociationProxyInstance[Any]") + + +class AssociationProxyInstance(SQLORMOperations[_T]): + """A per-class object that serves class- and object-specific results. + + This is used by :class:`.AssociationProxy` when it is invoked + in terms of a specific class or instance of a class, i.e. when it is + used as a regular Python descriptor. + + When referring to the :class:`.AssociationProxy` as a normal Python + descriptor, the :class:`.AssociationProxyInstance` is the object that + actually serves the information. Under normal circumstances, its presence + is transparent:: + + >>> User.keywords.scalar + False + + In the special case that the :class:`.AssociationProxy` object is being + accessed directly, in order to get an explicit handle to the + :class:`.AssociationProxyInstance`, use the + :meth:`.AssociationProxy.for_class` method:: + + proxy_state = inspect(User).all_orm_descriptors["keywords"].for_class(User) + + # view if proxy object is scalar or not + >>> proxy_state.scalar + False + + .. versionadded:: 1.3 + + """ # noqa + + collection_class: Optional[Type[Any]] + parent: _AssociationProxyProtocol[_T] + + def __init__( + self, + parent: _AssociationProxyProtocol[_T], + owning_class: Type[Any], + target_class: Type[Any], + value_attr: str, + ): + self.parent = parent + self.key = parent.key + self.owning_class = owning_class + self.target_collection = parent.target_collection + self.collection_class = None + self.target_class = target_class + self.value_attr = value_attr + + target_class: Type[Any] + """The intermediary class handled by this + :class:`.AssociationProxyInstance`. + + Intercepted append/set/assignment events will result + in the generation of new instances of this class. + + """ + + @classmethod + def for_proxy( + cls, + parent: AssociationProxy[_T], + owning_class: Type[Any], + parent_instance: Any, + ) -> AssociationProxyInstance[_T]: + target_collection = parent.target_collection + value_attr = parent.value_attr + prop = cast( + "orm.RelationshipProperty[_T]", + orm.class_mapper(owning_class).get_property(target_collection), + ) + + # this was never asserted before but this should be made clear. + if not isinstance(prop, orm.RelationshipProperty): + raise NotImplementedError( + "association proxy to a non-relationship " + "intermediary is not supported" + ) from None + + target_class = prop.mapper.class_ + + try: + target_assoc = cast( + "AssociationProxyInstance[_T]", + cls._cls_unwrap_target_assoc_proxy(target_class, value_attr), + ) + except AttributeError: + # the proxied attribute doesn't exist on the target class; + # return an "ambiguous" instance that will work on a per-object + # basis + return AmbiguousAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + except Exception as err: + raise exc.InvalidRequestError( + f"Association proxy received an unexpected error when " + f"trying to retreive attribute " + f'"{target_class.__name__}.{parent.value_attr}" from ' + f'class "{target_class.__name__}": {err}' + ) from err + else: + return cls._construct_for_assoc( + target_assoc, parent, owning_class, target_class, value_attr + ) + + @classmethod + def _construct_for_assoc( + cls, + target_assoc: Optional[AssociationProxyInstance[_T]], + parent: _AssociationProxyProtocol[_T], + owning_class: Type[Any], + target_class: Type[Any], + value_attr: str, + ) -> AssociationProxyInstance[_T]: + if target_assoc is not None: + return ObjectAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + + attr = getattr(target_class, value_attr) + if not hasattr(attr, "_is_internal_proxy"): + return AmbiguousAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + is_object = attr._impl_uses_objects + if is_object: + return ObjectAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + else: + return ColumnAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + + def _get_property(self) -> MapperProperty[Any]: + return orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) + + @property + def _comparator(self) -> PropComparator[Any]: + return getattr( # type: ignore + self.owning_class, self.target_collection + ).comparator + + def __clause_element__(self) -> NoReturn: + raise NotImplementedError( + "The association proxy can't be used as a plain column " + "expression; it only works inside of a comparison expression" + ) + + @classmethod + def _cls_unwrap_target_assoc_proxy( + cls, target_class: Any, value_attr: str + ) -> Optional[AssociationProxyInstance[_T]]: + attr = getattr(target_class, value_attr) + assert not isinstance(attr, AssociationProxy) + if isinstance(attr, AssociationProxyInstance): + return attr + return None + + @util.memoized_property + def _unwrap_target_assoc_proxy( + self, + ) -> Optional[AssociationProxyInstance[_T]]: + return self._cls_unwrap_target_assoc_proxy( + self.target_class, self.value_attr + ) + + @property + def remote_attr(self) -> SQLORMOperations[_T]: + """The 'remote' class attribute referenced by this + :class:`.AssociationProxyInstance`. + + .. seealso:: + + :attr:`.AssociationProxyInstance.attr` + + :attr:`.AssociationProxyInstance.local_attr` + + """ + return cast( + "SQLORMOperations[_T]", getattr(self.target_class, self.value_attr) + ) + + @property + def local_attr(self) -> SQLORMOperations[Any]: + """The 'local' class attribute referenced by this + :class:`.AssociationProxyInstance`. + + .. seealso:: + + :attr:`.AssociationProxyInstance.attr` + + :attr:`.AssociationProxyInstance.remote_attr` + + """ + return cast( + "SQLORMOperations[Any]", + getattr(self.owning_class, self.target_collection), + ) + + @property + def attr(self) -> Tuple[SQLORMOperations[Any], SQLORMOperations[_T]]: + """Return a tuple of ``(local_attr, remote_attr)``. + + This attribute was originally intended to facilitate using the + :meth:`_query.Query.join` method to join across the two relationships + at once, however this makes use of a deprecated calling style. + + To use :meth:`_sql.select.join` or :meth:`_orm.Query.join` with + an association proxy, the current method is to make use of the + :attr:`.AssociationProxyInstance.local_attr` and + :attr:`.AssociationProxyInstance.remote_attr` attributes separately:: + + stmt = ( + select(Parent). + join(Parent.proxied.local_attr). + join(Parent.proxied.remote_attr) + ) + + A future release may seek to provide a more succinct join pattern + for association proxy attributes. + + .. seealso:: + + :attr:`.AssociationProxyInstance.local_attr` + + :attr:`.AssociationProxyInstance.remote_attr` + + """ + return (self.local_attr, self.remote_attr) + + @util.memoized_property + def scalar(self) -> bool: + """Return ``True`` if this :class:`.AssociationProxyInstance` + proxies a scalar relationship on the local side.""" + + scalar = not self._get_property().uselist + if scalar: + self._initialize_scalar_accessors() + return scalar + + @util.memoized_property + def _value_is_scalar(self) -> bool: + return ( + not self._get_property() + .mapper.get_property(self.value_attr) + .uselist + ) + + @property + def _target_is_object(self) -> bool: + raise NotImplementedError() + + _scalar_get: _GetterProtocol[_T] + _scalar_set: _PlainSetterProtocol[_T] + + def _initialize_scalar_accessors(self) -> None: + if self.parent.getset_factory: + get, set_ = self.parent.getset_factory(None, self) + else: + get, set_ = self.parent._default_getset(None) + self._scalar_get, self._scalar_set = get, cast( + "_PlainSetterProtocol[_T]", set_ + ) + + def _default_getset( + self, collection_class: Any + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: + attr = self.value_attr + _getter = operator.attrgetter(attr) + + def getter(instance: Any) -> Optional[_T]: + return _getter(instance) if instance is not None else None + + if collection_class is dict: + + def dict_setter(instance: Any, k: Any, value: _T) -> None: + setattr(instance, attr, value) + + return getter, dict_setter + else: + + def plain_setter(o: Any, v: _T) -> None: + setattr(o, attr, v) + + return getter, plain_setter + + @util.ro_non_memoized_property + def info(self) -> _InfoType: + return self.parent.info + + @overload + def get(self: _Self, obj: Literal[None]) -> _Self: ... + + @overload + def get(self, obj: Any) -> _T: ... + + def get( + self, obj: Any + ) -> Union[Optional[_T], AssociationProxyInstance[_T]]: + if obj is None: + return self + + proxy: _T + + if self.scalar: + target = getattr(obj, self.target_collection) + return self._scalar_get(target) + else: + try: + # If the owning instance is reborn (orm session resurrect, + # etc.), refresh the proxy cache. + creator_id, self_id, proxy = cast( + "Tuple[int, int, _T]", getattr(obj, self.key) + ) + except AttributeError: + pass + else: + if id(obj) == creator_id and id(self) == self_id: + assert self.collection_class is not None + return proxy + + self.collection_class, proxy = self._new( + _lazy_collection(obj, self.target_collection) + ) + setattr(obj, self.key, (id(obj), id(self), proxy)) + return proxy + + def set(self, obj: Any, values: _T) -> None: + if self.scalar: + creator = cast( + "_PlainCreatorProtocol[_T]", + ( + self.parent.creator + if self.parent.creator + else self.target_class + ), + ) + target = getattr(obj, self.target_collection) + if target is None: + if ( + values is None + and not self.parent.create_on_none_assignment + ): + return + setattr(obj, self.target_collection, creator(values)) + else: + self._scalar_set(target, values) + if values is None and self.parent.cascade_scalar_deletes: + setattr(obj, self.target_collection, None) + else: + proxy = self.get(obj) + assert self.collection_class is not None + if proxy is not values: + proxy._bulk_replace(self, values) + + def delete(self, obj: Any) -> None: + if self.owning_class is None: + self._calc_owner(obj, None) + + if self.scalar: + target = getattr(obj, self.target_collection) + if target is not None: + delattr(target, self.value_attr) + delattr(obj, self.target_collection) + + def _new( + self, lazy_collection: _LazyCollectionProtocol[_T] + ) -> Tuple[Type[Any], _T]: + creator = ( + self.parent.creator + if self.parent.creator is not None + else cast("_CreatorProtocol", self.target_class) + ) + collection_class = util.duck_type_collection(lazy_collection()) + + if collection_class is None: + raise exc.InvalidRequestError( + f"lazy collection factory did not return a " + f"valid collection type, got {collection_class}" + ) + if self.parent.proxy_factory: + return ( + collection_class, + self.parent.proxy_factory( + lazy_collection, creator, self.value_attr, self + ), + ) + + if self.parent.getset_factory: + getter, setter = self.parent.getset_factory(collection_class, self) + else: + getter, setter = self.parent._default_getset(collection_class) + + if collection_class is list: + return ( + collection_class, + cast( + _T, + _AssociationList( + lazy_collection, creator, getter, setter, self + ), + ), + ) + elif collection_class is dict: + return ( + collection_class, + cast( + _T, + _AssociationDict( + lazy_collection, creator, getter, setter, self + ), + ), + ) + elif collection_class is set: + return ( + collection_class, + cast( + _T, + _AssociationSet( + lazy_collection, creator, getter, setter, self + ), + ), + ) + else: + raise exc.ArgumentError( + "could not guess which interface to use for " + 'collection_class "%s" backing "%s"; specify a ' + "proxy_factory and proxy_bulk_set manually" + % (self.collection_class, self.target_collection) + ) + + def _set( + self, proxy: _AssociationCollection[Any], values: Iterable[Any] + ) -> None: + if self.parent.proxy_bulk_set: + self.parent.proxy_bulk_set(proxy, values) + elif self.collection_class is list: + cast("_AssociationList[Any]", proxy).extend(values) + elif self.collection_class is dict: + cast("_AssociationDict[Any, Any]", proxy).update(values) + elif self.collection_class is set: + cast("_AssociationSet[Any]", proxy).update(values) + else: + raise exc.ArgumentError( + "no proxy_bulk_set supplied for custom " + "collection_class implementation" + ) + + def _inflate(self, proxy: _AssociationCollection[Any]) -> None: + creator = ( + self.parent.creator + and self.parent.creator + or cast(_CreatorProtocol, self.target_class) + ) + + if self.parent.getset_factory: + getter, setter = self.parent.getset_factory( + self.collection_class, self + ) + else: + getter, setter = self.parent._default_getset(self.collection_class) + + proxy.creator = creator + proxy.getter = getter + proxy.setter = setter + + def _criterion_exists( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + is_has = kwargs.pop("is_has", None) + + target_assoc = self._unwrap_target_assoc_proxy + if target_assoc is not None: + inner = target_assoc._criterion_exists( + criterion=criterion, **kwargs + ) + return self._comparator._criterion_exists(inner) + + if self._target_is_object: + attr = getattr(self.target_class, self.value_attr) + value_expr = attr.comparator._criterion_exists(criterion, **kwargs) + else: + if kwargs: + raise exc.ArgumentError( + "Can't apply keyword arguments to column-targeted " + "association proxy; use ==" + ) + elif is_has and criterion is not None: + raise exc.ArgumentError( + "Non-empty has() not allowed for " + "column-targeted association proxy; use ==" + ) + + value_expr = criterion + + return self._comparator._criterion_exists(value_expr) + + def any( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + """Produce a proxied 'any' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.Relationship.Comparator.any` + and/or :meth:`.Relationship.Comparator.has` + operators of the underlying proxied attributes. + + """ + if self._unwrap_target_assoc_proxy is None and ( + self.scalar + and (not self._target_is_object or self._value_is_scalar) + ): + raise exc.InvalidRequestError( + "'any()' not implemented for scalar attributes. Use has()." + ) + return self._criterion_exists( + criterion=criterion, is_has=False, **kwargs + ) + + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + """Produce a proxied 'has' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.Relationship.Comparator.any` + and/or :meth:`.Relationship.Comparator.has` + operators of the underlying proxied attributes. + + """ + if self._unwrap_target_assoc_proxy is None and ( + not self.scalar + or (self._target_is_object and not self._value_is_scalar) + ): + raise exc.InvalidRequestError( + "'has()' not implemented for collections. Use any()." + ) + return self._criterion_exists( + criterion=criterion, is_has=True, **kwargs + ) + + def __repr__(self) -> str: + return "%s(%r)" % (self.__class__.__name__, self.parent) + + +class AmbiguousAssociationProxyInstance(AssociationProxyInstance[_T]): + """an :class:`.AssociationProxyInstance` where we cannot determine + the type of target object. + """ + + _is_canonical = False + + def _ambiguous(self) -> NoReturn: + raise AttributeError( + "Association proxy %s.%s refers to an attribute '%s' that is not " + "directly mapped on class %s; therefore this operation cannot " + "proceed since we don't know what type of object is referred " + "towards" + % ( + self.owning_class.__name__, + self.target_collection, + self.value_attr, + self.target_class, + ) + ) + + def get(self, obj: Any) -> Any: + if obj is None: + return self + else: + return super().get(obj) + + def __eq__(self, obj: object) -> NoReturn: + self._ambiguous() + + def __ne__(self, obj: object) -> NoReturn: + self._ambiguous() + + def any( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> NoReturn: + self._ambiguous() + + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> NoReturn: + self._ambiguous() + + @util.memoized_property + def _lookup_cache(self) -> Dict[Type[Any], AssociationProxyInstance[_T]]: + # mapping of <subclass>->AssociationProxyInstance. + # e.g. proxy is A-> A.b -> B -> B.b_attr, but B.b_attr doesn't exist; + # only B1(B) and B2(B) have "b_attr", keys in here would be B1, B2 + return {} + + def _non_canonical_get_for_object( + self, parent_instance: Any + ) -> AssociationProxyInstance[_T]: + if parent_instance is not None: + actual_obj = getattr(parent_instance, self.target_collection) + if actual_obj is not None: + try: + insp = inspect(actual_obj) + except exc.NoInspectionAvailable: + pass + else: + mapper = insp.mapper + instance_class = mapper.class_ + if instance_class not in self._lookup_cache: + self._populate_cache(instance_class, mapper) + + try: + return self._lookup_cache[instance_class] + except KeyError: + pass + + # no object or ambiguous object given, so return "self", which + # is a proxy with generally only instance-level functionality + return self + + def _populate_cache( + self, instance_class: Any, mapper: Mapper[Any] + ) -> None: + prop = orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) + + if mapper.isa(prop.mapper): + target_class = instance_class + try: + target_assoc = self._cls_unwrap_target_assoc_proxy( + target_class, self.value_attr + ) + except AttributeError: + pass + else: + self._lookup_cache[instance_class] = self._construct_for_assoc( + cast("AssociationProxyInstance[_T]", target_assoc), + self.parent, + self.owning_class, + target_class, + self.value_attr, + ) + + +class ObjectAssociationProxyInstance(AssociationProxyInstance[_T]): + """an :class:`.AssociationProxyInstance` that has an object as a target.""" + + _target_is_object: bool = True + _is_canonical = True + + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: + """Produce a proxied 'contains' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.Relationship.Comparator.any`, + :meth:`.Relationship.Comparator.has`, + and/or :meth:`.Relationship.Comparator.contains` + operators of the underlying proxied attributes. + """ + + target_assoc = self._unwrap_target_assoc_proxy + if target_assoc is not None: + return self._comparator._criterion_exists( + target_assoc.contains(other) + if not target_assoc.scalar + else target_assoc == other + ) + elif ( + self._target_is_object + and self.scalar + and not self._value_is_scalar + ): + return self._comparator.has( + getattr(self.target_class, self.value_attr).contains(other) + ) + elif self._target_is_object and self.scalar and self._value_is_scalar: + raise exc.InvalidRequestError( + "contains() doesn't apply to a scalar object endpoint; use ==" + ) + else: + return self._comparator._criterion_exists( + **{self.value_attr: other} + ) + + def __eq__(self, obj: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + # note the has() here will fail for collections; eq_() + # is only allowed with a scalar. + if obj is None: + return or_( + self._comparator.has(**{self.value_attr: obj}), + self._comparator == None, + ) + else: + return self._comparator.has(**{self.value_attr: obj}) + + def __ne__(self, obj: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + # note the has() here will fail for collections; eq_() + # is only allowed with a scalar. + return self._comparator.has( + getattr(self.target_class, self.value_attr) != obj + ) + + +class ColumnAssociationProxyInstance(AssociationProxyInstance[_T]): + """an :class:`.AssociationProxyInstance` that has a database column as a + target. + """ + + _target_is_object: bool = False + _is_canonical = True + + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + # special case "is None" to check for no related row as well + expr = self._criterion_exists( + self.remote_attr.operate(operators.eq, other) + ) + if other is None: + return or_(expr, self._comparator == None) + else: + return expr + + def operate( + self, op: operators.OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return self._criterion_exists( + self.remote_attr.operate(op, *other, **kwargs) + ) + + +class _lazy_collection(_LazyCollectionProtocol[_T]): + def __init__(self, obj: Any, target: str): + self.parent = obj + self.target = target + + def __call__( + self, + ) -> Union[MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]]: + return getattr(self.parent, self.target) # type: ignore[no-any-return] + + def __getstate__(self) -> Any: + return {"obj": self.parent, "target": self.target} + + def __setstate__(self, state: Any) -> None: + self.parent = state["obj"] + self.target = state["target"] + + +_IT = TypeVar("_IT", bound="Any") +"""instance type - this is the type of object inside a collection. + +this is not the same as the _T of AssociationProxy and +AssociationProxyInstance itself, which will often refer to the +collection[_IT] type. + +""" + + +class _AssociationCollection(Generic[_IT]): + getter: _GetterProtocol[_IT] + """A function. Given an associated object, return the 'value'.""" + + creator: _CreatorProtocol + """ + A function that creates new target entities. Given one parameter: + value. This assertion is assumed:: + + obj = creator(somevalue) + assert getter(obj) == somevalue + """ + + parent: AssociationProxyInstance[_IT] + setter: _SetterProtocol + """A function. Given an associated object and a value, store that + value on the object. + """ + + lazy_collection: _LazyCollectionProtocol[_IT] + """A callable returning a list-based collection of entities (usually an + object attribute managed by a SQLAlchemy relationship())""" + + def __init__( + self, + lazy_collection: _LazyCollectionProtocol[_IT], + creator: _CreatorProtocol, + getter: _GetterProtocol[_IT], + setter: _SetterProtocol, + parent: AssociationProxyInstance[_IT], + ): + """Constructs an _AssociationCollection. + + This will always be a subclass of either _AssociationList, + _AssociationSet, or _AssociationDict. + + """ + self.lazy_collection = lazy_collection + self.creator = creator + self.getter = getter + self.setter = setter + self.parent = parent + + if typing.TYPE_CHECKING: + col: Collection[_IT] + else: + col = property(lambda self: self.lazy_collection()) + + def __len__(self) -> int: + return len(self.col) + + def __bool__(self) -> bool: + return bool(self.col) + + def __getstate__(self) -> Any: + return {"parent": self.parent, "lazy_collection": self.lazy_collection} + + def __setstate__(self, state: Any) -> None: + self.parent = state["parent"] + self.lazy_collection = state["lazy_collection"] + self.parent._inflate(self) + + def clear(self) -> None: + raise NotImplementedError() + + +class _AssociationSingleItem(_AssociationCollection[_T]): + setter: _PlainSetterProtocol[_T] + creator: _PlainCreatorProtocol[_T] + + def _create(self, value: _T) -> Any: + return self.creator(value) + + def _get(self, object_: Any) -> _T: + return self.getter(object_) + + def _bulk_replace( + self, assoc_proxy: AssociationProxyInstance[Any], values: Iterable[_IT] + ) -> None: + self.clear() + assoc_proxy._set(self, values) + + +class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]): + """Generic, converting, list-to-list proxy.""" + + col: MutableSequence[_T] + + def _set(self, object_: Any, value: _T) -> None: + self.setter(object_, value) + + @overload + def __getitem__(self, index: int) -> _T: ... + + @overload + def __getitem__(self, index: slice) -> MutableSequence[_T]: ... + + def __getitem__( + self, index: Union[int, slice] + ) -> Union[_T, MutableSequence[_T]]: + if not isinstance(index, slice): + return self._get(self.col[index]) + else: + return [self._get(member) for member in self.col[index]] + + @overload + def __setitem__(self, index: int, value: _T) -> None: ... + + @overload + def __setitem__(self, index: slice, value: Iterable[_T]) -> None: ... + + def __setitem__( + self, index: Union[int, slice], value: Union[_T, Iterable[_T]] + ) -> None: + if not isinstance(index, slice): + self._set(self.col[index], cast("_T", value)) + else: + if index.stop is None: + stop = len(self) + elif index.stop < 0: + stop = len(self) + index.stop + else: + stop = index.stop + step = index.step or 1 + + start = index.start or 0 + rng = list(range(index.start or 0, stop, step)) + + sized_value = list(value) + + if step == 1: + for i in rng: + del self[start] + i = start + for item in sized_value: + self.insert(i, item) + i += 1 + else: + if len(sized_value) != len(rng): + raise ValueError( + "attempt to assign sequence of size %s to " + "extended slice of size %s" + % (len(sized_value), len(rng)) + ) + for i, item in zip(rng, value): + self._set(self.col[i], item) + + @overload + def __delitem__(self, index: int) -> None: ... + + @overload + def __delitem__(self, index: slice) -> None: ... + + def __delitem__(self, index: Union[slice, int]) -> None: + del self.col[index] + + def __contains__(self, value: object) -> bool: + for member in self.col: + # testlib.pragma exempt:__eq__ + if self._get(member) == value: + return True + return False + + def __iter__(self) -> Iterator[_T]: + """Iterate over proxied values. + + For the actual domain objects, iterate over .col instead or + just use the underlying collection directly from its property + on the parent. + """ + + for member in self.col: + yield self._get(member) + return + + def append(self, value: _T) -> None: + col = self.col + item = self._create(value) + col.append(item) + + def count(self, value: Any) -> int: + count = 0 + for v in self: + if v == value: + count += 1 + return count + + def extend(self, values: Iterable[_T]) -> None: + for v in values: + self.append(v) + + def insert(self, index: int, value: _T) -> None: + self.col[index:index] = [self._create(value)] + + def pop(self, index: int = -1) -> _T: + return self.getter(self.col.pop(index)) + + def remove(self, value: _T) -> None: + for i, val in enumerate(self): + if val == value: + del self.col[i] + return + raise ValueError("value not in list") + + def reverse(self) -> NoReturn: + """Not supported, use reversed(mylist)""" + + raise NotImplementedError() + + def sort(self) -> NoReturn: + """Not supported, use sorted(mylist)""" + + raise NotImplementedError() + + def clear(self) -> None: + del self.col[0 : len(self.col)] + + def __eq__(self, other: object) -> bool: + return list(self) == other + + def __ne__(self, other: object) -> bool: + return list(self) != other + + def __lt__(self, other: List[_T]) -> bool: + return list(self) < other + + def __le__(self, other: List[_T]) -> bool: + return list(self) <= other + + def __gt__(self, other: List[_T]) -> bool: + return list(self) > other + + def __ge__(self, other: List[_T]) -> bool: + return list(self) >= other + + def __add__(self, other: List[_T]) -> List[_T]: + try: + other = list(other) + except TypeError: + return NotImplemented + return list(self) + other + + def __radd__(self, other: List[_T]) -> List[_T]: + try: + other = list(other) + except TypeError: + return NotImplemented + return other + list(self) + + def __mul__(self, n: SupportsIndex) -> List[_T]: + if not isinstance(n, int): + return NotImplemented + return list(self) * n + + def __rmul__(self, n: SupportsIndex) -> List[_T]: + if not isinstance(n, int): + return NotImplemented + return n * list(self) + + def __iadd__(self, iterable: Iterable[_T]) -> Self: + self.extend(iterable) + return self + + def __imul__(self, n: SupportsIndex) -> Self: + # unlike a regular list *=, proxied __imul__ will generate unique + # backing objects for each copy. *= on proxied lists is a bit of + # a stretch anyhow, and this interpretation of the __imul__ contract + # is more plausibly useful than copying the backing objects. + if not isinstance(n, int): + raise NotImplementedError() + if n == 0: + self.clear() + elif n > 1: + self.extend(list(self) * (n - 1)) + return self + + if typing.TYPE_CHECKING: + # TODO: no idea how to do this without separate "stub" + def index( + self, value: Any, start: int = ..., stop: int = ... + ) -> int: ... + + else: + + def index(self, value: Any, *arg) -> int: + ls = list(self) + return ls.index(value, *arg) + + def copy(self) -> List[_T]: + return list(self) + + def __repr__(self) -> str: + return repr(list(self)) + + def __hash__(self) -> NoReturn: + raise TypeError("%s objects are unhashable" % type(self).__name__) + + if not typing.TYPE_CHECKING: + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): + func.__doc__ = getattr(list, func_name).__doc__ + del func_name, func + + +class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]): + """Generic, converting, dict-to-dict proxy.""" + + setter: _DictSetterProtocol[_VT] + creator: _KeyCreatorProtocol[_VT] + col: MutableMapping[_KT, Optional[_VT]] + + def _create(self, key: _KT, value: Optional[_VT]) -> Any: + return self.creator(key, value) + + def _get(self, object_: Any) -> _VT: + return self.getter(object_) + + def _set(self, object_: Any, key: _KT, value: _VT) -> None: + return self.setter(object_, key, value) + + def __getitem__(self, key: _KT) -> _VT: + return self._get(self.col[key]) + + def __setitem__(self, key: _KT, value: _VT) -> None: + if key in self.col: + self._set(self.col[key], key, value) + else: + self.col[key] = self._create(key, value) + + def __delitem__(self, key: _KT) -> None: + del self.col[key] + + def __contains__(self, key: object) -> bool: + return key in self.col + + def __iter__(self) -> Iterator[_KT]: + return iter(self.col.keys()) + + def clear(self) -> None: + self.col.clear() + + def __eq__(self, other: object) -> bool: + return dict(self) == other + + def __ne__(self, other: object) -> bool: + return dict(self) != other + + def __repr__(self) -> str: + return repr(dict(self)) + + @overload + def get(self, __key: _KT) -> Optional[_VT]: ... + + @overload + def get(self, __key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ... + + def get( + self, key: _KT, default: Optional[Union[_VT, _T]] = None + ) -> Union[_VT, _T, None]: + try: + return self[key] + except KeyError: + return default + + def setdefault(self, key: _KT, default: Optional[_VT] = None) -> _VT: + # TODO: again, no idea how to create an actual MutableMapping. + # default must allow None, return type can't include None, + # the stub explicitly allows for default of None with a cryptic message + # "This overload should be allowed only if the value type is + # compatible with None.". + if key not in self.col: + self.col[key] = self._create(key, default) + return default # type: ignore + else: + return self[key] + + def keys(self) -> KeysView[_KT]: + return self.col.keys() + + def items(self) -> ItemsView[_KT, _VT]: + return ItemsView(self) + + def values(self) -> ValuesView[_VT]: + return ValuesView(self) + + @overload + def pop(self, __key: _KT) -> _VT: ... + + @overload + def pop( + self, __key: _KT, default: Union[_VT, _T] = ... + ) -> Union[_VT, _T]: ... + + def pop(self, __key: _KT, *arg: Any, **kw: Any) -> Union[_VT, _T]: + member = self.col.pop(__key, *arg, **kw) + return self._get(member) + + def popitem(self) -> Tuple[_KT, _VT]: + item = self.col.popitem() + return (item[0], self._get(item[1])) + + @overload + def update( + self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT + ) -> None: ... + + @overload + def update( + self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... + + @overload + def update(self, **kwargs: _VT) -> None: ... + + def update(self, *a: Any, **kw: Any) -> None: + up: Dict[_KT, _VT] = {} + up.update(*a, **kw) + + for key, value in up.items(): + self[key] = value + + def _bulk_replace( + self, + assoc_proxy: AssociationProxyInstance[Any], + values: Mapping[_KT, _VT], + ) -> None: + existing = set(self) + constants = existing.intersection(values or ()) + additions = set(values or ()).difference(constants) + removals = existing.difference(constants) + + for key, member in values.items() or (): + if key in additions: + self[key] = member + elif key in constants: + self[key] = member + + for key in removals: + del self[key] + + def copy(self) -> Dict[_KT, _VT]: + return dict(self.items()) + + def __hash__(self) -> NoReturn: + raise TypeError("%s objects are unhashable" % type(self).__name__) + + if not typing.TYPE_CHECKING: + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(dict, func_name) + ): + func.__doc__ = getattr(dict, func_name).__doc__ + del func_name, func + + +class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]): + """Generic, converting, set-to-set proxy.""" + + col: MutableSet[_T] + + def __len__(self) -> int: + return len(self.col) + + def __bool__(self) -> bool: + if self.col: + return True + else: + return False + + def __contains__(self, __o: object) -> bool: + for member in self.col: + if self._get(member) == __o: + return True + return False + + def __iter__(self) -> Iterator[_T]: + """Iterate over proxied values. + + For the actual domain objects, iterate over .col instead or just use + the underlying collection directly from its property on the parent. + + """ + for member in self.col: + yield self._get(member) + return + + def add(self, __element: _T) -> None: + if __element not in self: + self.col.add(self._create(__element)) + + # for discard and remove, choosing a more expensive check strategy rather + # than call self.creator() + def discard(self, __element: _T) -> None: + for member in self.col: + if self._get(member) == __element: + self.col.discard(member) + break + + def remove(self, __element: _T) -> None: + for member in self.col: + if self._get(member) == __element: + self.col.discard(member) + return + raise KeyError(__element) + + def pop(self) -> _T: + if not self.col: + raise KeyError("pop from an empty set") + member = self.col.pop() + return self._get(member) + + def update(self, *s: Iterable[_T]) -> None: + for iterable in s: + for value in iterable: + self.add(value) + + def _bulk_replace(self, assoc_proxy: Any, values: Iterable[_T]) -> None: + existing = set(self) + constants = existing.intersection(values or ()) + additions = set(values or ()).difference(constants) + removals = existing.difference(constants) + + appender = self.add + remover = self.remove + + for member in values or (): + if member in additions: + appender(member) + elif member in constants: + appender(member) + + for member in removals: + remover(member) + + def __ior__( # type: ignore + self, other: AbstractSet[_S] + ) -> MutableSet[Union[_T, _S]]: + if not collections._set_binops_check_strict(self, other): + raise NotImplementedError() + for value in other: + self.add(value) + return self + + def _set(self) -> Set[_T]: + return set(iter(self)) + + def union(self, *s: Iterable[_S]) -> MutableSet[Union[_T, _S]]: + return set(self).union(*s) + + def __or__(self, __s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: + return self.union(__s) + + def difference(self, *s: Iterable[Any]) -> MutableSet[_T]: + return set(self).difference(*s) + + def __sub__(self, s: AbstractSet[Any]) -> MutableSet[_T]: + return self.difference(s) + + def difference_update(self, *s: Iterable[Any]) -> None: + for other in s: + for value in other: + self.discard(value) + + def __isub__(self, s: AbstractSet[Any]) -> Self: + if not collections._set_binops_check_strict(self, s): + raise NotImplementedError() + for value in s: + self.discard(value) + return self + + def intersection(self, *s: Iterable[Any]) -> MutableSet[_T]: + return set(self).intersection(*s) + + def __and__(self, s: AbstractSet[Any]) -> MutableSet[_T]: + return self.intersection(s) + + def intersection_update(self, *s: Iterable[Any]) -> None: + for other in s: + want, have = self.intersection(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + + def __iand__(self, s: AbstractSet[Any]) -> Self: + if not collections._set_binops_check_strict(self, s): + raise NotImplementedError() + want = self.intersection(s) + have: Set[_T] = set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + return self + + def symmetric_difference(self, __s: Iterable[_T]) -> MutableSet[_T]: + return set(self).symmetric_difference(__s) + + def __xor__(self, s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: + return self.symmetric_difference(s) + + def symmetric_difference_update(self, other: Iterable[Any]) -> None: + want, have = self.symmetric_difference(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + + def __ixor__(self, other: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: # type: ignore # noqa: E501 + if not collections._set_binops_check_strict(self, other): + raise NotImplementedError() + + self.symmetric_difference_update(other) + return self + + def issubset(self, __s: Iterable[Any]) -> bool: + return set(self).issubset(__s) + + def issuperset(self, __s: Iterable[Any]) -> bool: + return set(self).issuperset(__s) + + def clear(self) -> None: + self.col.clear() + + def copy(self) -> AbstractSet[_T]: + return set(self) + + def __eq__(self, other: object) -> bool: + return set(self) == other + + def __ne__(self, other: object) -> bool: + return set(self) != other + + def __lt__(self, other: AbstractSet[Any]) -> bool: + return set(self) < other + + def __le__(self, other: AbstractSet[Any]) -> bool: + return set(self) <= other + + def __gt__(self, other: AbstractSet[Any]) -> bool: + return set(self) > other + + def __ge__(self, other: AbstractSet[Any]) -> bool: + return set(self) >= other + + def __repr__(self) -> str: + return repr(set(self)) + + def __hash__(self) -> NoReturn: + raise TypeError("%s objects are unhashable" % type(self).__name__) + + if not typing.TYPE_CHECKING: + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(set, func_name) + ): + func.__doc__ = getattr(set, func_name).__doc__ + del func_name, func diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__init__.py new file mode 100644 index 0000000..78c707b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__init__.py @@ -0,0 +1,25 @@ +# ext/asyncio/__init__.py +# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from .engine import async_engine_from_config as async_engine_from_config +from .engine import AsyncConnection as AsyncConnection +from .engine import AsyncEngine as AsyncEngine +from .engine import AsyncTransaction as AsyncTransaction +from .engine import create_async_engine as create_async_engine +from .engine import create_async_pool_from_url as create_async_pool_from_url +from .result import AsyncMappingResult as AsyncMappingResult +from .result import AsyncResult as AsyncResult +from .result import AsyncScalarResult as AsyncScalarResult +from .result import AsyncTupleResult as AsyncTupleResult +from .scoping import async_scoped_session as async_scoped_session +from .session import async_object_session as async_object_session +from .session import async_session as async_session +from .session import async_sessionmaker as async_sessionmaker +from .session import AsyncAttrs as AsyncAttrs +from .session import AsyncSession as AsyncSession +from .session import AsyncSessionTransaction as AsyncSessionTransaction +from .session import close_all_sessions as close_all_sessions diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a647d42 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..785ef03 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/engine.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/engine.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4326d1c --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/engine.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/exc.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/exc.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5a71fac --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/exc.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/result.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/result.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c6ae583 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/result.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/scoping.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/scoping.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8839d42 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/scoping.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/session.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/session.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0c267a0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/session.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/base.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/base.py new file mode 100644 index 0000000..9899364 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/base.py @@ -0,0 +1,279 @@ +# ext/asyncio/base.py +# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import abc +import functools +from typing import Any +from typing import AsyncGenerator +from typing import AsyncIterator +from typing import Awaitable +from typing import Callable +from typing import ClassVar +from typing import Dict +from typing import Generator +from typing import Generic +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Tuple +from typing import TypeVar +import weakref + +from . import exc as async_exc +from ... import util +from ...util.typing import Literal +from ...util.typing import Self + +_T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) + + +_PT = TypeVar("_PT", bound=Any) + + +class ReversibleProxy(Generic[_PT]): + _proxy_objects: ClassVar[ + Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]] + ] = {} + __slots__ = ("__weakref__",) + + @overload + def _assign_proxied(self, target: _PT) -> _PT: ... + + @overload + def _assign_proxied(self, target: None) -> None: ... + + def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]: + if target is not None: + target_ref: weakref.ref[_PT] = weakref.ref( + target, ReversibleProxy._target_gced + ) + proxy_ref = weakref.ref( + self, + functools.partial(ReversibleProxy._target_gced, target_ref), + ) + ReversibleProxy._proxy_objects[target_ref] = proxy_ref + + return target + + @classmethod + def _target_gced( + cls, + ref: weakref.ref[_PT], + proxy_ref: Optional[weakref.ref[Self]] = None, # noqa: U100 + ) -> None: + cls._proxy_objects.pop(ref, None) + + @classmethod + def _regenerate_proxy_for_target(cls, target: _PT) -> Self: + raise NotImplementedError() + + @overload + @classmethod + def _retrieve_proxy_for_target( + cls, + target: _PT, + regenerate: Literal[True] = ..., + ) -> Self: ... + + @overload + @classmethod + def _retrieve_proxy_for_target( + cls, target: _PT, regenerate: bool = True + ) -> Optional[Self]: ... + + @classmethod + def _retrieve_proxy_for_target( + cls, target: _PT, regenerate: bool = True + ) -> Optional[Self]: + try: + proxy_ref = cls._proxy_objects[weakref.ref(target)] + except KeyError: + pass + else: + proxy = proxy_ref() + if proxy is not None: + return proxy # type: ignore + + if regenerate: + return cls._regenerate_proxy_for_target(target) + else: + return None + + +class StartableContext(Awaitable[_T_co], abc.ABC): + __slots__ = () + + @abc.abstractmethod + async def start(self, is_ctxmanager: bool = False) -> _T_co: + raise NotImplementedError() + + def __await__(self) -> Generator[Any, Any, _T_co]: + return self.start().__await__() + + async def __aenter__(self) -> _T_co: + return await self.start(is_ctxmanager=True) + + @abc.abstractmethod + async def __aexit__( + self, type_: Any, value: Any, traceback: Any + ) -> Optional[bool]: + pass + + def _raise_for_not_started(self) -> NoReturn: + raise async_exc.AsyncContextNotStarted( + "%s context has not been started and object has not been awaited." + % (self.__class__.__name__) + ) + + +class GeneratorStartableContext(StartableContext[_T_co]): + __slots__ = ("gen",) + + gen: AsyncGenerator[_T_co, Any] + + def __init__( + self, + func: Callable[..., AsyncIterator[_T_co]], + args: Tuple[Any, ...], + kwds: Dict[str, Any], + ): + self.gen = func(*args, **kwds) # type: ignore + + async def start(self, is_ctxmanager: bool = False) -> _T_co: + try: + start_value = await util.anext_(self.gen) + except StopAsyncIteration: + raise RuntimeError("generator didn't yield") from None + + # if not a context manager, then interrupt the generator, don't + # let it complete. this step is technically not needed, as the + # generator will close in any case at gc time. not clear if having + # this here is a good idea or not (though it helps for clarity IMO) + if not is_ctxmanager: + await self.gen.aclose() + + return start_value + + async def __aexit__( + self, typ: Any, value: Any, traceback: Any + ) -> Optional[bool]: + # vendored from contextlib.py + if typ is None: + try: + await util.anext_(self.gen) + except StopAsyncIteration: + return False + else: + raise RuntimeError("generator didn't stop") + else: + if value is None: + # Need to force instantiation so we can reliably + # tell if we get the same exception back + value = typ() + try: + await self.gen.athrow(value) + except StopAsyncIteration as exc: + # Suppress StopIteration *unless* it's the same exception that + # was passed to throw(). This prevents a StopIteration + # raised inside the "with" statement from being suppressed. + return exc is not value + except RuntimeError as exc: + # Don't re-raise the passed in exception. (issue27122) + if exc is value: + return False + # Avoid suppressing if a Stop(Async)Iteration exception + # was passed to athrow() and later wrapped into a RuntimeError + # (see PEP 479 for sync generators; async generators also + # have this behavior). But do this only if the exception + # wrapped + # by the RuntimeError is actully Stop(Async)Iteration (see + # issue29692). + if ( + isinstance(value, (StopIteration, StopAsyncIteration)) + and exc.__cause__ is value + ): + return False + raise + except BaseException as exc: + # only re-raise if it's *not* the exception that was + # passed to throw(), because __exit__() must not raise + # an exception unless __exit__() itself failed. But throw() + # has to raise the exception to signal propagation, so this + # fixes the impedance mismatch between the throw() protocol + # and the __exit__() protocol. + if exc is not value: + raise + return False + raise RuntimeError("generator didn't stop after athrow()") + + +def asyncstartablecontext( + func: Callable[..., AsyncIterator[_T_co]] +) -> Callable[..., GeneratorStartableContext[_T_co]]: + """@asyncstartablecontext decorator. + + the decorated function can be called either as ``async with fn()``, **or** + ``await fn()``. This is decidedly different from what + ``@contextlib.asynccontextmanager`` supports, and the usage pattern + is different as well. + + Typical usage:: + + @asyncstartablecontext + async def some_async_generator(<arguments>): + <setup> + try: + yield <value> + except GeneratorExit: + # return value was awaited, no context manager is present + # and caller will .close() the resource explicitly + pass + else: + <context manager cleanup> + + + Above, ``GeneratorExit`` is caught if the function were used as an + ``await``. In this case, it's essential that the cleanup does **not** + occur, so there should not be a ``finally`` block. + + If ``GeneratorExit`` is not invoked, this means we're in ``__aexit__`` + and we were invoked as a context manager, and cleanup should proceed. + + + """ + + @functools.wraps(func) + def helper(*args: Any, **kwds: Any) -> GeneratorStartableContext[_T_co]: + return GeneratorStartableContext(func, args, kwds) + + return helper + + +class ProxyComparable(ReversibleProxy[_PT]): + __slots__ = () + + @util.ro_non_memoized_property + def _proxied(self) -> _PT: + raise NotImplementedError() + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, self.__class__) + and self._proxied == other._proxied + ) + + def __ne__(self, other: Any) -> bool: + return ( + not isinstance(other, self.__class__) + or self._proxied != other._proxied + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/engine.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/engine.py new file mode 100644 index 0000000..8fc8e96 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/engine.py @@ -0,0 +1,1466 @@ +# ext/asyncio/engine.py +# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import asyncio +import contextlib +from typing import Any +from typing import AsyncIterator +from typing import Callable +from typing import Dict +from typing import Generator +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import exc as async_exc +from .base import asyncstartablecontext +from .base import GeneratorStartableContext +from .base import ProxyComparable +from .base import StartableContext +from .result import _ensure_sync_result +from .result import AsyncResult +from .result import AsyncScalarResult +from ... import exc +from ... import inspection +from ... import util +from ...engine import Connection +from ...engine import create_engine as _create_engine +from ...engine import create_pool_from_url as _create_pool_from_url +from ...engine import Engine +from ...engine.base import NestedTransaction +from ...engine.base import Transaction +from ...exc import ArgumentError +from ...util.concurrency import greenlet_spawn +from ...util.typing import Concatenate +from ...util.typing import ParamSpec + +if TYPE_CHECKING: + from ...engine.cursor import CursorResult + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import _CoreSingleExecuteParams + from ...engine.interfaces import _DBAPIAnyExecuteParams + from ...engine.interfaces import _ExecuteOptions + from ...engine.interfaces import CompiledCacheType + from ...engine.interfaces import CoreExecuteOptionsParameter + from ...engine.interfaces import Dialect + from ...engine.interfaces import IsolationLevel + from ...engine.interfaces import SchemaTranslateMapType + from ...engine.result import ScalarResult + from ...engine.url import URL + from ...pool import Pool + from ...pool import PoolProxiedConnection + from ...sql._typing import _InfoType + from ...sql.base import Executable + from ...sql.selectable import TypedReturnsRows + +_P = ParamSpec("_P") +_T = TypeVar("_T", bound=Any) + + +def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine: + """Create a new async engine instance. + + Arguments passed to :func:`_asyncio.create_async_engine` are mostly + identical to those passed to the :func:`_sa.create_engine` function. + The specified dialect must be an asyncio-compatible dialect + such as :ref:`dialect-postgresql-asyncpg`. + + .. versionadded:: 1.4 + + :param async_creator: an async callable which returns a driver-level + asyncio connection. If given, the function should take no arguments, + and return a new asyncio connection from the underlying asyncio + database driver; the connection will be wrapped in the appropriate + structures to be used with the :class:`.AsyncEngine`. Note that the + parameters specified in the URL are not applied here, and the creator + function should use its own connection parameters. + + This parameter is the asyncio equivalent of the + :paramref:`_sa.create_engine.creator` parameter of the + :func:`_sa.create_engine` function. + + .. versionadded:: 2.0.16 + + """ + + if kw.get("server_side_cursors", False): + raise async_exc.AsyncMethodRequired( + "Can't set server_side_cursors for async engine globally; " + "use the connection.stream() method for an async " + "streaming result set" + ) + kw["_is_async"] = True + async_creator = kw.pop("async_creator", None) + if async_creator: + if kw.get("creator", None): + raise ArgumentError( + "Can only specify one of 'async_creator' or 'creator', " + "not both." + ) + + def creator() -> Any: + # note that to send adapted arguments like + # prepared_statement_cache_size, user would use + # "creator" and emulate this form here + return sync_engine.dialect.dbapi.connect( # type: ignore + async_creator_fn=async_creator + ) + + kw["creator"] = creator + sync_engine = _create_engine(url, **kw) + return AsyncEngine(sync_engine) + + +def async_engine_from_config( + configuration: Dict[str, Any], prefix: str = "sqlalchemy.", **kwargs: Any +) -> AsyncEngine: + """Create a new AsyncEngine instance using a configuration dictionary. + + This function is analogous to the :func:`_sa.engine_from_config` function + in SQLAlchemy Core, except that the requested dialect must be an + asyncio-compatible dialect such as :ref:`dialect-postgresql-asyncpg`. + The argument signature of the function is identical to that + of :func:`_sa.engine_from_config`. + + .. versionadded:: 1.4.29 + + """ + options = { + key[len(prefix) :]: value + for key, value in configuration.items() + if key.startswith(prefix) + } + options["_coerce_config"] = True + options.update(kwargs) + url = options.pop("url") + return create_async_engine(url, **options) + + +def create_async_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: + """Create a new async engine instance. + + Arguments passed to :func:`_asyncio.create_async_pool_from_url` are mostly + identical to those passed to the :func:`_sa.create_pool_from_url` function. + The specified dialect must be an asyncio-compatible dialect + such as :ref:`dialect-postgresql-asyncpg`. + + .. versionadded:: 2.0.10 + + """ + kwargs["_is_async"] = True + return _create_pool_from_url(url, **kwargs) + + +class AsyncConnectable: + __slots__ = "_slots_dispatch", "__weakref__" + + @classmethod + def _no_async_engine_events(cls) -> NoReturn: + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncEngine.sync_engine or " + "AsyncConnection.sync_connection attributes." + ) + + +@util.create_proxy_methods( + Connection, + ":class:`_engine.Connection`", + ":class:`_asyncio.AsyncConnection`", + classmethods=[], + methods=[], + attributes=[ + "closed", + "invalidated", + "dialect", + "default_isolation_level", + ], +) +class AsyncConnection( + ProxyComparable[Connection], + StartableContext["AsyncConnection"], + AsyncConnectable, +): + """An asyncio proxy for a :class:`_engine.Connection`. + + :class:`_asyncio.AsyncConnection` is acquired using the + :meth:`_asyncio.AsyncEngine.connect` + method of :class:`_asyncio.AsyncEngine`:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") + + async with engine.connect() as conn: + result = await conn.execute(select(table)) + + .. versionadded:: 1.4 + + """ # noqa + + # AsyncConnection is a thin proxy; no state should be added here + # that is not retrievable from the "sync" engine / connection, e.g. + # current transaction, info, etc. It should be possible to + # create a new AsyncConnection that matches this one given only the + # "sync" elements. + __slots__ = ( + "engine", + "sync_engine", + "sync_connection", + ) + + def __init__( + self, + async_engine: AsyncEngine, + sync_connection: Optional[Connection] = None, + ): + self.engine = async_engine + self.sync_engine = async_engine.sync_engine + self.sync_connection = self._assign_proxied(sync_connection) + + sync_connection: Optional[Connection] + """Reference to the sync-style :class:`_engine.Connection` this + :class:`_asyncio.AsyncConnection` proxies requests towards. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + + """ + + sync_engine: Engine + """Reference to the sync-style :class:`_engine.Engine` this + :class:`_asyncio.AsyncConnection` is associated with via its underlying + :class:`_engine.Connection`. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + + """ + + @classmethod + def _regenerate_proxy_for_target( + cls, target: Connection + ) -> AsyncConnection: + return AsyncConnection( + AsyncEngine._retrieve_proxy_for_target(target.engine), target + ) + + async def start( + self, is_ctxmanager: bool = False # noqa: U100 + ) -> AsyncConnection: + """Start this :class:`_asyncio.AsyncConnection` object's context + outside of using a Python ``with:`` block. + + """ + if self.sync_connection: + raise exc.InvalidRequestError("connection is already started") + self.sync_connection = self._assign_proxied( + await greenlet_spawn(self.sync_engine.connect) + ) + return self + + @property + def connection(self) -> NoReturn: + """Not implemented for async; call + :meth:`_asyncio.AsyncConnection.get_raw_connection`. + """ + raise exc.InvalidRequestError( + "AsyncConnection.connection accessor is not implemented as the " + "attribute may need to reconnect on an invalidated connection. " + "Use the get_raw_connection() method." + ) + + async def get_raw_connection(self) -> PoolProxiedConnection: + """Return the pooled DBAPI-level connection in use by this + :class:`_asyncio.AsyncConnection`. + + This is a SQLAlchemy connection-pool proxied connection + which then has the attribute + :attr:`_pool._ConnectionFairy.driver_connection` that refers to the + actual driver connection. Its + :attr:`_pool._ConnectionFairy.dbapi_connection` refers instead + to an :class:`_engine.AdaptedConnection` instance that + adapts the driver connection to the DBAPI protocol. + + """ + + return await greenlet_spawn(getattr, self._proxied, "connection") + + @util.ro_non_memoized_property + def info(self) -> _InfoType: + """Return the :attr:`_engine.Connection.info` dictionary of the + underlying :class:`_engine.Connection`. + + This dictionary is freely writable for user-defined state to be + associated with the database connection. + + This attribute is only available if the :class:`.AsyncConnection` is + currently connected. If the :attr:`.AsyncConnection.closed` attribute + is ``True``, then accessing this attribute will raise + :class:`.ResourceClosedError`. + + .. versionadded:: 1.4.0b2 + + """ + return self._proxied.info + + @util.ro_non_memoized_property + def _proxied(self) -> Connection: + if not self.sync_connection: + self._raise_for_not_started() + return self.sync_connection + + def begin(self) -> AsyncTransaction: + """Begin a transaction prior to autobegin occurring.""" + assert self._proxied + return AsyncTransaction(self) + + def begin_nested(self) -> AsyncTransaction: + """Begin a nested transaction and return a transaction handle.""" + assert self._proxied + return AsyncTransaction(self, nested=True) + + async def invalidate( + self, exception: Optional[BaseException] = None + ) -> None: + """Invalidate the underlying DBAPI connection associated with + this :class:`_engine.Connection`. + + See the method :meth:`_engine.Connection.invalidate` for full + detail on this method. + + """ + + return await greenlet_spawn( + self._proxied.invalidate, exception=exception + ) + + async def get_isolation_level(self) -> IsolationLevel: + return await greenlet_spawn(self._proxied.get_isolation_level) + + def in_transaction(self) -> bool: + """Return True if a transaction is in progress.""" + + return self._proxied.in_transaction() + + def in_nested_transaction(self) -> bool: + """Return True if a transaction is in progress. + + .. versionadded:: 1.4.0b2 + + """ + return self._proxied.in_nested_transaction() + + def get_transaction(self) -> Optional[AsyncTransaction]: + """Return an :class:`.AsyncTransaction` representing the current + transaction, if any. + + This makes use of the underlying synchronous connection's + :meth:`_engine.Connection.get_transaction` method to get the current + :class:`_engine.Transaction`, which is then proxied in a new + :class:`.AsyncTransaction` object. + + .. versionadded:: 1.4.0b2 + + """ + + trans = self._proxied.get_transaction() + if trans is not None: + return AsyncTransaction._retrieve_proxy_for_target(trans) + else: + return None + + def get_nested_transaction(self) -> Optional[AsyncTransaction]: + """Return an :class:`.AsyncTransaction` representing the current + nested (savepoint) transaction, if any. + + This makes use of the underlying synchronous connection's + :meth:`_engine.Connection.get_nested_transaction` method to get the + current :class:`_engine.Transaction`, which is then proxied in a new + :class:`.AsyncTransaction` object. + + .. versionadded:: 1.4.0b2 + + """ + + trans = self._proxied.get_nested_transaction() + if trans is not None: + return AsyncTransaction._retrieve_proxy_for_target(trans) + else: + return None + + @overload + async def execution_options( + self, + *, + compiled_cache: Optional[CompiledCacheType] = ..., + logging_token: str = ..., + isolation_level: IsolationLevel = ..., + no_parameters: bool = False, + stream_results: bool = False, + max_row_buffer: int = ..., + yield_per: int = ..., + insertmanyvalues_page_size: int = ..., + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + preserve_rowcount: bool = False, + **opt: Any, + ) -> AsyncConnection: ... + + @overload + async def execution_options(self, **opt: Any) -> AsyncConnection: ... + + async def execution_options(self, **opt: Any) -> AsyncConnection: + r"""Set non-SQL options for the connection which take effect + during execution. + + This returns this :class:`_asyncio.AsyncConnection` object with + the new options added. + + See :meth:`_engine.Connection.execution_options` for full details + on this method. + + """ + + conn = self._proxied + c2 = await greenlet_spawn(conn.execution_options, **opt) + assert c2 is conn + return self + + async def commit(self) -> None: + """Commit the transaction that is currently in progress. + + This method commits the current transaction if one has been started. + If no transaction was started, the method has no effect, assuming + the connection is in a non-invalidated state. + + A transaction is begun on a :class:`_engine.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_engine.Connection.begin` method is called. + + """ + await greenlet_spawn(self._proxied.commit) + + async def rollback(self) -> None: + """Roll back the transaction that is currently in progress. + + This method rolls back the current transaction if one has been started. + If no transaction was started, the method has no effect. If a + transaction was started and the connection is in an invalidated state, + the transaction is cleared using this method. + + A transaction is begun on a :class:`_engine.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_engine.Connection.begin` method is called. + + + """ + await greenlet_spawn(self._proxied.rollback) + + async def close(self) -> None: + """Close this :class:`_asyncio.AsyncConnection`. + + This has the effect of also rolling back the transaction if one + is in place. + + """ + await greenlet_spawn(self._proxied.close) + + async def aclose(self) -> None: + """A synonym for :meth:`_asyncio.AsyncConnection.close`. + + The :meth:`_asyncio.AsyncConnection.aclose` name is specifically + to support the Python standard library ``@contextlib.aclosing`` + context manager function. + + .. versionadded:: 2.0.20 + + """ + await self.close() + + async def exec_driver_sql( + self, + statement: str, + parameters: Optional[_DBAPIAnyExecuteParams] = None, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: + r"""Executes a driver-level SQL string and return buffered + :class:`_engine.Result`. + + """ + + result = await greenlet_spawn( + self._proxied.exec_driver_sql, + statement, + parameters, + execution_options, + _require_await=True, + ) + + return await _ensure_sync_result(result, self.exec_driver_sql) + + @overload + def stream( + self, + statement: TypedReturnsRows[_T], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> GeneratorStartableContext[AsyncResult[_T]]: ... + + @overload + def stream( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> GeneratorStartableContext[AsyncResult[Any]]: ... + + @asyncstartablecontext + async def stream( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> AsyncIterator[AsyncResult[Any]]: + """Execute a statement and return an awaitable yielding a + :class:`_asyncio.AsyncResult` object. + + E.g.:: + + result = await conn.stream(stmt): + async for row in result: + print(f"{row}") + + The :meth:`.AsyncConnection.stream` + method supports optional context manager use against the + :class:`.AsyncResult` object, as in:: + + async with conn.stream(stmt) as result: + async for row in result: + print(f"{row}") + + In the above pattern, the :meth:`.AsyncResult.close` method is + invoked unconditionally, even if the iterator is interrupted by an + exception throw. Context manager use remains optional, however, + and the function may be called in either an ``async with fn():`` or + ``await fn()`` style. + + .. versionadded:: 2.0.0b3 added context manager support + + + :return: an awaitable object that will yield an + :class:`_asyncio.AsyncResult` object. + + .. seealso:: + + :meth:`.AsyncConnection.stream_scalars` + + """ + if not self.dialect.supports_server_side_cursors: + raise exc.InvalidRequestError( + "Cant use `stream` or `stream_scalars` with the current " + "dialect since it does not support server side cursors." + ) + + result = await greenlet_spawn( + self._proxied.execute, + statement, + parameters, + execution_options=util.EMPTY_DICT.merge_with( + execution_options, {"stream_results": True} + ), + _require_await=True, + ) + assert result.context._is_server_side + ar = AsyncResult(result) + try: + yield ar + except GeneratorExit: + pass + else: + task = asyncio.create_task(ar.close()) + await asyncio.shield(task) + + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[_T]: ... + + @overload + async def execute( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: ... + + async def execute( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: + r"""Executes a SQL statement construct and return a buffered + :class:`_engine.Result`. + + :param object: The statement to be executed. This is always + an object that is in both the :class:`_expression.ClauseElement` and + :class:`_expression.Executable` hierarchies, including: + + * :class:`_expression.Select` + * :class:`_expression.Insert`, :class:`_expression.Update`, + :class:`_expression.Delete` + * :class:`_expression.TextClause` and + :class:`_expression.TextualSelect` + * :class:`_schema.DDL` and objects which inherit from + :class:`_schema.ExecutableDDLElement` + + :param parameters: parameters which will be bound into the statement. + This may be either a dictionary of parameter names to values, + or a mutable sequence (e.g. a list) of dictionaries. When a + list of dictionaries is passed, the underlying statement execution + will make use of the DBAPI ``cursor.executemany()`` method. + When a single dictionary is passed, the DBAPI ``cursor.execute()`` + method will be used. + + :param execution_options: optional dictionary of execution options, + which will be associated with the statement execution. This + dictionary can provide a subset of the options that are accepted + by :meth:`_engine.Connection.execution_options`. + + :return: a :class:`_engine.Result` object. + + """ + result = await greenlet_spawn( + self._proxied.execute, + statement, + parameters, + execution_options=execution_options, + _require_await=True, + ) + return await _ensure_sync_result(result, self.execute) + + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Optional[_T]: ... + + @overload + async def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Any: ... + + async def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Any: + r"""Executes a SQL statement construct and returns a scalar object. + + This method is shorthand for invoking the + :meth:`_engine.Result.scalar` method after invoking the + :meth:`_engine.Connection.execute` method. Parameters are equivalent. + + :return: a scalar Python value representing the first column of the + first row returned. + + """ + result = await self.execute( + statement, parameters, execution_options=execution_options + ) + return result.scalar() + + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[_T]: ... + + @overload + async def scalars( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: ... + + async def scalars( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: + r"""Executes a SQL statement construct and returns a scalar objects. + + This method is shorthand for invoking the + :meth:`_engine.Result.scalars` method after invoking the + :meth:`_engine.Connection.execute` method. Parameters are equivalent. + + :return: a :class:`_engine.ScalarResult` object. + + .. versionadded:: 1.4.24 + + """ + result = await self.execute( + statement, parameters, execution_options=execution_options + ) + return result.scalars() + + @overload + def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> GeneratorStartableContext[AsyncScalarResult[_T]]: ... + + @overload + def stream_scalars( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> GeneratorStartableContext[AsyncScalarResult[Any]]: ... + + @asyncstartablecontext + async def stream_scalars( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> AsyncIterator[AsyncScalarResult[Any]]: + r"""Execute a statement and return an awaitable yielding a + :class:`_asyncio.AsyncScalarResult` object. + + E.g.:: + + result = await conn.stream_scalars(stmt) + async for scalar in result: + print(f"{scalar}") + + This method is shorthand for invoking the + :meth:`_engine.AsyncResult.scalars` method after invoking the + :meth:`_engine.Connection.stream` method. Parameters are equivalent. + + The :meth:`.AsyncConnection.stream_scalars` + method supports optional context manager use against the + :class:`.AsyncScalarResult` object, as in:: + + async with conn.stream_scalars(stmt) as result: + async for scalar in result: + print(f"{scalar}") + + In the above pattern, the :meth:`.AsyncScalarResult.close` method is + invoked unconditionally, even if the iterator is interrupted by an + exception throw. Context manager use remains optional, however, + and the function may be called in either an ``async with fn():`` or + ``await fn()`` style. + + .. versionadded:: 2.0.0b3 added context manager support + + :return: an awaitable object that will yield an + :class:`_asyncio.AsyncScalarResult` object. + + .. versionadded:: 1.4.24 + + .. seealso:: + + :meth:`.AsyncConnection.stream` + + """ + + async with self.stream( + statement, parameters, execution_options=execution_options + ) as result: + yield result.scalars() + + async def run_sync( + self, + fn: Callable[Concatenate[Connection, _P], _T], + *arg: _P.args, + **kw: _P.kwargs, + ) -> _T: + """Invoke the given synchronous (i.e. not async) callable, + passing a synchronous-style :class:`_engine.Connection` as the first + argument. + + This method allows traditional synchronous SQLAlchemy functions to + run within the context of an asyncio application. + + E.g.:: + + def do_something_with_core(conn: Connection, arg1: int, arg2: str) -> str: + '''A synchronous function that does not require awaiting + + :param conn: a Core SQLAlchemy Connection, used synchronously + + :return: an optional return value is supported + + ''' + conn.execute( + some_table.insert().values(int_col=arg1, str_col=arg2) + ) + return "success" + + + async def do_something_async(async_engine: AsyncEngine) -> None: + '''an async function that uses awaiting''' + + async with async_engine.begin() as async_conn: + # run do_something_with_core() with a sync-style + # Connection, proxied into an awaitable + return_code = await async_conn.run_sync(do_something_with_core, 5, "strval") + print(return_code) + + This method maintains the asyncio event loop all the way through + to the database connection by running the given callable in a + specially instrumented greenlet. + + The most rudimentary use of :meth:`.AsyncConnection.run_sync` is to + invoke methods such as :meth:`_schema.MetaData.create_all`, given + an :class:`.AsyncConnection` that needs to be provided to + :meth:`_schema.MetaData.create_all` as a :class:`_engine.Connection` + object:: + + # run metadata.create_all(conn) with a sync-style Connection, + # proxied into an awaitable + with async_engine.begin() as conn: + await conn.run_sync(metadata.create_all) + + .. note:: + + The provided callable is invoked inline within the asyncio event + loop, and will block on traditional IO calls. IO within this + callable should only call into SQLAlchemy's asyncio database + APIs which will be properly adapted to the greenlet context. + + .. seealso:: + + :meth:`.AsyncSession.run_sync` + + :ref:`session_run_sync` + + """ # noqa: E501 + + return await greenlet_spawn( + fn, self._proxied, *arg, _require_await=False, **kw + ) + + def __await__(self) -> Generator[Any, None, AsyncConnection]: + return self.start().__await__() + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + task = asyncio.create_task(self.close()) + await asyncio.shield(task) + + # START PROXY METHODS AsyncConnection + + # code within this block is **programmatically, + # statically generated** by tools/generate_proxy_methods.py + + @property + def closed(self) -> Any: + r"""Return True if this connection is closed. + + .. container:: class_bases + + Proxied for the :class:`_engine.Connection` class + on behalf of the :class:`_asyncio.AsyncConnection` class. + + """ # noqa: E501 + + return self._proxied.closed + + @property + def invalidated(self) -> Any: + r"""Return True if this connection was invalidated. + + .. container:: class_bases + + Proxied for the :class:`_engine.Connection` class + on behalf of the :class:`_asyncio.AsyncConnection` class. + + This does not indicate whether or not the connection was + invalidated at the pool level, however + + + """ # noqa: E501 + + return self._proxied.invalidated + + @property + def dialect(self) -> Dialect: + r"""Proxy for the :attr:`_engine.Connection.dialect` attribute + on behalf of the :class:`_asyncio.AsyncConnection` class. + + """ # noqa: E501 + + return self._proxied.dialect + + @dialect.setter + def dialect(self, attr: Dialect) -> None: + self._proxied.dialect = attr + + @property + def default_isolation_level(self) -> Any: + r"""The initial-connection time isolation level associated with the + :class:`_engine.Dialect` in use. + + .. container:: class_bases + + Proxied for the :class:`_engine.Connection` class + on behalf of the :class:`_asyncio.AsyncConnection` class. + + This value is independent of the + :paramref:`.Connection.execution_options.isolation_level` and + :paramref:`.Engine.execution_options.isolation_level` execution + options, and is determined by the :class:`_engine.Dialect` when the + first connection is created, by performing a SQL query against the + database for the current isolation level before any additional commands + have been emitted. + + Calling this accessor does not invoke any new SQL queries. + + .. seealso:: + + :meth:`_engine.Connection.get_isolation_level` + - view current actual isolation level + + :paramref:`_sa.create_engine.isolation_level` + - set per :class:`_engine.Engine` isolation level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`_engine.Connection` isolation level + + + """ # noqa: E501 + + return self._proxied.default_isolation_level + + # END PROXY METHODS AsyncConnection + + +@util.create_proxy_methods( + Engine, + ":class:`_engine.Engine`", + ":class:`_asyncio.AsyncEngine`", + classmethods=[], + methods=[ + "clear_compiled_cache", + "update_execution_options", + "get_execution_options", + ], + attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"], +) +class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): + """An asyncio proxy for a :class:`_engine.Engine`. + + :class:`_asyncio.AsyncEngine` is acquired using the + :func:`_asyncio.create_async_engine` function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") + + .. versionadded:: 1.4 + + """ # noqa + + # AsyncEngine is a thin proxy; no state should be added here + # that is not retrievable from the "sync" engine / connection, e.g. + # current transaction, info, etc. It should be possible to + # create a new AsyncEngine that matches this one given only the + # "sync" elements. + __slots__ = "sync_engine" + + _connection_cls: Type[AsyncConnection] = AsyncConnection + + sync_engine: Engine + """Reference to the sync-style :class:`_engine.Engine` this + :class:`_asyncio.AsyncEngine` proxies requests towards. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + """ + + def __init__(self, sync_engine: Engine): + if not sync_engine.dialect.is_async: + raise exc.InvalidRequestError( + "The asyncio extension requires an async driver to be used. " + f"The loaded {sync_engine.dialect.driver!r} is not async." + ) + self.sync_engine = self._assign_proxied(sync_engine) + + @util.ro_non_memoized_property + def _proxied(self) -> Engine: + return self.sync_engine + + @classmethod + def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine: + return AsyncEngine(target) + + @contextlib.asynccontextmanager + async def begin(self) -> AsyncIterator[AsyncConnection]: + """Return a context manager which when entered will deliver an + :class:`_asyncio.AsyncConnection` with an + :class:`_asyncio.AsyncTransaction` established. + + E.g.:: + + async with async_engine.begin() as conn: + await conn.execute( + text("insert into table (x, y, z) values (1, 2, 3)") + ) + await conn.execute(text("my_special_procedure(5)")) + + + """ + conn = self.connect() + + async with conn: + async with conn.begin(): + yield conn + + def connect(self) -> AsyncConnection: + """Return an :class:`_asyncio.AsyncConnection` object. + + The :class:`_asyncio.AsyncConnection` will procure a database + connection from the underlying connection pool when it is entered + as an async context manager:: + + async with async_engine.connect() as conn: + result = await conn.execute(select(user_table)) + + The :class:`_asyncio.AsyncConnection` may also be started outside of a + context manager by invoking its :meth:`_asyncio.AsyncConnection.start` + method. + + """ + + return self._connection_cls(self) + + async def raw_connection(self) -> PoolProxiedConnection: + """Return a "raw" DBAPI connection from the connection pool. + + .. seealso:: + + :ref:`dbapi_connections` + + """ + return await greenlet_spawn(self.sync_engine.raw_connection) + + @overload + def execution_options( + self, + *, + compiled_cache: Optional[CompiledCacheType] = ..., + logging_token: str = ..., + isolation_level: IsolationLevel = ..., + insertmanyvalues_page_size: int = ..., + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + **opt: Any, + ) -> AsyncEngine: ... + + @overload + def execution_options(self, **opt: Any) -> AsyncEngine: ... + + def execution_options(self, **opt: Any) -> AsyncEngine: + """Return a new :class:`_asyncio.AsyncEngine` that will provide + :class:`_asyncio.AsyncConnection` objects with the given execution + options. + + Proxied from :meth:`_engine.Engine.execution_options`. See that + method for details. + + """ + + return AsyncEngine(self.sync_engine.execution_options(**opt)) + + async def dispose(self, close: bool = True) -> None: + """Dispose of the connection pool used by this + :class:`_asyncio.AsyncEngine`. + + :param close: if left at its default of ``True``, has the + effect of fully closing all **currently checked in** + database connections. Connections that are still checked out + will **not** be closed, however they will no longer be associated + with this :class:`_engine.Engine`, + so when they are closed individually, eventually the + :class:`_pool.Pool` which they are associated with will + be garbage collected and they will be closed out fully, if + not already closed on checkin. + + If set to ``False``, the previous connection pool is de-referenced, + and otherwise not touched in any way. + + .. seealso:: + + :meth:`_engine.Engine.dispose` + + """ + + await greenlet_spawn(self.sync_engine.dispose, close=close) + + # START PROXY METHODS AsyncEngine + + # code within this block is **programmatically, + # statically generated** by tools/generate_proxy_methods.py + + def clear_compiled_cache(self) -> None: + r"""Clear the compiled cache associated with the dialect. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class on + behalf of the :class:`_asyncio.AsyncEngine` class. + + This applies **only** to the built-in cache that is established + via the :paramref:`_engine.create_engine.query_cache_size` parameter. + It will not impact any dictionary caches that were passed via the + :paramref:`.Connection.execution_options.compiled_cache` parameter. + + .. versionadded:: 1.4 + + + """ # noqa: E501 + + return self._proxied.clear_compiled_cache() + + def update_execution_options(self, **opt: Any) -> None: + r"""Update the default execution_options dictionary + of this :class:`_engine.Engine`. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class on + behalf of the :class:`_asyncio.AsyncEngine` class. + + The given keys/values in \**opt are added to the + default execution options that will be used for + all connections. The initial contents of this dictionary + can be sent via the ``execution_options`` parameter + to :func:`_sa.create_engine`. + + .. seealso:: + + :meth:`_engine.Connection.execution_options` + + :meth:`_engine.Engine.execution_options` + + + """ # noqa: E501 + + return self._proxied.update_execution_options(**opt) + + def get_execution_options(self) -> _ExecuteOptions: + r"""Get the non-SQL options which will take effect during execution. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class on + behalf of the :class:`_asyncio.AsyncEngine` class. + + .. versionadded: 1.3 + + .. seealso:: + + :meth:`_engine.Engine.execution_options` + + """ # noqa: E501 + + return self._proxied.get_execution_options() + + @property + def url(self) -> URL: + r"""Proxy for the :attr:`_engine.Engine.url` attribute + on behalf of the :class:`_asyncio.AsyncEngine` class. + + """ # noqa: E501 + + return self._proxied.url + + @url.setter + def url(self, attr: URL) -> None: + self._proxied.url = attr + + @property + def pool(self) -> Pool: + r"""Proxy for the :attr:`_engine.Engine.pool` attribute + on behalf of the :class:`_asyncio.AsyncEngine` class. + + """ # noqa: E501 + + return self._proxied.pool + + @pool.setter + def pool(self, attr: Pool) -> None: + self._proxied.pool = attr + + @property + def dialect(self) -> Dialect: + r"""Proxy for the :attr:`_engine.Engine.dialect` attribute + on behalf of the :class:`_asyncio.AsyncEngine` class. + + """ # noqa: E501 + + return self._proxied.dialect + + @dialect.setter + def dialect(self, attr: Dialect) -> None: + self._proxied.dialect = attr + + @property + def engine(self) -> Any: + r"""Returns this :class:`.Engine`. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class + on behalf of the :class:`_asyncio.AsyncEngine` class. + + Used for legacy schemes that accept :class:`.Connection` / + :class:`.Engine` objects within the same variable. + + + """ # noqa: E501 + + return self._proxied.engine + + @property + def name(self) -> Any: + r"""String name of the :class:`~sqlalchemy.engine.interfaces.Dialect` + in use by this :class:`Engine`. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class + on behalf of the :class:`_asyncio.AsyncEngine` class. + + + """ # noqa: E501 + + return self._proxied.name + + @property + def driver(self) -> Any: + r"""Driver name of the :class:`~sqlalchemy.engine.interfaces.Dialect` + in use by this :class:`Engine`. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class + on behalf of the :class:`_asyncio.AsyncEngine` class. + + + """ # noqa: E501 + + return self._proxied.driver + + @property + def echo(self) -> Any: + r"""When ``True``, enable log output for this element. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class + on behalf of the :class:`_asyncio.AsyncEngine` class. + + This has the effect of setting the Python logging level for the namespace + of this element's class and object reference. A value of boolean ``True`` + indicates that the loglevel ``logging.INFO`` will be set for the logger, + whereas the string value ``debug`` will set the loglevel to + ``logging.DEBUG``. + + """ # noqa: E501 + + return self._proxied.echo + + @echo.setter + def echo(self, attr: Any) -> None: + self._proxied.echo = attr + + # END PROXY METHODS AsyncEngine + + +class AsyncTransaction( + ProxyComparable[Transaction], StartableContext["AsyncTransaction"] +): + """An asyncio proxy for a :class:`_engine.Transaction`.""" + + __slots__ = ("connection", "sync_transaction", "nested") + + sync_transaction: Optional[Transaction] + connection: AsyncConnection + nested: bool + + def __init__(self, connection: AsyncConnection, nested: bool = False): + self.connection = connection + self.sync_transaction = None + self.nested = nested + + @classmethod + def _regenerate_proxy_for_target( + cls, target: Transaction + ) -> AsyncTransaction: + sync_connection = target.connection + sync_transaction = target + nested = isinstance(target, NestedTransaction) + + async_connection = AsyncConnection._retrieve_proxy_for_target( + sync_connection + ) + assert async_connection is not None + + obj = cls.__new__(cls) + obj.connection = async_connection + obj.sync_transaction = obj._assign_proxied(sync_transaction) + obj.nested = nested + return obj + + @util.ro_non_memoized_property + def _proxied(self) -> Transaction: + if not self.sync_transaction: + self._raise_for_not_started() + return self.sync_transaction + + @property + def is_valid(self) -> bool: + return self._proxied.is_valid + + @property + def is_active(self) -> bool: + return self._proxied.is_active + + async def close(self) -> None: + """Close this :class:`.AsyncTransaction`. + + If this transaction is the base transaction in a begin/commit + nesting, the transaction will rollback(). Otherwise, the + method returns. + + This is used to cancel a Transaction without affecting the scope of + an enclosing transaction. + + """ + await greenlet_spawn(self._proxied.close) + + async def rollback(self) -> None: + """Roll back this :class:`.AsyncTransaction`.""" + await greenlet_spawn(self._proxied.rollback) + + async def commit(self) -> None: + """Commit this :class:`.AsyncTransaction`.""" + + await greenlet_spawn(self._proxied.commit) + + async def start(self, is_ctxmanager: bool = False) -> AsyncTransaction: + """Start this :class:`_asyncio.AsyncTransaction` object's context + outside of using a Python ``with:`` block. + + """ + + self.sync_transaction = self._assign_proxied( + await greenlet_spawn( + self.connection._proxied.begin_nested + if self.nested + else self.connection._proxied.begin + ) + ) + if is_ctxmanager: + self.sync_transaction.__enter__() + return self + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + await greenlet_spawn(self._proxied.__exit__, type_, value, traceback) + + +@overload +def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine: ... + + +@overload +def _get_sync_engine_or_connection( + async_engine: AsyncConnection, +) -> Connection: ... + + +def _get_sync_engine_or_connection( + async_engine: Union[AsyncEngine, AsyncConnection] +) -> Union[Engine, Connection]: + if isinstance(async_engine, AsyncConnection): + return async_engine._proxied + + try: + return async_engine.sync_engine + except AttributeError as e: + raise exc.ArgumentError( + "AsyncEngine expected, got %r" % async_engine + ) from e + + +@inspection._inspects(AsyncConnection) +def _no_insp_for_async_conn_yet( + subject: AsyncConnection, # noqa: U100 +) -> NoReturn: + raise exc.NoInspectionAvailable( + "Inspection on an AsyncConnection is currently not supported. " + "Please use ``run_sync`` to pass a callable where it's possible " + "to call ``inspect`` on the passed connection.", + code="xd3s", + ) + + +@inspection._inspects(AsyncEngine) +def _no_insp_for_async_engine_xyet( + subject: AsyncEngine, # noqa: U100 +) -> NoReturn: + raise exc.NoInspectionAvailable( + "Inspection on an AsyncEngine is currently not supported. " + "Please obtain a connection then use ``conn.run_sync`` to pass a " + "callable where it's possible to call ``inspect`` on the " + "passed connection.", + code="xd3s", + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/exc.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/exc.py new file mode 100644 index 0000000..1cf6f36 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/exc.py @@ -0,0 +1,21 @@ +# ext/asyncio/exc.py +# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from ... import exc + + +class AsyncMethodRequired(exc.InvalidRequestError): + """an API can't be used because its result would not be + compatible with async""" + + +class AsyncContextNotStarted(exc.InvalidRequestError): + """a startable context manager has not been started.""" + + +class AsyncContextAlreadyStarted(exc.InvalidRequestError): + """a startable context manager is already started.""" diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/result.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/result.py new file mode 100644 index 0000000..7dcbe32 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/result.py @@ -0,0 +1,961 @@ +# ext/asyncio/result.py +# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import operator +from typing import Any +from typing import AsyncIterator +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar + +from . import exc as async_exc +from ... import util +from ...engine import Result +from ...engine.result import _NO_ROW +from ...engine.result import _R +from ...engine.result import _WithKeys +from ...engine.result import FilterResult +from ...engine.result import FrozenResult +from ...engine.result import ResultMetaData +from ...engine.row import Row +from ...engine.row import RowMapping +from ...sql.base import _generative +from ...util.concurrency import greenlet_spawn +from ...util.typing import Literal +from ...util.typing import Self + +if TYPE_CHECKING: + from ...engine import CursorResult + from ...engine.result import _KeyIndexType + from ...engine.result import _UniqueFilterType + +_T = TypeVar("_T", bound=Any) +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + + +class AsyncCommon(FilterResult[_R]): + __slots__ = () + + _real_result: Result[Any] + _metadata: ResultMetaData + + async def close(self) -> None: # type: ignore[override] + """Close this result.""" + + await greenlet_spawn(self._real_result.close) + + @property + def closed(self) -> bool: + """proxies the .closed attribute of the underlying result object, + if any, else raises ``AttributeError``. + + .. versionadded:: 2.0.0b3 + + """ + return self._real_result.closed + + +class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): + """An asyncio wrapper around a :class:`_result.Result` object. + + The :class:`_asyncio.AsyncResult` only applies to statement executions that + use a server-side cursor. It is returned only from the + :meth:`_asyncio.AsyncConnection.stream` and + :meth:`_asyncio.AsyncSession.stream` methods. + + .. note:: As is the case with :class:`_engine.Result`, this object is + used for ORM results returned by :meth:`_asyncio.AsyncSession.execute`, + which can yield instances of ORM mapped objects either individually or + within tuple-like rows. Note that these result objects do not + deduplicate instances or rows automatically as is the case with the + legacy :class:`_orm.Query` object. For in-Python de-duplication of + instances or rows, use the :meth:`_asyncio.AsyncResult.unique` modifier + method. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + _real_result: Result[_TP] + + def __init__(self, real_result: Result[_TP]): + self._real_result = real_result + + self._metadata = real_result._metadata + self._unique_filter_state = real_result._unique_filter_state + self._post_creational_filter = None + + # BaseCursorResult pre-generates the "_row_getter". Use that + # if available rather than building a second one + if "_row_getter" in real_result.__dict__: + self._set_memoized_attribute( + "_row_getter", real_result.__dict__["_row_getter"] + ) + + @property + def t(self) -> AsyncTupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + The :attr:`_asyncio.AsyncResult.t` attribute is a synonym for + calling the :meth:`_asyncio.AsyncResult.tuples` method. + + .. versionadded:: 2.0 + + """ + return self # type: ignore + + def tuples(self) -> AsyncTupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + This method returns the same :class:`_asyncio.AsyncResult` object + at runtime, + however annotates as returning a :class:`_asyncio.AsyncTupleResult` + object that will indicate to :pep:`484` typing tools that plain typed + ``Tuple`` instances are returned rather than rows. This allows + tuple unpacking and ``__getitem__`` access of :class:`_engine.Row` + objects to by typed, for those cases where the statement invoked + itself included typing information. + + .. versionadded:: 2.0 + + :return: the :class:`_result.AsyncTupleResult` type at typing time. + + .. seealso:: + + :attr:`_asyncio.AsyncResult.t` - shorter synonym + + :attr:`_engine.Row.t` - :class:`_engine.Row` version + + """ + + return self # type: ignore + + @_generative + def unique(self, strategy: Optional[_UniqueFilterType] = None) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncResult`. + + Refer to :meth:`_engine.Result.unique` in the synchronous + SQLAlchemy API for a complete behavioral description. + + """ + self._unique_filter_state = (set(), strategy) + return self + + def columns(self, *col_expressions: _KeyIndexType) -> Self: + r"""Establish the columns that should be returned in each row. + + Refer to :meth:`_engine.Result.columns` in the synchronous + SQLAlchemy API for a complete behavioral description. + + """ + return self._column_slices(col_expressions) + + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[Sequence[Row[_TP]]]: + """Iterate through sub-lists of rows of the size given. + + An async iterator is returned:: + + async def scroll_results(connection): + result = await connection.stream(select(users_table)) + + async for partition in result.partitions(100): + print("list of rows: %s" % partition) + + Refer to :meth:`_engine.Result.partitions` in the synchronous + SQLAlchemy API for a complete behavioral description. + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchall(self) -> Sequence[Row[_TP]]: + """A synonym for the :meth:`_asyncio.AsyncResult.all` method. + + .. versionadded:: 2.0 + + """ + + return await greenlet_spawn(self._allrows) + + async def fetchone(self) -> Optional[Row[_TP]]: + """Fetch one row. + + When all rows are exhausted, returns None. + + This method is provided for backwards compatibility with + SQLAlchemy 1.x.x. + + To fetch the first row of a result only, use the + :meth:`_asyncio.AsyncResult.first` method. To iterate through all + rows, iterate the :class:`_asyncio.AsyncResult` object directly. + + :return: a :class:`_engine.Row` object if no filters are applied, + or ``None`` if no rows remain. + + """ + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + return None + else: + return row + + async def fetchmany( + self, size: Optional[int] = None + ) -> Sequence[Row[_TP]]: + """Fetch many rows. + + When all rows are exhausted, returns an empty list. + + This method is provided for backwards compatibility with + SQLAlchemy 1.x.x. + + To fetch rows in groups, use the + :meth:`._asyncio.AsyncResult.partitions` method. + + :return: a list of :class:`_engine.Row` objects. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.partitions` + + """ + + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self) -> Sequence[Row[_TP]]: + """Return all rows in a list. + + Closes the result set after invocation. Subsequent invocations + will return an empty list. + + :return: a list of :class:`_engine.Row` objects. + + """ + + return await greenlet_spawn(self._allrows) + + def __aiter__(self) -> AsyncResult[_TP]: + return self + + async def __anext__(self) -> Row[_TP]: + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self) -> Optional[Row[_TP]]: + """Fetch the first row or ``None`` if no row is present. + + Closes the result set and discards remaining rows. + + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the + :meth:`_asyncio.AsyncResult.scalar` method, + or combine :meth:`_asyncio.AsyncResult.scalars` and + :meth:`_asyncio.AsyncResult.first`. + + Additionally, in contrast to the behavior of the legacy ORM + :meth:`_orm.Query.first` method, **no limit is applied** to the + SQL query which was invoked to produce this + :class:`_asyncio.AsyncResult`; + for a DBAPI driver that buffers results in memory before yielding + rows, all rows will be sent to the Python process and all but + the first row will be discarded. + + .. seealso:: + + :ref:`migration_20_unify_select` + + :return: a :class:`_engine.Row` object, or None + if no rows remain. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.scalar` + + :meth:`_asyncio.AsyncResult.one` + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self) -> Optional[Row[_TP]]: + """Return at most one result or raise an exception. + + Returns ``None`` if the result has no rows. + Raises :class:`.MultipleResultsFound` + if multiple rows are returned. + + .. versionadded:: 1.4 + + :return: The first :class:`_engine.Row` or ``None`` if no row + is available. + + :raises: :class:`.MultipleResultsFound` + + .. seealso:: + + :meth:`_asyncio.AsyncResult.first` + + :meth:`_asyncio.AsyncResult.one` + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + @overload + async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T: ... + + @overload + async def scalar_one(self) -> Any: ... + + async def scalar_one(self) -> Any: + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and + then :meth:`_asyncio.AsyncResult.one`. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.one` + + :meth:`_asyncio.AsyncResult.scalars` + + """ + return await greenlet_spawn(self._only_one_row, True, True, True) + + @overload + async def scalar_one_or_none( + self: AsyncResult[Tuple[_T]], + ) -> Optional[_T]: ... + + @overload + async def scalar_one_or_none(self) -> Optional[Any]: ... + + async def scalar_one_or_none(self) -> Optional[Any]: + """Return exactly one scalar result or ``None``. + + This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and + then :meth:`_asyncio.AsyncResult.one_or_none`. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.one_or_none` + + :meth:`_asyncio.AsyncResult.scalars` + + """ + return await greenlet_spawn(self._only_one_row, True, False, True) + + async def one(self) -> Row[_TP]: + """Return exactly one row or raise an exception. + + Raises :class:`.NoResultFound` if the result returns no + rows, or :class:`.MultipleResultsFound` if multiple rows + would be returned. + + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the + :meth:`_asyncio.AsyncResult.scalar_one` method, or combine + :meth:`_asyncio.AsyncResult.scalars` and + :meth:`_asyncio.AsyncResult.one`. + + .. versionadded:: 1.4 + + :return: The first :class:`_engine.Row`. + + :raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound` + + .. seealso:: + + :meth:`_asyncio.AsyncResult.first` + + :meth:`_asyncio.AsyncResult.one_or_none` + + :meth:`_asyncio.AsyncResult.scalar_one` + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) + + @overload + async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]: ... + + @overload + async def scalar(self) -> Any: ... + + async def scalar(self) -> Any: + """Fetch the first column of the first row, and close the result set. + + Returns ``None`` if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value, or ``None`` if no rows remain. + + """ + return await greenlet_spawn(self._only_one_row, False, False, True) + + async def freeze(self) -> FrozenResult[_TP]: + """Return a callable object that will produce copies of this + :class:`_asyncio.AsyncResult` when invoked. + + The callable object returned is an instance of + :class:`_engine.FrozenResult`. + + This is used for result set caching. The method must be called + on the result when it has been unconsumed, and calling the method + will consume the result fully. When the :class:`_engine.FrozenResult` + is retrieved from a cache, it can be called any number of times where + it will produce a new :class:`_engine.Result` object each time + against its stored set of rows. + + .. seealso:: + + :ref:`do_orm_execute_re_executing` - example usage within the + ORM to implement a result-set cache. + + """ + + return await greenlet_spawn(FrozenResult, self) + + @overload + def scalars( + self: AsyncResult[Tuple[_T]], index: Literal[0] + ) -> AsyncScalarResult[_T]: ... + + @overload + def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]: ... + + @overload + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: ... + + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: + """Return an :class:`_asyncio.AsyncScalarResult` filtering object which + will return single elements rather than :class:`_row.Row` objects. + + Refer to :meth:`_result.Result.scalars` in the synchronous + SQLAlchemy API for a complete behavioral description. + + :param index: integer or row key indicating the column to be fetched + from each row, defaults to ``0`` indicating the first column. + + :return: a new :class:`_asyncio.AsyncScalarResult` filtering object + referring to this :class:`_asyncio.AsyncResult` object. + + """ + return AsyncScalarResult(self._real_result, index) + + def mappings(self) -> AsyncMappingResult: + """Apply a mappings filter to returned rows, returning an instance of + :class:`_asyncio.AsyncMappingResult`. + + When this filter is applied, fetching rows will return + :class:`_engine.RowMapping` objects instead of :class:`_engine.Row` + objects. + + :return: a new :class:`_asyncio.AsyncMappingResult` filtering object + referring to the underlying :class:`_result.Result` object. + + """ + + return AsyncMappingResult(self._real_result) + + +class AsyncScalarResult(AsyncCommon[_R]): + """A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values + rather than :class:`_row.Row` values. + + The :class:`_asyncio.AsyncScalarResult` object is acquired by calling the + :meth:`_asyncio.AsyncResult.scalars` method. + + Refer to the :class:`_result.ScalarResult` object in the synchronous + SQLAlchemy API for a complete behavioral description. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + _generate_rows = False + + def __init__(self, real_result: Result[Any], index: _KeyIndexType): + self._real_result = real_result + + if real_result._source_supports_scalars: + self._metadata = real_result._metadata + self._post_creational_filter = None + else: + self._metadata = real_result._metadata._reduce([index]) + self._post_creational_filter = operator.itemgetter(0) + + self._unique_filter_state = real_result._unique_filter_state + + def unique( + self, + strategy: Optional[_UniqueFilterType] = None, + ) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncScalarResult`. + + See :meth:`_asyncio.AsyncResult.unique` for usage details. + + """ + self._unique_filter_state = (set(), strategy) + return self + + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[Sequence[_R]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchall(self) -> Sequence[_R]: + """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method.""" + + return await greenlet_spawn(self._allrows) + + async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: + """Fetch many objects. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self) -> Sequence[_R]: + """Return all scalar values in a list. + + Equivalent to :meth:`_asyncio.AsyncResult.all` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._allrows) + + def __aiter__(self) -> AsyncScalarResult[_R]: + return self + + async def __anext__(self) -> _R: + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self) -> Optional[_R]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_asyncio.AsyncResult.first` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self) -> Optional[_R]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + async def one(self) -> _R: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) + + +class AsyncMappingResult(_WithKeys, AsyncCommon[RowMapping]): + """A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary + values rather than :class:`_engine.Row` values. + + The :class:`_asyncio.AsyncMappingResult` object is acquired by calling the + :meth:`_asyncio.AsyncResult.mappings` method. + + Refer to the :class:`_result.MappingResult` object in the synchronous + SQLAlchemy API for a complete behavioral description. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + _generate_rows = True + + _post_creational_filter = operator.attrgetter("_mapping") + + def __init__(self, result: Result[Any]): + self._real_result = result + self._unique_filter_state = result._unique_filter_state + self._metadata = result._metadata + if result._source_supports_scalars: + self._metadata = self._metadata._reduce([0]) + + def unique( + self, + strategy: Optional[_UniqueFilterType] = None, + ) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncMappingResult`. + + See :meth:`_asyncio.AsyncResult.unique` for usage details. + + """ + self._unique_filter_state = (set(), strategy) + return self + + def columns(self, *col_expressions: _KeyIndexType) -> Self: + r"""Establish the columns that should be returned in each row.""" + return self._column_slices(col_expressions) + + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[Sequence[RowMapping]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchall(self) -> Sequence[RowMapping]: + """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method.""" + + return await greenlet_spawn(self._allrows) + + async def fetchone(self) -> Optional[RowMapping]: + """Fetch one object. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + return None + else: + return row + + async def fetchmany( + self, size: Optional[int] = None + ) -> Sequence[RowMapping]: + """Fetch many rows. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self) -> Sequence[RowMapping]: + """Return all rows in a list. + + Equivalent to :meth:`_asyncio.AsyncResult.all` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + return await greenlet_spawn(self._allrows) + + def __aiter__(self) -> AsyncMappingResult: + return self + + async def __anext__(self) -> RowMapping: + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self) -> Optional[RowMapping]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_asyncio.AsyncResult.first` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self) -> Optional[RowMapping]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + async def one(self) -> RowMapping: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) + + +class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly): + """A :class:`_asyncio.AsyncResult` that's typed as returning plain + Python tuples instead of rows. + + Since :class:`_engine.Row` acts like a tuple in every way already, + this class is a typing only class, regular :class:`_asyncio.AsyncResult` is + still used at runtime. + + """ + + __slots__ = () + + if TYPE_CHECKING: + + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[Sequence[_R]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_result.Result.partitions` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + async def fetchone(self) -> Optional[_R]: + """Fetch one tuple. + + Equivalent to :meth:`_result.Result.fetchone` except that + tuple values, rather than :class:`_engine.Row` + objects, are returned. + + """ + ... + + async def fetchall(self) -> Sequence[_R]: + """A synonym for the :meth:`_engine.ScalarResult.all` method.""" + ... + + async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: + """Fetch many objects. + + Equivalent to :meth:`_result.Result.fetchmany` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + async def all(self) -> Sequence[_R]: # noqa: A001 + """Return all scalar values in a list. + + Equivalent to :meth:`_result.Result.all` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + async def __aiter__(self) -> AsyncIterator[_R]: ... + + async def __anext__(self) -> _R: ... + + async def first(self) -> Optional[_R]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_result.Result.first` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + + """ + ... + + async def one_or_none(self) -> Optional[_R]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_result.Result.one_or_none` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + async def one(self) -> _R: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_result.Result.one` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + @overload + async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: ... + + @overload + async def scalar_one(self) -> Any: ... + + async def scalar_one(self) -> Any: + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`_engine.Result.scalars` + and then :meth:`_engine.Result.one`. + + .. seealso:: + + :meth:`_engine.Result.one` + + :meth:`_engine.Result.scalars` + + """ + ... + + @overload + async def scalar_one_or_none( + self: AsyncTupleResult[Tuple[_T]], + ) -> Optional[_T]: ... + + @overload + async def scalar_one_or_none(self) -> Optional[Any]: ... + + async def scalar_one_or_none(self) -> Optional[Any]: + """Return exactly one or no scalar result. + + This is equivalent to calling :meth:`_engine.Result.scalars` + and then :meth:`_engine.Result.one_or_none`. + + .. seealso:: + + :meth:`_engine.Result.one_or_none` + + :meth:`_engine.Result.scalars` + + """ + ... + + @overload + async def scalar( + self: AsyncTupleResult[Tuple[_T]], + ) -> Optional[_T]: ... + + @overload + async def scalar(self) -> Any: ... + + async def scalar(self) -> Any: + """Fetch the first column of the first row, and close the result + set. + + Returns ``None`` if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value , or ``None`` if no rows remain. + + """ + ... + + +_RT = TypeVar("_RT", bound="Result[Any]") + + +async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT: + cursor_result: CursorResult[Any] + + try: + is_cursor = result._is_cursor + except AttributeError: + # legacy execute(DefaultGenerator) case + return result + + if not is_cursor: + cursor_result = getattr(result, "raw", None) # type: ignore + else: + cursor_result = result # type: ignore + if cursor_result and cursor_result.context._is_server_side: + await greenlet_spawn(cursor_result.close) + raise async_exc.AsyncMethodRequired( + "Can't use the %s.%s() method with a " + "server-side cursor. " + "Use the %s.stream() method for an async " + "streaming result set." + % ( + calling_method.__self__.__class__.__name__, + calling_method.__name__, + calling_method.__self__.__class__.__name__, + ) + ) + return result diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/scoping.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/scoping.py new file mode 100644 index 0000000..e879a16 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/scoping.py @@ -0,0 +1,1614 @@ +# ext/asyncio/scoping.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .session import _AS +from .session import async_sessionmaker +from .session import AsyncSession +from ... import exc as sa_exc +from ... import util +from ...orm.session import Session +from ...util import create_proxy_methods +from ...util import ScopedRegistry +from ...util import warn +from ...util import warn_deprecated + +if TYPE_CHECKING: + from .engine import AsyncConnection + from .result import AsyncResult + from .result import AsyncScalarResult + from .session import AsyncSessionTransaction + from ...engine import Connection + from ...engine import CursorResult + from ...engine import Engine + from ...engine import Result + from ...engine import Row + from ...engine import RowMapping + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import CoreExecuteOptionsParameter + from ...engine.result import ScalarResult + from ...orm._typing import _IdentityKeyType + from ...orm._typing import _O + from ...orm._typing import OrmExecuteOptionsParameter + from ...orm.interfaces import ORMOption + from ...orm.session import _BindArguments + from ...orm.session import _EntityBindKey + from ...orm.session import _PKIdentityArgument + from ...orm.session import _SessionBind + from ...sql.base import Executable + from ...sql.dml import UpdateBase + from ...sql.elements import ClauseElement + from ...sql.selectable import ForUpdateParameter + from ...sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) + + +@create_proxy_methods( + AsyncSession, + ":class:`_asyncio.AsyncSession`", + ":class:`_asyncio.scoping.async_scoped_session`", + classmethods=["close_all", "object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "aclose", + "add", + "add_all", + "begin", + "begin_nested", + "close", + "reset", + "commit", + "connection", + "delete", + "execute", + "expire", + "expire_all", + "expunge", + "expunge_all", + "flush", + "get_bind", + "is_modified", + "invalidate", + "merge", + "refresh", + "rollback", + "scalar", + "scalars", + "get", + "get_one", + "stream", + "stream_scalars", + ], + attributes=[ + "bind", + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + ], + use_intermediate_variable=["get"], +) +class async_scoped_session(Generic[_AS]): + """Provides scoped management of :class:`.AsyncSession` objects. + + See the section :ref:`asyncio_scoped_session` for usage details. + + .. versionadded:: 1.4.19 + + + """ + + _support_async = True + + session_factory: async_sessionmaker[_AS] + """The `session_factory` provided to `__init__` is stored in this + attribute and may be accessed at a later time. This can be useful when + a new non-scoped :class:`.AsyncSession` is needed.""" + + registry: ScopedRegistry[_AS] + + def __init__( + self, + session_factory: async_sessionmaker[_AS], + scopefunc: Callable[[], Any], + ): + """Construct a new :class:`_asyncio.async_scoped_session`. + + :param session_factory: a factory to create new :class:`_asyncio.AsyncSession` + instances. This is usually, but not necessarily, an instance + of :class:`_asyncio.async_sessionmaker`. + + :param scopefunc: function which defines + the current scope. A function such as ``asyncio.current_task`` + may be useful here. + + """ # noqa: E501 + + self.session_factory = session_factory + self.registry = ScopedRegistry(session_factory, scopefunc) + + @property + def _proxied(self) -> _AS: + return self.registry() + + def __call__(self, **kw: Any) -> _AS: + r"""Return the current :class:`.AsyncSession`, creating it + using the :attr:`.scoped_session.session_factory` if not present. + + :param \**kw: Keyword arguments will be passed to the + :attr:`.scoped_session.session_factory` callable, if an existing + :class:`.AsyncSession` is not present. If the + :class:`.AsyncSession` is present + and keyword arguments have been passed, + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + if kw: + if self.registry.has(): + raise sa_exc.InvalidRequestError( + "Scoped session is already present; " + "no new arguments may be specified." + ) + else: + sess = self.session_factory(**kw) + self.registry.set(sess) + else: + sess = self.registry() + if not self._support_async and sess._is_asyncio: + warn_deprecated( + "Using `scoped_session` with asyncio is deprecated and " + "will raise an error in a future version. " + "Please use `async_scoped_session` instead.", + "1.4.23", + ) + return sess + + def configure(self, **kwargs: Any) -> None: + """reconfigure the :class:`.sessionmaker` used by this + :class:`.scoped_session`. + + See :meth:`.sessionmaker.configure`. + + """ + + if self.registry.has(): + warn( + "At least one scoped session is already present. " + " configure() can not affect sessions that have " + "already been created." + ) + + self.session_factory.configure(**kwargs) + + async def remove(self) -> None: + """Dispose of the current :class:`.AsyncSession`, if present. + + Different from scoped_session's remove method, this method would use + await to wait for the close method of AsyncSession. + + """ + + if self.registry.has(): + await self.registry().close() + self.registry.clear() + + # START PROXY METHODS async_scoped_session + + # code within this block is **programmatically, + # statically generated** by tools/generate_proxy_methods.py + + def __contains__(self, instance: object) -> bool: + r"""Return True if the instance is associated with this session. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + The instance may be pending or persistent within the Session for a + result of True. + + + + """ # noqa: E501 + + return self._proxied.__contains__(instance) + + def __iter__(self) -> Iterator[object]: + r"""Iterate over all pending or persistent instances within this + Session. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + + + """ # noqa: E501 + + return self._proxied.__iter__() + + async def aclose(self) -> None: + r"""A synonym for :meth:`_asyncio.AsyncSession.close`. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + The :meth:`_asyncio.AsyncSession.aclose` name is specifically + to support the Python standard library ``@contextlib.aclosing`` + context manager function. + + .. versionadded:: 2.0.20 + + + """ # noqa: E501 + + return await self._proxied.aclose() + + def add(self, instance: object, _warn: bool = True) -> None: + r"""Place an object into this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + Objects that are in the :term:`transient` state when passed to the + :meth:`_orm.Session.add` method will move to the + :term:`pending` state, until the next flush, at which point they + will move to the :term:`persistent` state. + + Objects that are in the :term:`detached` state when passed to the + :meth:`_orm.Session.add` method will move to the :term:`persistent` + state directly. + + If the transaction used by the :class:`_orm.Session` is rolled back, + objects which were transient when they were passed to + :meth:`_orm.Session.add` will be moved back to the + :term:`transient` state, and will no longer be present within this + :class:`_orm.Session`. + + .. seealso:: + + :meth:`_orm.Session.add_all` + + :ref:`session_adding` - at :ref:`session_basics` + + + + """ # noqa: E501 + + return self._proxied.add(instance, _warn=_warn) + + def add_all(self, instances: Iterable[object]) -> None: + r"""Add the given collection of instances to this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + See the documentation for :meth:`_orm.Session.add` for a general + behavioral description. + + .. seealso:: + + :meth:`_orm.Session.add` + + :ref:`session_adding` - at :ref:`session_basics` + + + + """ # noqa: E501 + + return self._proxied.add_all(instances) + + def begin(self) -> AsyncSessionTransaction: + r"""Return an :class:`_asyncio.AsyncSessionTransaction` object. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + The underlying :class:`_orm.Session` will perform the + "begin" action when the :class:`_asyncio.AsyncSessionTransaction` + object is entered:: + + async with async_session.begin(): + # .. ORM transaction is begun + + Note that database IO will not normally occur when the session-level + transaction is begun, as database transactions begin on an + on-demand basis. However, the begin block is async to accommodate + for a :meth:`_orm.SessionEvents.after_transaction_create` + event hook that may perform IO. + + For a general description of ORM begin, see + :meth:`_orm.Session.begin`. + + + """ # noqa: E501 + + return self._proxied.begin() + + def begin_nested(self) -> AsyncSessionTransaction: + r"""Return an :class:`_asyncio.AsyncSessionTransaction` object + which will begin a "nested" transaction, e.g. SAVEPOINT. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`. + + For a general description of ORM begin nested, see + :meth:`_orm.Session.begin_nested`. + + .. seealso:: + + :ref:`aiosqlite_serializable` - special workarounds required + with the SQLite asyncio driver in order for SAVEPOINT to work + correctly. + + + """ # noqa: E501 + + return self._proxied.begin_nested() + + async def close(self) -> None: + r"""Close out the transactional resources and ORM objects used by this + :class:`_asyncio.AsyncSession`. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.close` - main documentation for + "close" + + :ref:`session_closing` - detail on the semantics of + :meth:`_asyncio.AsyncSession.close` and + :meth:`_asyncio.AsyncSession.reset`. + + + """ # noqa: E501 + + return await self._proxied.close() + + async def reset(self) -> None: + r"""Close out the transactional resources and ORM objects used by this + :class:`_orm.Session`, resetting the session to its initial state. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. versionadded:: 2.0.22 + + .. seealso:: + + :meth:`_orm.Session.reset` - main documentation for + "reset" + + :ref:`session_closing` - detail on the semantics of + :meth:`_asyncio.AsyncSession.close` and + :meth:`_asyncio.AsyncSession.reset`. + + + """ # noqa: E501 + + return await self._proxied.reset() + + async def commit(self) -> None: + r"""Commit the current transaction in progress. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.commit` - main documentation for + "commit" + + """ # noqa: E501 + + return await self._proxied.commit() + + async def connection( + self, + bind_arguments: Optional[_BindArguments] = None, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + **kw: Any, + ) -> AsyncConnection: + r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to + this :class:`.Session` object's transactional state. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + This method may also be used to establish execution options for the + database connection used by the current transaction. + + .. versionadded:: 1.4.24 Added \**kw arguments which are passed + through to the underlying :meth:`_orm.Session.connection` method. + + .. seealso:: + + :meth:`_orm.Session.connection` - main documentation for + "connection" + + + """ # noqa: E501 + + return await self._proxied.connection( + bind_arguments=bind_arguments, + execution_options=execution_options, + **kw, + ) + + async def delete(self, instance: object) -> None: + r"""Mark an instance as deleted. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + The database delete operation occurs upon ``flush()``. + + As this operation may need to cascade along unloaded relationships, + it is awaitable to allow for those queries to take place. + + .. seealso:: + + :meth:`_orm.Session.delete` - main documentation for delete + + + """ # noqa: E501 + + return await self._proxied.delete(instance) + + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: ... + + @overload + async def execute( + self, + statement: UpdateBase, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> CursorResult[Any]: ... + + @overload + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: ... + + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Result[Any]: + r"""Execute a statement and return a buffered + :class:`_engine.Result` object. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.execute` - main documentation for execute + + + """ # noqa: E501 + + return await self._proxied.execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: + r"""Expire the attributes on an instance. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + Marks the attributes of an instance as out of date. When an expired + attribute is next accessed, a query will be issued to the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire all objects in the :class:`.Session` simultaneously, + use :meth:`Session.expire_all`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire` only makes sense for the specific + case that a non-ORM SQL statement was emitted in the current + transaction. + + :param instance: The instance to be refreshed. + :param attribute_names: optional list of string attribute names + indicating a subset of attributes to be expired. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + + """ # noqa: E501 + + return self._proxied.expire(instance, attribute_names=attribute_names) + + def expire_all(self) -> None: + r"""Expires all persistent instances within this Session. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + When any attributes on a persistent instance is next accessed, + a query will be issued using the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire individual objects and individual attributes + on those objects, use :meth:`Session.expire`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire_all` is not usually needed, + assuming the transaction is isolated. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + + """ # noqa: E501 + + return self._proxied.expire_all() + + def expunge(self, instance: object) -> None: + r"""Remove the `instance` from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This will free all internal references to the instance. Cascading + will be applied according to the *expunge* cascade rule. + + + + """ # noqa: E501 + + return self._proxied.expunge(instance) + + def expunge_all(self) -> None: + r"""Remove all object instances from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is equivalent to calling ``expunge(obj)`` on all objects in this + ``Session``. + + + + """ # noqa: E501 + + return self._proxied.expunge_all() + + async def flush(self, objects: Optional[Sequence[Any]] = None) -> None: + r"""Flush all the object changes to the database. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.flush` - main documentation for flush + + + """ # noqa: E501 + + return await self._proxied.flush(objects=objects) + + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + **kw: Any, + ) -> Union[Engine, Connection]: + r"""Return a "bind" to which the synchronous proxied :class:`_orm.Session` + is bound. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + Unlike the :meth:`_orm.Session.get_bind` method, this method is + currently **not** used by this :class:`.AsyncSession` in any way + in order to resolve engines for requests. + + .. note:: + + This method proxies directly to the :meth:`_orm.Session.get_bind` + method, however is currently **not** useful as an override target, + in contrast to that of the :meth:`_orm.Session.get_bind` method. + The example below illustrates how to implement custom + :meth:`_orm.Session.get_bind` schemes that work with + :class:`.AsyncSession` and :class:`.AsyncEngine`. + + The pattern introduced at :ref:`session_custom_partitioning` + illustrates how to apply a custom bind-lookup scheme to a + :class:`_orm.Session` given a set of :class:`_engine.Engine` objects. + To apply a corresponding :meth:`_orm.Session.get_bind` implementation + for use with a :class:`.AsyncSession` and :class:`.AsyncEngine` + objects, continue to subclass :class:`_orm.Session` and apply it to + :class:`.AsyncSession` using + :paramref:`.AsyncSession.sync_session_class`. The inner method must + continue to return :class:`_engine.Engine` instances, which can be + acquired from a :class:`_asyncio.AsyncEngine` using the + :attr:`_asyncio.AsyncEngine.sync_engine` attribute:: + + # using example from "Custom Vertical Partitioning" + + + import random + + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.orm import Session + + # construct async engines w/ async drivers + engines = { + 'leader':create_async_engine("sqlite+aiosqlite:///leader.db"), + 'other':create_async_engine("sqlite+aiosqlite:///other.db"), + 'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"), + 'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"), + } + + class RoutingSession(Session): + def get_bind(self, mapper=None, clause=None, **kw): + # within get_bind(), return sync engines + if mapper and issubclass(mapper.class_, MyOtherClass): + return engines['other'].sync_engine + elif self._flushing or isinstance(clause, (Update, Delete)): + return engines['leader'].sync_engine + else: + return engines[ + random.choice(['follower1','follower2']) + ].sync_engine + + # apply to AsyncSession using sync_session_class + AsyncSessionMaker = async_sessionmaker( + sync_session_class=RoutingSession + ) + + The :meth:`_orm.Session.get_bind` method is called in a non-asyncio, + implicitly non-blocking context in the same manner as ORM event hooks + and functions that are invoked via :meth:`.AsyncSession.run_sync`, so + routines that wish to run SQL commands inside of + :meth:`_orm.Session.get_bind` can continue to do so using + blocking-style code, which will be translated to implicitly async calls + at the point of invoking IO on the database drivers. + + + """ # noqa: E501 + + return self._proxied.get_bind( + mapper=mapper, clause=clause, bind=bind, **kw + ) + + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: + r"""Return ``True`` if the given instance has locally + modified attributes. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This method retrieves the history for each instrumented + attribute on the instance and performs a comparison of the current + value to its previously committed value, if any. + + It is in effect a more expensive and accurate + version of checking for the given instance in the + :attr:`.Session.dirty` collection; a full test for + each attribute's net "dirty" status is performed. + + E.g.:: + + return session.is_modified(someobject) + + A few caveats to this method apply: + + * Instances present in the :attr:`.Session.dirty` collection may + report ``False`` when tested with this method. This is because + the object may have received change events via attribute mutation, + thus placing it in :attr:`.Session.dirty`, but ultimately the state + is the same as that loaded from the database, resulting in no net + change here. + * Scalar attributes may not have recorded the previously set + value when a new value was applied, if the attribute was not loaded, + or was expired, at the time the new value was received - in these + cases, the attribute is assumed to have a change, even if there is + ultimately no net change against its database value. SQLAlchemy in + most cases does not need the "old" value when a set event occurs, so + it skips the expense of a SQL call if the old value isn't present, + based on the assumption that an UPDATE of the scalar value is + usually needed, and in those few cases where it isn't, is less + expensive on average than issuing a defensive SELECT. + + The "old" value is fetched unconditionally upon set only if the + attribute container has the ``active_history`` flag set to ``True``. + This flag is set typically for primary key attributes and scalar + object references that are not a simple many-to-one. To set this + flag for any arbitrary mapped column, use the ``active_history`` + argument with :func:`.column_property`. + + :param instance: mapped instance to be tested for pending changes. + :param include_collections: Indicates if multivalued collections + should be included in the operation. Setting this to ``False`` is a + way to detect only local-column based properties (i.e. scalar columns + or many-to-one foreign keys) that would result in an UPDATE for this + instance upon flush. + + + + """ # noqa: E501 + + return self._proxied.is_modified( + instance, include_collections=include_collections + ) + + async def invalidate(self) -> None: + r"""Close this Session, using connection invalidation. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + For a complete description, see :meth:`_orm.Session.invalidate`. + + """ # noqa: E501 + + return await self._proxied.invalidate() + + async def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: + r"""Copy the state of a given instance into a corresponding instance + within this :class:`_asyncio.AsyncSession`. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.merge` - main documentation for merge + + + """ # noqa: E501 + + return await self._proxied.merge(instance, load=load, options=options) + + async def refresh( + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: ForUpdateParameter = None, + ) -> None: + r"""Expire and refresh the attributes on the given instance. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + A query will be issued to the database and all attributes will be + refreshed with their current database value. + + This is the async version of the :meth:`_orm.Session.refresh` method. + See that method for a complete description of all options. + + .. seealso:: + + :meth:`_orm.Session.refresh` - main documentation for refresh + + + """ # noqa: E501 + + return await self._proxied.refresh( + instance, + attribute_names=attribute_names, + with_for_update=with_for_update, + ) + + async def rollback(self) -> None: + r"""Rollback the current transaction in progress. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.rollback` - main documentation for + "rollback" + + """ # noqa: E501 + + return await self._proxied.rollback() + + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: ... + + @overload + async def scalar( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: ... + + async def scalar( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + r"""Execute a statement and return a scalar result. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.scalar` - main documentation for scalar + + + """ # noqa: E501 + + return await self._proxied.scalar( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: ... + + @overload + async def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: ... + + async def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + r"""Execute a statement and return scalar results. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + :return: a :class:`_result.ScalarResult` object + + .. versionadded:: 1.4.24 Added :meth:`_asyncio.AsyncSession.scalars` + + .. versionadded:: 1.4.26 Added + :meth:`_asyncio.async_scoped_session.scalars` + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + :meth:`_asyncio.AsyncSession.stream_scalars` - streaming version + + + """ # noqa: E501 + + return await self._proxied.scalars( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + async def get( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Union[_O, None]: + r"""Return an instance based on the given primary key identifier, + or ``None`` if not found. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.get` - main documentation for get + + + + """ # noqa: E501 + + result = await self._proxied.get( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) + return result + + async def get_one( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + ) -> _O: + r"""Return an instance based on the given primary key identifier, + or raise an exception if not found. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. + + ..versionadded: 2.0.22 + + .. seealso:: + + :meth:`_orm.Session.get_one` - main documentation for get_one + + + """ # noqa: E501 + + return await self._proxied.get_one( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) + + @overload + async def stream( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[_T]: ... + + @overload + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: ... + + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: + r"""Execute a statement and return a streaming + :class:`_asyncio.AsyncResult` object. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + + """ # noqa: E501 + + return await self._proxied.stream( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @overload + async def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[_T]: ... + + @overload + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: ... + + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: + r"""Execute a statement and return a stream of scalar results. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + :return: an :class:`_asyncio.AsyncScalarResult` object + + .. versionadded:: 1.4.24 + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + :meth:`_asyncio.AsyncSession.scalars` - non streaming version + + + """ # noqa: E501 + + return await self._proxied.stream_scalars( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @property + def bind(self) -> Any: + r"""Proxy for the :attr:`_asyncio.AsyncSession.bind` attribute + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + """ # noqa: E501 + + return self._proxied.bind + + @bind.setter + def bind(self, attr: Any) -> None: + self._proxied.bind = attr + + @property + def dirty(self) -> Any: + r"""The set of all persistent instances considered dirty. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + E.g.:: + + some_mapped_object in session.dirty + + Instances are considered dirty when they were modified but not + deleted. + + Note that this 'dirty' calculation is 'optimistic'; most + attribute-setting or collection modification operations will + mark an instance as 'dirty' and place it in this set, even if + there is no net change to the attribute's value. At flush + time, the value of each attribute is compared to its + previously saved value, and if there's no net change, no SQL + operation will occur (this is a more expensive operation so + it's only done at flush time). + + To check if an instance has actionable net changes to its + attributes, use the :meth:`.Session.is_modified` method. + + + + """ # noqa: E501 + + return self._proxied.dirty + + @property + def deleted(self) -> Any: + r"""The set of all instances marked as 'deleted' within this ``Session`` + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + + """ # noqa: E501 + + return self._proxied.deleted + + @property + def new(self) -> Any: + r"""The set of all instances marked as 'new' within this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + + """ # noqa: E501 + + return self._proxied.new + + @property + def identity_map(self) -> Any: + r"""Proxy for the :attr:`_orm.Session.identity_map` attribute + on behalf of the :class:`_asyncio.AsyncSession` class. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + + """ # noqa: E501 + + return self._proxied.identity_map + + @identity_map.setter + def identity_map(self, attr: Any) -> None: + self._proxied.identity_map = attr + + @property + def is_active(self) -> Any: + r"""True if this :class:`.Session` not in "partial rollback" state. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins + a new transaction immediately, so this attribute will be False + when the :class:`_orm.Session` is first instantiated. + + "partial rollback" state typically indicates that the flush process + of the :class:`_orm.Session` has failed, and that the + :meth:`_orm.Session.rollback` method must be emitted in order to + fully roll back the transaction. + + If this :class:`_orm.Session` is not in a transaction at all, the + :class:`_orm.Session` will autobegin when it is first used, so in this + case :attr:`_orm.Session.is_active` will return True. + + Otherwise, if this :class:`_orm.Session` is within a transaction, + and that transaction has not been rolled back internally, the + :attr:`_orm.Session.is_active` will also return True. + + .. seealso:: + + :ref:`faq_session_rollback` + + :meth:`_orm.Session.in_transaction` + + + + """ # noqa: E501 + + return self._proxied.is_active + + @property + def autoflush(self) -> Any: + r"""Proxy for the :attr:`_orm.Session.autoflush` attribute + on behalf of the :class:`_asyncio.AsyncSession` class. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + + """ # noqa: E501 + + return self._proxied.autoflush + + @autoflush.setter + def autoflush(self, attr: Any) -> None: + self._proxied.autoflush = attr + + @property + def no_autoflush(self) -> Any: + r"""Return a context manager that disables autoflush. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + e.g.:: + + with session.no_autoflush: + + some_object = SomeClass() + session.add(some_object) + # won't autoflush + some_object.related_thing = session.query(SomeRelated).first() + + Operations that proceed within the ``with:`` block + will not be subject to flushes occurring upon query + access. This is useful when initializing a series + of objects which involve existing database queries, + where the uncompleted object should not yet be flushed. + + + + """ # noqa: E501 + + return self._proxied.no_autoflush + + @property + def info(self) -> Any: + r"""A user-modifiable dictionary. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + The initial value of this dictionary can be populated using the + ``info`` argument to the :class:`.Session` constructor or + :class:`.sessionmaker` constructor or factory methods. The dictionary + here is always local to this :class:`.Session` and can be modified + independently of all other :class:`.Session` objects. + + + + """ # noqa: E501 + + return self._proxied.info + + @classmethod + async def close_all(cls) -> None: + r"""Close all :class:`_asyncio.AsyncSession` sessions. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. deprecated:: 2.0 The :meth:`.AsyncSession.close_all` method is deprecated and will be removed in a future release. Please refer to :func:`_asyncio.close_all_sessions`. + + """ # noqa: E501 + + return await AsyncSession.close_all() + + @classmethod + def object_session(cls, instance: object) -> Optional[Session]: + r"""Return the :class:`.Session` to which an object belongs. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is an alias of :func:`.object_session`. + + + + """ # noqa: E501 + + return AsyncSession.object_session(instance) + + @classmethod + def identity_key( + cls, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, + *, + instance: Optional[Any] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: + r"""Return an identity key. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is an alias of :func:`.util.identity_key`. + + + + """ # noqa: E501 + + return AsyncSession.identity_key( + class_=class_, + ident=ident, + instance=instance, + row=row, + identity_token=identity_token, + ) + + # END PROXY METHODS async_scoped_session diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/session.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/session.py new file mode 100644 index 0000000..c5fe469 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/session.py @@ -0,0 +1,1936 @@ +# ext/asyncio/session.py +# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import asyncio +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import cast +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import engine +from .base import ReversibleProxy +from .base import StartableContext +from .result import _ensure_sync_result +from .result import AsyncResult +from .result import AsyncScalarResult +from ... import util +from ...orm import close_all_sessions as _sync_close_all_sessions +from ...orm import object_session +from ...orm import Session +from ...orm import SessionTransaction +from ...orm import state as _instance_state +from ...util.concurrency import greenlet_spawn +from ...util.typing import Concatenate +from ...util.typing import ParamSpec + + +if TYPE_CHECKING: + from .engine import AsyncConnection + from .engine import AsyncEngine + from ...engine import Connection + from ...engine import CursorResult + from ...engine import Engine + from ...engine import Result + from ...engine import Row + from ...engine import RowMapping + from ...engine import ScalarResult + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import CoreExecuteOptionsParameter + from ...event import dispatcher + from ...orm._typing import _IdentityKeyType + from ...orm._typing import _O + from ...orm._typing import OrmExecuteOptionsParameter + from ...orm.identity import IdentityMap + from ...orm.interfaces import ORMOption + from ...orm.session import _BindArguments + from ...orm.session import _EntityBindKey + from ...orm.session import _PKIdentityArgument + from ...orm.session import _SessionBind + from ...orm.session import _SessionBindKey + from ...sql._typing import _InfoType + from ...sql.base import Executable + from ...sql.dml import UpdateBase + from ...sql.elements import ClauseElement + from ...sql.selectable import ForUpdateParameter + from ...sql.selectable import TypedReturnsRows + +_AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"] + +_P = ParamSpec("_P") +_T = TypeVar("_T", bound=Any) + + +_EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True}) +_STREAM_OPTIONS = util.immutabledict({"stream_results": True}) + + +class AsyncAttrs: + """Mixin class which provides an awaitable accessor for all attributes. + + E.g.:: + + from __future__ import annotations + + from typing import List + + from sqlalchemy import ForeignKey + from sqlalchemy import func + from sqlalchemy.ext.asyncio import AsyncAttrs + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy.orm import relationship + + + class Base(AsyncAttrs, DeclarativeBase): + pass + + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + bs: Mapped[List[B]] = relationship() + + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + data: Mapped[str] + + In the above example, the :class:`_asyncio.AsyncAttrs` mixin is applied to + the declarative ``Base`` class where it takes effect for all subclasses. + This mixin adds a single new attribute + :attr:`_asyncio.AsyncAttrs.awaitable_attrs` to all classes, which will + yield the value of any attribute as an awaitable. This allows attributes + which may be subject to lazy loading or deferred / unexpiry loading to be + accessed such that IO can still be emitted:: + + a1 = (await async_session.scalars(select(A).where(A.id == 5))).one() + + # use the lazy loader on ``a1.bs`` via the ``.awaitable_attrs`` + # interface, so that it may be awaited + for b1 in await a1.awaitable_attrs.bs: + print(b1) + + The :attr:`_asyncio.AsyncAttrs.awaitable_attrs` performs a call against the + attribute that is approximately equivalent to using the + :meth:`_asyncio.AsyncSession.run_sync` method, e.g.:: + + for b1 in await async_session.run_sync(lambda sess: a1.bs): + print(b1) + + .. versionadded:: 2.0.13 + + .. seealso:: + + :ref:`asyncio_orm_avoid_lazyloads` + + """ + + class _AsyncAttrGetitem: + __slots__ = "_instance" + + def __init__(self, _instance: Any): + self._instance = _instance + + def __getattr__(self, name: str) -> Awaitable[Any]: + return greenlet_spawn(getattr, self._instance, name) + + @property + def awaitable_attrs(self) -> AsyncAttrs._AsyncAttrGetitem: + """provide a namespace of all attributes on this object wrapped + as awaitables. + + e.g.:: + + + a1 = (await async_session.scalars(select(A).where(A.id == 5))).one() + + some_attribute = await a1.awaitable_attrs.some_deferred_attribute + some_collection = await a1.awaitable_attrs.some_collection + + """ # noqa: E501 + + return AsyncAttrs._AsyncAttrGetitem(self) + + +@util.create_proxy_methods( + Session, + ":class:`_orm.Session`", + ":class:`_asyncio.AsyncSession`", + classmethods=["object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "add", + "add_all", + "expire", + "expire_all", + "expunge", + "expunge_all", + "is_modified", + "in_transaction", + "in_nested_transaction", + ], + attributes=[ + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + ], +) +class AsyncSession(ReversibleProxy[Session]): + """Asyncio version of :class:`_orm.Session`. + + The :class:`_asyncio.AsyncSession` is a proxy for a traditional + :class:`_orm.Session` instance. + + The :class:`_asyncio.AsyncSession` is **not safe for use in concurrent + tasks.**. See :ref:`session_faq_threadsafe` for background. + + .. versionadded:: 1.4 + + To use an :class:`_asyncio.AsyncSession` with custom :class:`_orm.Session` + implementations, see the + :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. + + + """ + + _is_asyncio = True + + dispatch: dispatcher[Session] + + def __init__( + self, + bind: Optional[_AsyncSessionBind] = None, + *, + binds: Optional[Dict[_SessionBindKey, _AsyncSessionBind]] = None, + sync_session_class: Optional[Type[Session]] = None, + **kw: Any, + ): + r"""Construct a new :class:`_asyncio.AsyncSession`. + + All parameters other than ``sync_session_class`` are passed to the + ``sync_session_class`` callable directly to instantiate a new + :class:`_orm.Session`. Refer to :meth:`_orm.Session.__init__` for + parameter documentation. + + :param sync_session_class: + A :class:`_orm.Session` subclass or other callable which will be used + to construct the :class:`_orm.Session` which will be proxied. This + parameter may be used to provide custom :class:`_orm.Session` + subclasses. Defaults to the + :attr:`_asyncio.AsyncSession.sync_session_class` class-level + attribute. + + .. versionadded:: 1.4.24 + + """ + sync_bind = sync_binds = None + + if bind: + self.bind = bind + sync_bind = engine._get_sync_engine_or_connection(bind) + + if binds: + self.binds = binds + sync_binds = { + key: engine._get_sync_engine_or_connection(b) + for key, b in binds.items() + } + + if sync_session_class: + self.sync_session_class = sync_session_class + + self.sync_session = self._proxied = self._assign_proxied( + self.sync_session_class(bind=sync_bind, binds=sync_binds, **kw) + ) + + sync_session_class: Type[Session] = Session + """The class or callable that provides the + underlying :class:`_orm.Session` instance for a particular + :class:`_asyncio.AsyncSession`. + + At the class level, this attribute is the default value for the + :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. Custom + subclasses of :class:`_asyncio.AsyncSession` can override this. + + At the instance level, this attribute indicates the current class or + callable that was used to provide the :class:`_orm.Session` instance for + this :class:`_asyncio.AsyncSession` instance. + + .. versionadded:: 1.4.24 + + """ + + sync_session: Session + """Reference to the underlying :class:`_orm.Session` this + :class:`_asyncio.AsyncSession` proxies requests towards. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + + """ + + @classmethod + def _no_async_engine_events(cls) -> NoReturn: + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncSession.sync_session." + ) + + async def refresh( + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: ForUpdateParameter = None, + ) -> None: + """Expire and refresh the attributes on the given instance. + + A query will be issued to the database and all attributes will be + refreshed with their current database value. + + This is the async version of the :meth:`_orm.Session.refresh` method. + See that method for a complete description of all options. + + .. seealso:: + + :meth:`_orm.Session.refresh` - main documentation for refresh + + """ + + await greenlet_spawn( + self.sync_session.refresh, + instance, + attribute_names=attribute_names, + with_for_update=with_for_update, + ) + + async def run_sync( + self, + fn: Callable[Concatenate[Session, _P], _T], + *arg: _P.args, + **kw: _P.kwargs, + ) -> _T: + """Invoke the given synchronous (i.e. not async) callable, + passing a synchronous-style :class:`_orm.Session` as the first + argument. + + This method allows traditional synchronous SQLAlchemy functions to + run within the context of an asyncio application. + + E.g.:: + + def some_business_method(session: Session, param: str) -> str: + '''A synchronous function that does not require awaiting + + :param session: a SQLAlchemy Session, used synchronously + + :return: an optional return value is supported + + ''' + session.add(MyObject(param=param)) + session.flush() + return "success" + + + async def do_something_async(async_engine: AsyncEngine) -> None: + '''an async function that uses awaiting''' + + with AsyncSession(async_engine) as async_session: + # run some_business_method() with a sync-style + # Session, proxied into an awaitable + return_code = await async_session.run_sync(some_business_method, param="param1") + print(return_code) + + This method maintains the asyncio event loop all the way through + to the database connection by running the given callable in a + specially instrumented greenlet. + + .. tip:: + + The provided callable is invoked inline within the asyncio event + loop, and will block on traditional IO calls. IO within this + callable should only call into SQLAlchemy's asyncio database + APIs which will be properly adapted to the greenlet context. + + .. seealso:: + + :class:`.AsyncAttrs` - a mixin for ORM mapped classes that provides + a similar feature more succinctly on a per-attribute basis + + :meth:`.AsyncConnection.run_sync` + + :ref:`session_run_sync` + """ # noqa: E501 + + return await greenlet_spawn( + fn, self.sync_session, *arg, _require_await=False, **kw + ) + + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: ... + + @overload + async def execute( + self, + statement: UpdateBase, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> CursorResult[Any]: ... + + @overload + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: ... + + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Result[Any]: + """Execute a statement and return a buffered + :class:`_engine.Result` object. + + .. seealso:: + + :meth:`_orm.Session.execute` - main documentation for execute + + """ + + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _EXECUTE_OPTIONS + ) + else: + execution_options = _EXECUTE_OPTIONS + + result = await greenlet_spawn( + self.sync_session.execute, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + return await _ensure_sync_result(result, self.execute) + + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: ... + + @overload + async def scalar( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: ... + + async def scalar( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + """Execute a statement and return a scalar result. + + .. seealso:: + + :meth:`_orm.Session.scalar` - main documentation for scalar + + """ + + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _EXECUTE_OPTIONS + ) + else: + execution_options = _EXECUTE_OPTIONS + + return await greenlet_spawn( + self.sync_session.scalar, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: ... + + @overload + async def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: ... + + async def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + """Execute a statement and return scalar results. + + :return: a :class:`_result.ScalarResult` object + + .. versionadded:: 1.4.24 Added :meth:`_asyncio.AsyncSession.scalars` + + .. versionadded:: 1.4.26 Added + :meth:`_asyncio.async_scoped_session.scalars` + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + :meth:`_asyncio.AsyncSession.stream_scalars` - streaming version + + """ + + result = await self.execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + return result.scalars() + + async def get( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Union[_O, None]: + """Return an instance based on the given primary key identifier, + or ``None`` if not found. + + .. seealso:: + + :meth:`_orm.Session.get` - main documentation for get + + + """ + + return await greenlet_spawn( + cast("Callable[..., _O]", self.sync_session.get), + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) + + async def get_one( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + ) -> _O: + """Return an instance based on the given primary key identifier, + or raise an exception if not found. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. + + ..versionadded: 2.0.22 + + .. seealso:: + + :meth:`_orm.Session.get_one` - main documentation for get_one + + """ + + return await greenlet_spawn( + cast("Callable[..., _O]", self.sync_session.get_one), + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) + + @overload + async def stream( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[_T]: ... + + @overload + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: ... + + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: + """Execute a statement and return a streaming + :class:`_asyncio.AsyncResult` object. + + """ + + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _STREAM_OPTIONS + ) + else: + execution_options = _STREAM_OPTIONS + + result = await greenlet_spawn( + self.sync_session.execute, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + return AsyncResult(result) + + @overload + async def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[_T]: ... + + @overload + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: ... + + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: + """Execute a statement and return a stream of scalar results. + + :return: an :class:`_asyncio.AsyncScalarResult` object + + .. versionadded:: 1.4.24 + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + :meth:`_asyncio.AsyncSession.scalars` - non streaming version + + """ + + result = await self.stream( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + return result.scalars() + + async def delete(self, instance: object) -> None: + """Mark an instance as deleted. + + The database delete operation occurs upon ``flush()``. + + As this operation may need to cascade along unloaded relationships, + it is awaitable to allow for those queries to take place. + + .. seealso:: + + :meth:`_orm.Session.delete` - main documentation for delete + + """ + await greenlet_spawn(self.sync_session.delete, instance) + + async def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: + """Copy the state of a given instance into a corresponding instance + within this :class:`_asyncio.AsyncSession`. + + .. seealso:: + + :meth:`_orm.Session.merge` - main documentation for merge + + """ + return await greenlet_spawn( + self.sync_session.merge, instance, load=load, options=options + ) + + async def flush(self, objects: Optional[Sequence[Any]] = None) -> None: + """Flush all the object changes to the database. + + .. seealso:: + + :meth:`_orm.Session.flush` - main documentation for flush + + """ + await greenlet_spawn(self.sync_session.flush, objects=objects) + + def get_transaction(self) -> Optional[AsyncSessionTransaction]: + """Return the current root transaction in progress, if any. + + :return: an :class:`_asyncio.AsyncSessionTransaction` object, or + ``None``. + + .. versionadded:: 1.4.18 + + """ + trans = self.sync_session.get_transaction() + if trans is not None: + return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + else: + return None + + def get_nested_transaction(self) -> Optional[AsyncSessionTransaction]: + """Return the current nested transaction in progress, if any. + + :return: an :class:`_asyncio.AsyncSessionTransaction` object, or + ``None``. + + .. versionadded:: 1.4.18 + + """ + + trans = self.sync_session.get_nested_transaction() + if trans is not None: + return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + else: + return None + + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + **kw: Any, + ) -> Union[Engine, Connection]: + """Return a "bind" to which the synchronous proxied :class:`_orm.Session` + is bound. + + Unlike the :meth:`_orm.Session.get_bind` method, this method is + currently **not** used by this :class:`.AsyncSession` in any way + in order to resolve engines for requests. + + .. note:: + + This method proxies directly to the :meth:`_orm.Session.get_bind` + method, however is currently **not** useful as an override target, + in contrast to that of the :meth:`_orm.Session.get_bind` method. + The example below illustrates how to implement custom + :meth:`_orm.Session.get_bind` schemes that work with + :class:`.AsyncSession` and :class:`.AsyncEngine`. + + The pattern introduced at :ref:`session_custom_partitioning` + illustrates how to apply a custom bind-lookup scheme to a + :class:`_orm.Session` given a set of :class:`_engine.Engine` objects. + To apply a corresponding :meth:`_orm.Session.get_bind` implementation + for use with a :class:`.AsyncSession` and :class:`.AsyncEngine` + objects, continue to subclass :class:`_orm.Session` and apply it to + :class:`.AsyncSession` using + :paramref:`.AsyncSession.sync_session_class`. The inner method must + continue to return :class:`_engine.Engine` instances, which can be + acquired from a :class:`_asyncio.AsyncEngine` using the + :attr:`_asyncio.AsyncEngine.sync_engine` attribute:: + + # using example from "Custom Vertical Partitioning" + + + import random + + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.orm import Session + + # construct async engines w/ async drivers + engines = { + 'leader':create_async_engine("sqlite+aiosqlite:///leader.db"), + 'other':create_async_engine("sqlite+aiosqlite:///other.db"), + 'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"), + 'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"), + } + + class RoutingSession(Session): + def get_bind(self, mapper=None, clause=None, **kw): + # within get_bind(), return sync engines + if mapper and issubclass(mapper.class_, MyOtherClass): + return engines['other'].sync_engine + elif self._flushing or isinstance(clause, (Update, Delete)): + return engines['leader'].sync_engine + else: + return engines[ + random.choice(['follower1','follower2']) + ].sync_engine + + # apply to AsyncSession using sync_session_class + AsyncSessionMaker = async_sessionmaker( + sync_session_class=RoutingSession + ) + + The :meth:`_orm.Session.get_bind` method is called in a non-asyncio, + implicitly non-blocking context in the same manner as ORM event hooks + and functions that are invoked via :meth:`.AsyncSession.run_sync`, so + routines that wish to run SQL commands inside of + :meth:`_orm.Session.get_bind` can continue to do so using + blocking-style code, which will be translated to implicitly async calls + at the point of invoking IO on the database drivers. + + """ # noqa: E501 + + return self.sync_session.get_bind( + mapper=mapper, clause=clause, bind=bind, **kw + ) + + async def connection( + self, + bind_arguments: Optional[_BindArguments] = None, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + **kw: Any, + ) -> AsyncConnection: + r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to + this :class:`.Session` object's transactional state. + + This method may also be used to establish execution options for the + database connection used by the current transaction. + + .. versionadded:: 1.4.24 Added \**kw arguments which are passed + through to the underlying :meth:`_orm.Session.connection` method. + + .. seealso:: + + :meth:`_orm.Session.connection` - main documentation for + "connection" + + """ + + sync_connection = await greenlet_spawn( + self.sync_session.connection, + bind_arguments=bind_arguments, + execution_options=execution_options, + **kw, + ) + return engine.AsyncConnection._retrieve_proxy_for_target( + sync_connection + ) + + def begin(self) -> AsyncSessionTransaction: + """Return an :class:`_asyncio.AsyncSessionTransaction` object. + + The underlying :class:`_orm.Session` will perform the + "begin" action when the :class:`_asyncio.AsyncSessionTransaction` + object is entered:: + + async with async_session.begin(): + # .. ORM transaction is begun + + Note that database IO will not normally occur when the session-level + transaction is begun, as database transactions begin on an + on-demand basis. However, the begin block is async to accommodate + for a :meth:`_orm.SessionEvents.after_transaction_create` + event hook that may perform IO. + + For a general description of ORM begin, see + :meth:`_orm.Session.begin`. + + """ + + return AsyncSessionTransaction(self) + + def begin_nested(self) -> AsyncSessionTransaction: + """Return an :class:`_asyncio.AsyncSessionTransaction` object + which will begin a "nested" transaction, e.g. SAVEPOINT. + + Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`. + + For a general description of ORM begin nested, see + :meth:`_orm.Session.begin_nested`. + + .. seealso:: + + :ref:`aiosqlite_serializable` - special workarounds required + with the SQLite asyncio driver in order for SAVEPOINT to work + correctly. + + """ + + return AsyncSessionTransaction(self, nested=True) + + async def rollback(self) -> None: + """Rollback the current transaction in progress. + + .. seealso:: + + :meth:`_orm.Session.rollback` - main documentation for + "rollback" + """ + await greenlet_spawn(self.sync_session.rollback) + + async def commit(self) -> None: + """Commit the current transaction in progress. + + .. seealso:: + + :meth:`_orm.Session.commit` - main documentation for + "commit" + """ + await greenlet_spawn(self.sync_session.commit) + + async def close(self) -> None: + """Close out the transactional resources and ORM objects used by this + :class:`_asyncio.AsyncSession`. + + .. seealso:: + + :meth:`_orm.Session.close` - main documentation for + "close" + + :ref:`session_closing` - detail on the semantics of + :meth:`_asyncio.AsyncSession.close` and + :meth:`_asyncio.AsyncSession.reset`. + + """ + await greenlet_spawn(self.sync_session.close) + + async def reset(self) -> None: + """Close out the transactional resources and ORM objects used by this + :class:`_orm.Session`, resetting the session to its initial state. + + .. versionadded:: 2.0.22 + + .. seealso:: + + :meth:`_orm.Session.reset` - main documentation for + "reset" + + :ref:`session_closing` - detail on the semantics of + :meth:`_asyncio.AsyncSession.close` and + :meth:`_asyncio.AsyncSession.reset`. + + """ + await greenlet_spawn(self.sync_session.reset) + + async def aclose(self) -> None: + """A synonym for :meth:`_asyncio.AsyncSession.close`. + + The :meth:`_asyncio.AsyncSession.aclose` name is specifically + to support the Python standard library ``@contextlib.aclosing`` + context manager function. + + .. versionadded:: 2.0.20 + + """ + await self.close() + + async def invalidate(self) -> None: + """Close this Session, using connection invalidation. + + For a complete description, see :meth:`_orm.Session.invalidate`. + """ + await greenlet_spawn(self.sync_session.invalidate) + + @classmethod + @util.deprecated( + "2.0", + "The :meth:`.AsyncSession.close_all` method is deprecated and will be " + "removed in a future release. Please refer to " + ":func:`_asyncio.close_all_sessions`.", + ) + async def close_all(cls) -> None: + """Close all :class:`_asyncio.AsyncSession` sessions.""" + await close_all_sessions() + + async def __aenter__(self: _AS) -> _AS: + return self + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + task = asyncio.create_task(self.close()) + await asyncio.shield(task) + + def _maker_context_manager(self: _AS) -> _AsyncSessionContextManager[_AS]: + return _AsyncSessionContextManager(self) + + # START PROXY METHODS AsyncSession + + # code within this block is **programmatically, + # statically generated** by tools/generate_proxy_methods.py + + def __contains__(self, instance: object) -> bool: + r"""Return True if the instance is associated with this session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + The instance may be pending or persistent within the Session for a + result of True. + + + """ # noqa: E501 + + return self._proxied.__contains__(instance) + + def __iter__(self) -> Iterator[object]: + r"""Iterate over all pending or persistent instances within this + Session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + + """ # noqa: E501 + + return self._proxied.__iter__() + + def add(self, instance: object, _warn: bool = True) -> None: + r"""Place an object into this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + Objects that are in the :term:`transient` state when passed to the + :meth:`_orm.Session.add` method will move to the + :term:`pending` state, until the next flush, at which point they + will move to the :term:`persistent` state. + + Objects that are in the :term:`detached` state when passed to the + :meth:`_orm.Session.add` method will move to the :term:`persistent` + state directly. + + If the transaction used by the :class:`_orm.Session` is rolled back, + objects which were transient when they were passed to + :meth:`_orm.Session.add` will be moved back to the + :term:`transient` state, and will no longer be present within this + :class:`_orm.Session`. + + .. seealso:: + + :meth:`_orm.Session.add_all` + + :ref:`session_adding` - at :ref:`session_basics` + + + """ # noqa: E501 + + return self._proxied.add(instance, _warn=_warn) + + def add_all(self, instances: Iterable[object]) -> None: + r"""Add the given collection of instances to this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + See the documentation for :meth:`_orm.Session.add` for a general + behavioral description. + + .. seealso:: + + :meth:`_orm.Session.add` + + :ref:`session_adding` - at :ref:`session_basics` + + + """ # noqa: E501 + + return self._proxied.add_all(instances) + + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: + r"""Expire the attributes on an instance. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + Marks the attributes of an instance as out of date. When an expired + attribute is next accessed, a query will be issued to the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire all objects in the :class:`.Session` simultaneously, + use :meth:`Session.expire_all`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire` only makes sense for the specific + case that a non-ORM SQL statement was emitted in the current + transaction. + + :param instance: The instance to be refreshed. + :param attribute_names: optional list of string attribute names + indicating a subset of attributes to be expired. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + """ # noqa: E501 + + return self._proxied.expire(instance, attribute_names=attribute_names) + + def expire_all(self) -> None: + r"""Expires all persistent instances within this Session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + When any attributes on a persistent instance is next accessed, + a query will be issued using the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire individual objects and individual attributes + on those objects, use :meth:`Session.expire`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire_all` is not usually needed, + assuming the transaction is isolated. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + """ # noqa: E501 + + return self._proxied.expire_all() + + def expunge(self, instance: object) -> None: + r"""Remove the `instance` from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This will free all internal references to the instance. Cascading + will be applied according to the *expunge* cascade rule. + + + """ # noqa: E501 + + return self._proxied.expunge(instance) + + def expunge_all(self) -> None: + r"""Remove all object instances from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is equivalent to calling ``expunge(obj)`` on all objects in this + ``Session``. + + + """ # noqa: E501 + + return self._proxied.expunge_all() + + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: + r"""Return ``True`` if the given instance has locally + modified attributes. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This method retrieves the history for each instrumented + attribute on the instance and performs a comparison of the current + value to its previously committed value, if any. + + It is in effect a more expensive and accurate + version of checking for the given instance in the + :attr:`.Session.dirty` collection; a full test for + each attribute's net "dirty" status is performed. + + E.g.:: + + return session.is_modified(someobject) + + A few caveats to this method apply: + + * Instances present in the :attr:`.Session.dirty` collection may + report ``False`` when tested with this method. This is because + the object may have received change events via attribute mutation, + thus placing it in :attr:`.Session.dirty`, but ultimately the state + is the same as that loaded from the database, resulting in no net + change here. + * Scalar attributes may not have recorded the previously set + value when a new value was applied, if the attribute was not loaded, + or was expired, at the time the new value was received - in these + cases, the attribute is assumed to have a change, even if there is + ultimately no net change against its database value. SQLAlchemy in + most cases does not need the "old" value when a set event occurs, so + it skips the expense of a SQL call if the old value isn't present, + based on the assumption that an UPDATE of the scalar value is + usually needed, and in those few cases where it isn't, is less + expensive on average than issuing a defensive SELECT. + + The "old" value is fetched unconditionally upon set only if the + attribute container has the ``active_history`` flag set to ``True``. + This flag is set typically for primary key attributes and scalar + object references that are not a simple many-to-one. To set this + flag for any arbitrary mapped column, use the ``active_history`` + argument with :func:`.column_property`. + + :param instance: mapped instance to be tested for pending changes. + :param include_collections: Indicates if multivalued collections + should be included in the operation. Setting this to ``False`` is a + way to detect only local-column based properties (i.e. scalar columns + or many-to-one foreign keys) that would result in an UPDATE for this + instance upon flush. + + + """ # noqa: E501 + + return self._proxied.is_modified( + instance, include_collections=include_collections + ) + + def in_transaction(self) -> bool: + r"""Return True if this :class:`_orm.Session` has begun a transaction. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_orm.Session.is_active` + + + + """ # noqa: E501 + + return self._proxied.in_transaction() + + def in_nested_transaction(self) -> bool: + r"""Return True if this :class:`_orm.Session` has begun a nested + transaction, e.g. SAVEPOINT. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + .. versionadded:: 1.4 + + + """ # noqa: E501 + + return self._proxied.in_nested_transaction() + + @property + def dirty(self) -> Any: + r"""The set of all persistent instances considered dirty. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + E.g.:: + + some_mapped_object in session.dirty + + Instances are considered dirty when they were modified but not + deleted. + + Note that this 'dirty' calculation is 'optimistic'; most + attribute-setting or collection modification operations will + mark an instance as 'dirty' and place it in this set, even if + there is no net change to the attribute's value. At flush + time, the value of each attribute is compared to its + previously saved value, and if there's no net change, no SQL + operation will occur (this is a more expensive operation so + it's only done at flush time). + + To check if an instance has actionable net changes to its + attributes, use the :meth:`.Session.is_modified` method. + + + """ # noqa: E501 + + return self._proxied.dirty + + @property + def deleted(self) -> Any: + r"""The set of all instances marked as 'deleted' within this ``Session`` + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + """ # noqa: E501 + + return self._proxied.deleted + + @property + def new(self) -> Any: + r"""The set of all instances marked as 'new' within this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + """ # noqa: E501 + + return self._proxied.new + + @property + def identity_map(self) -> IdentityMap: + r"""Proxy for the :attr:`_orm.Session.identity_map` attribute + on behalf of the :class:`_asyncio.AsyncSession` class. + + """ # noqa: E501 + + return self._proxied.identity_map + + @identity_map.setter + def identity_map(self, attr: IdentityMap) -> None: + self._proxied.identity_map = attr + + @property + def is_active(self) -> Any: + r"""True if this :class:`.Session` not in "partial rollback" state. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins + a new transaction immediately, so this attribute will be False + when the :class:`_orm.Session` is first instantiated. + + "partial rollback" state typically indicates that the flush process + of the :class:`_orm.Session` has failed, and that the + :meth:`_orm.Session.rollback` method must be emitted in order to + fully roll back the transaction. + + If this :class:`_orm.Session` is not in a transaction at all, the + :class:`_orm.Session` will autobegin when it is first used, so in this + case :attr:`_orm.Session.is_active` will return True. + + Otherwise, if this :class:`_orm.Session` is within a transaction, + and that transaction has not been rolled back internally, the + :attr:`_orm.Session.is_active` will also return True. + + .. seealso:: + + :ref:`faq_session_rollback` + + :meth:`_orm.Session.in_transaction` + + + """ # noqa: E501 + + return self._proxied.is_active + + @property + def autoflush(self) -> bool: + r"""Proxy for the :attr:`_orm.Session.autoflush` attribute + on behalf of the :class:`_asyncio.AsyncSession` class. + + """ # noqa: E501 + + return self._proxied.autoflush + + @autoflush.setter + def autoflush(self, attr: bool) -> None: + self._proxied.autoflush = attr + + @property + def no_autoflush(self) -> Any: + r"""Return a context manager that disables autoflush. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + e.g.:: + + with session.no_autoflush: + + some_object = SomeClass() + session.add(some_object) + # won't autoflush + some_object.related_thing = session.query(SomeRelated).first() + + Operations that proceed within the ``with:`` block + will not be subject to flushes occurring upon query + access. This is useful when initializing a series + of objects which involve existing database queries, + where the uncompleted object should not yet be flushed. + + + """ # noqa: E501 + + return self._proxied.no_autoflush + + @property + def info(self) -> Any: + r"""A user-modifiable dictionary. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + The initial value of this dictionary can be populated using the + ``info`` argument to the :class:`.Session` constructor or + :class:`.sessionmaker` constructor or factory methods. The dictionary + here is always local to this :class:`.Session` and can be modified + independently of all other :class:`.Session` objects. + + + """ # noqa: E501 + + return self._proxied.info + + @classmethod + def object_session(cls, instance: object) -> Optional[Session]: + r"""Return the :class:`.Session` to which an object belongs. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is an alias of :func:`.object_session`. + + + """ # noqa: E501 + + return Session.object_session(instance) + + @classmethod + def identity_key( + cls, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, + *, + instance: Optional[Any] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: + r"""Return an identity key. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is an alias of :func:`.util.identity_key`. + + + """ # noqa: E501 + + return Session.identity_key( + class_=class_, + ident=ident, + instance=instance, + row=row, + identity_token=identity_token, + ) + + # END PROXY METHODS AsyncSession + + +_AS = TypeVar("_AS", bound="AsyncSession") + + +class async_sessionmaker(Generic[_AS]): + """A configurable :class:`.AsyncSession` factory. + + The :class:`.async_sessionmaker` factory works in the same way as the + :class:`.sessionmaker` factory, to generate new :class:`.AsyncSession` + objects when called, creating them given + the configurational arguments established here. + + e.g.:: + + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.ext.asyncio import async_sessionmaker + + async def run_some_sql(async_session: async_sessionmaker[AsyncSession]) -> None: + async with async_session() as session: + session.add(SomeObject(data="object")) + session.add(SomeOtherObject(name="other object")) + await session.commit() + + async def main() -> None: + # an AsyncEngine, which the AsyncSession will use for connection + # resources + engine = create_async_engine('postgresql+asyncpg://scott:tiger@localhost/') + + # create a reusable factory for new AsyncSession instances + async_session = async_sessionmaker(engine) + + await run_some_sql(async_session) + + await engine.dispose() + + The :class:`.async_sessionmaker` is useful so that different parts + of a program can create new :class:`.AsyncSession` objects with a + fixed configuration established up front. Note that :class:`.AsyncSession` + objects may also be instantiated directly when not using + :class:`.async_sessionmaker`. + + .. versionadded:: 2.0 :class:`.async_sessionmaker` provides a + :class:`.sessionmaker` class that's dedicated to the + :class:`.AsyncSession` object, including pep-484 typing support. + + .. seealso:: + + :ref:`asyncio_orm` - shows example use + + :class:`.sessionmaker` - general overview of the + :class:`.sessionmaker` architecture + + + :ref:`session_getting` - introductory text on creating + sessions using :class:`.sessionmaker`. + + """ # noqa E501 + + class_: Type[_AS] + + @overload + def __init__( + self, + bind: Optional[_AsyncSessionBind] = ..., + *, + class_: Type[_AS], + autoflush: bool = ..., + expire_on_commit: bool = ..., + info: Optional[_InfoType] = ..., + **kw: Any, + ): ... + + @overload + def __init__( + self: "async_sessionmaker[AsyncSession]", + bind: Optional[_AsyncSessionBind] = ..., + *, + autoflush: bool = ..., + expire_on_commit: bool = ..., + info: Optional[_InfoType] = ..., + **kw: Any, + ): ... + + def __init__( + self, + bind: Optional[_AsyncSessionBind] = None, + *, + class_: Type[_AS] = AsyncSession, # type: ignore + autoflush: bool = True, + expire_on_commit: bool = True, + info: Optional[_InfoType] = None, + **kw: Any, + ): + r"""Construct a new :class:`.async_sessionmaker`. + + All arguments here except for ``class_`` correspond to arguments + accepted by :class:`.Session` directly. See the + :meth:`.AsyncSession.__init__` docstring for more details on + parameters. + + + """ + kw["bind"] = bind + kw["autoflush"] = autoflush + kw["expire_on_commit"] = expire_on_commit + if info is not None: + kw["info"] = info + self.kw = kw + self.class_ = class_ + + def begin(self) -> _AsyncSessionContextManager[_AS]: + """Produce a context manager that both provides a new + :class:`_orm.AsyncSession` as well as a transaction that commits. + + + e.g.:: + + async def main(): + Session = async_sessionmaker(some_engine) + + async with Session.begin() as session: + session.add(some_object) + + # commits transaction, closes session + + + """ + + session = self() + return session._maker_context_manager() + + def __call__(self, **local_kw: Any) -> _AS: + """Produce a new :class:`.AsyncSession` object using the configuration + established in this :class:`.async_sessionmaker`. + + In Python, the ``__call__`` method is invoked on an object when + it is "called" in the same way as a function:: + + AsyncSession = async_sessionmaker(async_engine, expire_on_commit=False) + session = AsyncSession() # invokes sessionmaker.__call__() + + """ # noqa E501 + for k, v in self.kw.items(): + if k == "info" and "info" in local_kw: + d = v.copy() + d.update(local_kw["info"]) + local_kw["info"] = d + else: + local_kw.setdefault(k, v) + return self.class_(**local_kw) + + def configure(self, **new_kw: Any) -> None: + """(Re)configure the arguments for this async_sessionmaker. + + e.g.:: + + AsyncSession = async_sessionmaker(some_engine) + + AsyncSession.configure(bind=create_async_engine('sqlite+aiosqlite://')) + """ # noqa E501 + + self.kw.update(new_kw) + + def __repr__(self) -> str: + return "%s(class_=%r, %s)" % ( + self.__class__.__name__, + self.class_.__name__, + ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()), + ) + + +class _AsyncSessionContextManager(Generic[_AS]): + __slots__ = ("async_session", "trans") + + async_session: _AS + trans: AsyncSessionTransaction + + def __init__(self, async_session: _AS): + self.async_session = async_session + + async def __aenter__(self) -> _AS: + self.trans = self.async_session.begin() + await self.trans.__aenter__() + return self.async_session + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + async def go() -> None: + await self.trans.__aexit__(type_, value, traceback) + await self.async_session.__aexit__(type_, value, traceback) + + task = asyncio.create_task(go()) + await asyncio.shield(task) + + +class AsyncSessionTransaction( + ReversibleProxy[SessionTransaction], + StartableContext["AsyncSessionTransaction"], +): + """A wrapper for the ORM :class:`_orm.SessionTransaction` object. + + This object is provided so that a transaction-holding object + for the :meth:`_asyncio.AsyncSession.begin` may be returned. + + The object supports both explicit calls to + :meth:`_asyncio.AsyncSessionTransaction.commit` and + :meth:`_asyncio.AsyncSessionTransaction.rollback`, as well as use as an + async context manager. + + + .. versionadded:: 1.4 + + """ + + __slots__ = ("session", "sync_transaction", "nested") + + session: AsyncSession + sync_transaction: Optional[SessionTransaction] + + def __init__(self, session: AsyncSession, nested: bool = False): + self.session = session + self.nested = nested + self.sync_transaction = None + + @property + def is_active(self) -> bool: + return ( + self._sync_transaction() is not None + and self._sync_transaction().is_active + ) + + def _sync_transaction(self) -> SessionTransaction: + if not self.sync_transaction: + self._raise_for_not_started() + return self.sync_transaction + + async def rollback(self) -> None: + """Roll back this :class:`_asyncio.AsyncTransaction`.""" + await greenlet_spawn(self._sync_transaction().rollback) + + async def commit(self) -> None: + """Commit this :class:`_asyncio.AsyncTransaction`.""" + + await greenlet_spawn(self._sync_transaction().commit) + + async def start( + self, is_ctxmanager: bool = False + ) -> AsyncSessionTransaction: + self.sync_transaction = self._assign_proxied( + await greenlet_spawn( + self.session.sync_session.begin_nested # type: ignore + if self.nested + else self.session.sync_session.begin + ) + ) + if is_ctxmanager: + self.sync_transaction.__enter__() + return self + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + await greenlet_spawn( + self._sync_transaction().__exit__, type_, value, traceback + ) + + +def async_object_session(instance: object) -> Optional[AsyncSession]: + """Return the :class:`_asyncio.AsyncSession` to which the given instance + belongs. + + This function makes use of the sync-API function + :class:`_orm.object_session` to retrieve the :class:`_orm.Session` which + refers to the given instance, and from there links it to the original + :class:`_asyncio.AsyncSession`. + + If the :class:`_asyncio.AsyncSession` has been garbage collected, the + return value is ``None``. + + This functionality is also available from the + :attr:`_orm.InstanceState.async_session` accessor. + + :param instance: an ORM mapped instance + :return: an :class:`_asyncio.AsyncSession` object, or ``None``. + + .. versionadded:: 1.4.18 + + """ + + session = object_session(instance) + if session is not None: + return async_session(session) + else: + return None + + +def async_session(session: Session) -> Optional[AsyncSession]: + """Return the :class:`_asyncio.AsyncSession` which is proxying the given + :class:`_orm.Session` object, if any. + + :param session: a :class:`_orm.Session` instance. + :return: a :class:`_asyncio.AsyncSession` instance, or ``None``. + + .. versionadded:: 1.4.18 + + """ + return AsyncSession._retrieve_proxy_for_target(session, regenerate=False) + + +async def close_all_sessions() -> None: + """Close all :class:`_asyncio.AsyncSession` sessions. + + .. versionadded:: 2.0.23 + + .. seealso:: + + :func:`.session.close_all_sessions` + + """ + await greenlet_spawn(_sync_close_all_sessions) + + +_instance_state._async_provider = async_session # type: ignore diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/automap.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/automap.py new file mode 100644 index 0000000..bf6a5f2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/automap.py @@ -0,0 +1,1658 @@ +# ext/automap.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r"""Define an extension to the :mod:`sqlalchemy.ext.declarative` system +which automatically generates mapped classes and relationships from a database +schema, typically though not necessarily one which is reflected. + +It is hoped that the :class:`.AutomapBase` system provides a quick +and modernized solution to the problem that the very famous +`SQLSoup <https://sqlsoup.readthedocs.io/en/latest/>`_ +also tries to solve, that of generating a quick and rudimentary object +model from an existing database on the fly. By addressing the issue strictly +at the mapper configuration level, and integrating fully with existing +Declarative class techniques, :class:`.AutomapBase` seeks to provide +a well-integrated approach to the issue of expediently auto-generating ad-hoc +mappings. + +.. tip:: The :ref:`automap_toplevel` extension is geared towards a + "zero declaration" approach, where a complete ORM model including classes + and pre-named relationships can be generated on the fly from a database + schema. For applications that still want to use explicit class declarations + including explicit relationship definitions in conjunction with reflection + of tables, the :class:`.DeferredReflection` class, described at + :ref:`orm_declarative_reflected_deferred_reflection`, is a better choice. + +.. _automap_basic_use: + +Basic Use +========= + +The simplest usage is to reflect an existing database into a new model. +We create a new :class:`.AutomapBase` class in a similar manner as to how +we create a declarative base class, using :func:`.automap_base`. +We then call :meth:`.AutomapBase.prepare` on the resulting base class, +asking it to reflect the schema and produce mappings:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy.orm import Session + from sqlalchemy import create_engine + + Base = automap_base() + + # engine, suppose it has two tables 'user' and 'address' set up + engine = create_engine("sqlite:///mydatabase.db") + + # reflect the tables + Base.prepare(autoload_with=engine) + + # mapped classes are now created with names by default + # matching that of the table name. + User = Base.classes.user + Address = Base.classes.address + + session = Session(engine) + + # rudimentary relationships are produced + session.add(Address(email_address="foo@bar.com", user=User(name="foo"))) + session.commit() + + # collection-based relationships are by default named + # "<classname>_collection" + u1 = session.query(User).first() + print (u1.address_collection) + +Above, calling :meth:`.AutomapBase.prepare` while passing along the +:paramref:`.AutomapBase.prepare.reflect` parameter indicates that the +:meth:`_schema.MetaData.reflect` +method will be called on this declarative base +classes' :class:`_schema.MetaData` collection; then, each **viable** +:class:`_schema.Table` within the :class:`_schema.MetaData` +will get a new mapped class +generated automatically. The :class:`_schema.ForeignKeyConstraint` +objects which +link the various tables together will be used to produce new, bidirectional +:func:`_orm.relationship` objects between classes. +The classes and relationships +follow along a default naming scheme that we can customize. At this point, +our basic mapping consisting of related ``User`` and ``Address`` classes is +ready to use in the traditional way. + +.. note:: By **viable**, we mean that for a table to be mapped, it must + specify a primary key. Additionally, if the table is detected as being + a pure association table between two other tables, it will not be directly + mapped and will instead be configured as a many-to-many table between + the mappings for the two referring tables. + +Generating Mappings from an Existing MetaData +============================================= + +We can pass a pre-declared :class:`_schema.MetaData` object to +:func:`.automap_base`. +This object can be constructed in any way, including programmatically, from +a serialized file, or from itself being reflected using +:meth:`_schema.MetaData.reflect`. +Below we illustrate a combination of reflection and +explicit table declaration:: + + from sqlalchemy import create_engine, MetaData, Table, Column, ForeignKey + from sqlalchemy.ext.automap import automap_base + engine = create_engine("sqlite:///mydatabase.db") + + # produce our own MetaData object + metadata = MetaData() + + # we can reflect it ourselves from a database, using options + # such as 'only' to limit what tables we look at... + metadata.reflect(engine, only=['user', 'address']) + + # ... or just define our own Table objects with it (or combine both) + Table('user_order', metadata, + Column('id', Integer, primary_key=True), + Column('user_id', ForeignKey('user.id')) + ) + + # we can then produce a set of mappings from this MetaData. + Base = automap_base(metadata=metadata) + + # calling prepare() just sets up mapped classes and relationships. + Base.prepare() + + # mapped classes are ready + User, Address, Order = Base.classes.user, Base.classes.address,\ + Base.classes.user_order + +.. _automap_by_module: + +Generating Mappings from Multiple Schemas +========================================= + +The :meth:`.AutomapBase.prepare` method when used with reflection may reflect +tables from one schema at a time at most, using the +:paramref:`.AutomapBase.prepare.schema` parameter to indicate the name of a +schema to be reflected from. In order to populate the :class:`.AutomapBase` +with tables from multiple schemas, :meth:`.AutomapBase.prepare` may be invoked +multiple times, each time passing a different name to the +:paramref:`.AutomapBase.prepare.schema` parameter. The +:meth:`.AutomapBase.prepare` method keeps an internal list of +:class:`_schema.Table` objects that have already been mapped, and will add new +mappings only for those :class:`_schema.Table` objects that are new since the +last time :meth:`.AutomapBase.prepare` was run:: + + e = create_engine("postgresql://scott:tiger@localhost/test") + + Base.metadata.create_all(e) + + Base = automap_base() + + Base.prepare(e) + Base.prepare(e, schema="test_schema") + Base.prepare(e, schema="test_schema_2") + +.. versionadded:: 2.0 The :meth:`.AutomapBase.prepare` method may be called + any number of times; only newly added tables will be mapped + on each run. Previously in version 1.4 and earlier, multiple calls would + cause errors as it would attempt to re-map an already mapped class. + The previous workaround approach of invoking + :meth:`_schema.MetaData.reflect` directly remains available as well. + +Automapping same-named tables across multiple schemas +----------------------------------------------------- + +For the common case where multiple schemas may have same-named tables and +therefore would generate same-named classes, conflicts can be resolved either +through use of the :paramref:`.AutomapBase.prepare.classname_for_table` hook to +apply different classnames on a per-schema basis, or by using the +:paramref:`.AutomapBase.prepare.modulename_for_table` hook, which allows +disambiguation of same-named classes by changing their effective ``__module__`` +attribute. In the example below, this hook is used to create a ``__module__`` +attribute for all classes that is of the form ``mymodule.<schemaname>``, where +the schema name ``default`` is used if no schema is present:: + + e = create_engine("postgresql://scott:tiger@localhost/test") + + Base.metadata.create_all(e) + + def module_name_for_table(cls, tablename, table): + if table.schema is not None: + return f"mymodule.{table.schema}" + else: + return f"mymodule.default" + + Base = automap_base() + + Base.prepare(e, modulename_for_table=module_name_for_table) + Base.prepare(e, schema="test_schema", modulename_for_table=module_name_for_table) + Base.prepare(e, schema="test_schema_2", modulename_for_table=module_name_for_table) + + +The same named-classes are organized into a hierarchical collection available +at :attr:`.AutomapBase.by_module`. This collection is traversed using the +dot-separated name of a particular package/module down into the desired +class name. + +.. note:: When using the :paramref:`.AutomapBase.prepare.modulename_for_table` + hook to return a new ``__module__`` that is not ``None``, the class is + **not** placed into the :attr:`.AutomapBase.classes` collection; only + classes that were not given an explicit modulename are placed here, as the + collection cannot represent same-named classes individually. + +In the example above, if the database contained a table named ``accounts`` in +all three of the default schema, the ``test_schema`` schema, and the +``test_schema_2`` schema, three separate classes will be available as:: + + Base.by_module.mymodule.default.accounts + Base.by_module.mymodule.test_schema.accounts + Base.by_module.mymodule.test_schema_2.accounts + +The default module namespace generated for all :class:`.AutomapBase` classes is +``sqlalchemy.ext.automap``. If no +:paramref:`.AutomapBase.prepare.modulename_for_table` hook is used, the +contents of :attr:`.AutomapBase.by_module` will be entirely within the +``sqlalchemy.ext.automap`` namespace (e.g. +``MyBase.by_module.sqlalchemy.ext.automap.<classname>``), which would contain +the same series of classes as what would be seen in +:attr:`.AutomapBase.classes`. Therefore it's generally only necessary to use +:attr:`.AutomapBase.by_module` when explicit ``__module__`` conventions are +present. + +.. versionadded: 2.0 + + Added the :attr:`.AutomapBase.by_module` collection, which stores + classes within a named hierarchy based on dot-separated module names, + as well as the :paramref:`.Automap.prepare.modulename_for_table` parameter + which allows for custom ``__module__`` schemes for automapped + classes. + + + +Specifying Classes Explicitly +============================= + +.. tip:: If explicit classes are expected to be prominent in an application, + consider using :class:`.DeferredReflection` instead. + +The :mod:`.sqlalchemy.ext.automap` extension allows classes to be defined +explicitly, in a way similar to that of the :class:`.DeferredReflection` class. +Classes that extend from :class:`.AutomapBase` act like regular declarative +classes, but are not immediately mapped after their construction, and are +instead mapped when we call :meth:`.AutomapBase.prepare`. The +:meth:`.AutomapBase.prepare` method will make use of the classes we've +established based on the table name we use. If our schema contains tables +``user`` and ``address``, we can define one or both of the classes to be used:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import create_engine + + # automap base + Base = automap_base() + + # pre-declare User for the 'user' table + class User(Base): + __tablename__ = 'user' + + # override schema elements like Columns + user_name = Column('name', String) + + # override relationships too, if desired. + # we must use the same name that automap would use for the + # relationship, and also must refer to the class name that automap will + # generate for "address" + address_collection = relationship("address", collection_class=set) + + # reflect + engine = create_engine("sqlite:///mydatabase.db") + Base.prepare(autoload_with=engine) + + # we still have Address generated from the tablename "address", + # but User is the same as Base.classes.User now + + Address = Base.classes.address + + u1 = session.query(User).first() + print (u1.address_collection) + + # the backref is still there: + a1 = session.query(Address).first() + print (a1.user) + +Above, one of the more intricate details is that we illustrated overriding +one of the :func:`_orm.relationship` objects that automap would have created. +To do this, we needed to make sure the names match up with what automap +would normally generate, in that the relationship name would be +``User.address_collection`` and the name of the class referred to, from +automap's perspective, is called ``address``, even though we are referring to +it as ``Address`` within our usage of this class. + +Overriding Naming Schemes +========================= + +:mod:`.sqlalchemy.ext.automap` is tasked with producing mapped classes and +relationship names based on a schema, which means it has decision points in how +these names are determined. These three decision points are provided using +functions which can be passed to the :meth:`.AutomapBase.prepare` method, and +are known as :func:`.classname_for_table`, +:func:`.name_for_scalar_relationship`, +and :func:`.name_for_collection_relationship`. Any or all of these +functions are provided as in the example below, where we use a "camel case" +scheme for class names and a "pluralizer" for collection names using the +`Inflect <https://pypi.org/project/inflect>`_ package:: + + import re + import inflect + + def camelize_classname(base, tablename, table): + "Produce a 'camelized' class name, e.g. " + "'words_and_underscores' -> 'WordsAndUnderscores'" + + return str(tablename[0].upper() + \ + re.sub(r'_([a-z])', lambda m: m.group(1).upper(), tablename[1:])) + + _pluralizer = inflect.engine() + def pluralize_collection(base, local_cls, referred_cls, constraint): + "Produce an 'uncamelized', 'pluralized' class name, e.g. " + "'SomeTerm' -> 'some_terms'" + + referred_name = referred_cls.__name__ + uncamelized = re.sub(r'[A-Z]', + lambda m: "_%s" % m.group(0).lower(), + referred_name)[1:] + pluralized = _pluralizer.plural(uncamelized) + return pluralized + + from sqlalchemy.ext.automap import automap_base + + Base = automap_base() + + engine = create_engine("sqlite:///mydatabase.db") + + Base.prepare(autoload_with=engine, + classname_for_table=camelize_classname, + name_for_collection_relationship=pluralize_collection + ) + +From the above mapping, we would now have classes ``User`` and ``Address``, +where the collection from ``User`` to ``Address`` is called +``User.addresses``:: + + User, Address = Base.classes.User, Base.classes.Address + + u1 = User(addresses=[Address(email="foo@bar.com")]) + +Relationship Detection +====================== + +The vast majority of what automap accomplishes is the generation of +:func:`_orm.relationship` structures based on foreign keys. The mechanism +by which this works for many-to-one and one-to-many relationships is as +follows: + +1. A given :class:`_schema.Table`, known to be mapped to a particular class, + is examined for :class:`_schema.ForeignKeyConstraint` objects. + +2. From each :class:`_schema.ForeignKeyConstraint`, the remote + :class:`_schema.Table` + object present is matched up to the class to which it is to be mapped, + if any, else it is skipped. + +3. As the :class:`_schema.ForeignKeyConstraint` + we are examining corresponds to a + reference from the immediate mapped class, the relationship will be set up + as a many-to-one referring to the referred class; a corresponding + one-to-many backref will be created on the referred class referring + to this class. + +4. If any of the columns that are part of the + :class:`_schema.ForeignKeyConstraint` + are not nullable (e.g. ``nullable=False``), a + :paramref:`_orm.relationship.cascade` keyword argument + of ``all, delete-orphan`` will be added to the keyword arguments to + be passed to the relationship or backref. If the + :class:`_schema.ForeignKeyConstraint` reports that + :paramref:`_schema.ForeignKeyConstraint.ondelete` + is set to ``CASCADE`` for a not null or ``SET NULL`` for a nullable + set of columns, the option :paramref:`_orm.relationship.passive_deletes` + flag is set to ``True`` in the set of relationship keyword arguments. + Note that not all backends support reflection of ON DELETE. + +5. The names of the relationships are determined using the + :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` and + :paramref:`.AutomapBase.prepare.name_for_collection_relationship` + callable functions. It is important to note that the default relationship + naming derives the name from the **the actual class name**. If you've + given a particular class an explicit name by declaring it, or specified an + alternate class naming scheme, that's the name from which the relationship + name will be derived. + +6. The classes are inspected for an existing mapped property matching these + names. If one is detected on one side, but none on the other side, + :class:`.AutomapBase` attempts to create a relationship on the missing side, + then uses the :paramref:`_orm.relationship.back_populates` + parameter in order to + point the new relationship to the other side. + +7. In the usual case where no relationship is on either side, + :meth:`.AutomapBase.prepare` produces a :func:`_orm.relationship` on the + "many-to-one" side and matches it to the other using the + :paramref:`_orm.relationship.backref` parameter. + +8. Production of the :func:`_orm.relationship` and optionally the + :func:`.backref` + is handed off to the :paramref:`.AutomapBase.prepare.generate_relationship` + function, which can be supplied by the end-user in order to augment + the arguments passed to :func:`_orm.relationship` or :func:`.backref` or to + make use of custom implementations of these functions. + +Custom Relationship Arguments +----------------------------- + +The :paramref:`.AutomapBase.prepare.generate_relationship` hook can be used +to add parameters to relationships. For most cases, we can make use of the +existing :func:`.automap.generate_relationship` function to return +the object, after augmenting the given keyword dictionary with our own +arguments. + +Below is an illustration of how to send +:paramref:`_orm.relationship.cascade` and +:paramref:`_orm.relationship.passive_deletes` +options along to all one-to-many relationships:: + + from sqlalchemy.ext.automap import generate_relationship + + def _gen_relationship(base, direction, return_fn, + attrname, local_cls, referred_cls, **kw): + if direction is interfaces.ONETOMANY: + kw['cascade'] = 'all, delete-orphan' + kw['passive_deletes'] = True + # make use of the built-in function to actually return + # the result. + return generate_relationship(base, direction, return_fn, + attrname, local_cls, referred_cls, **kw) + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import create_engine + + # automap base + Base = automap_base() + + engine = create_engine("sqlite:///mydatabase.db") + Base.prepare(autoload_with=engine, + generate_relationship=_gen_relationship) + +Many-to-Many relationships +-------------------------- + +:mod:`.sqlalchemy.ext.automap` will generate many-to-many relationships, e.g. +those which contain a ``secondary`` argument. The process for producing these +is as follows: + +1. A given :class:`_schema.Table` is examined for + :class:`_schema.ForeignKeyConstraint` + objects, before any mapped class has been assigned to it. + +2. If the table contains two and exactly two + :class:`_schema.ForeignKeyConstraint` + objects, and all columns within this table are members of these two + :class:`_schema.ForeignKeyConstraint` objects, the table is assumed to be a + "secondary" table, and will **not be mapped directly**. + +3. The two (or one, for self-referential) external tables to which the + :class:`_schema.Table` + refers to are matched to the classes to which they will be + mapped, if any. + +4. If mapped classes for both sides are located, a many-to-many bi-directional + :func:`_orm.relationship` / :func:`.backref` + pair is created between the two + classes. + +5. The override logic for many-to-many works the same as that of one-to-many/ + many-to-one; the :func:`.generate_relationship` function is called upon + to generate the structures and existing attributes will be maintained. + +Relationships with Inheritance +------------------------------ + +:mod:`.sqlalchemy.ext.automap` will not generate any relationships between +two classes that are in an inheritance relationship. That is, with two +classes given as follows:: + + class Employee(Base): + __tablename__ = 'employee' + id = Column(Integer, primary_key=True) + type = Column(String(50)) + __mapper_args__ = { + 'polymorphic_identity':'employee', 'polymorphic_on': type + } + + class Engineer(Employee): + __tablename__ = 'engineer' + id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + __mapper_args__ = { + 'polymorphic_identity':'engineer', + } + +The foreign key from ``Engineer`` to ``Employee`` is used not for a +relationship, but to establish joined inheritance between the two classes. + +Note that this means automap will not generate *any* relationships +for foreign keys that link from a subclass to a superclass. If a mapping +has actual relationships from subclass to superclass as well, those +need to be explicit. Below, as we have two separate foreign keys +from ``Engineer`` to ``Employee``, we need to set up both the relationship +we want as well as the ``inherit_condition``, as these are not things +SQLAlchemy can guess:: + + class Employee(Base): + __tablename__ = 'employee' + id = Column(Integer, primary_key=True) + type = Column(String(50)) + + __mapper_args__ = { + 'polymorphic_identity':'employee', 'polymorphic_on':type + } + + class Engineer(Employee): + __tablename__ = 'engineer' + id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + favorite_employee_id = Column(Integer, ForeignKey('employee.id')) + + favorite_employee = relationship(Employee, + foreign_keys=favorite_employee_id) + + __mapper_args__ = { + 'polymorphic_identity':'engineer', + 'inherit_condition': id == Employee.id + } + +Handling Simple Naming Conflicts +-------------------------------- + +In the case of naming conflicts during mapping, override any of +:func:`.classname_for_table`, :func:`.name_for_scalar_relationship`, +and :func:`.name_for_collection_relationship` as needed. For example, if +automap is attempting to name a many-to-one relationship the same as an +existing column, an alternate convention can be conditionally selected. Given +a schema: + +.. sourcecode:: sql + + CREATE TABLE table_a ( + id INTEGER PRIMARY KEY + ); + + CREATE TABLE table_b ( + id INTEGER PRIMARY KEY, + table_a INTEGER, + FOREIGN KEY(table_a) REFERENCES table_a(id) + ); + +The above schema will first automap the ``table_a`` table as a class named +``table_a``; it will then automap a relationship onto the class for ``table_b`` +with the same name as this related class, e.g. ``table_a``. This +relationship name conflicts with the mapping column ``table_b.table_a``, +and will emit an error on mapping. + +We can resolve this conflict by using an underscore as follows:: + + def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): + name = referred_cls.__name__.lower() + local_table = local_cls.__table__ + if name in local_table.columns: + newname = name + "_" + warnings.warn( + "Already detected name %s present. using %s" % + (name, newname)) + return newname + return name + + + Base.prepare(autoload_with=engine, + name_for_scalar_relationship=name_for_scalar_relationship) + +Alternatively, we can change the name on the column side. The columns +that are mapped can be modified using the technique described at +:ref:`mapper_column_distinct_names`, by assigning the column explicitly +to a new name:: + + Base = automap_base() + + class TableB(Base): + __tablename__ = 'table_b' + _table_a = Column('table_a', ForeignKey('table_a.id')) + + Base.prepare(autoload_with=engine) + + +Using Automap with Explicit Declarations +======================================== + +As noted previously, automap has no dependency on reflection, and can make +use of any collection of :class:`_schema.Table` objects within a +:class:`_schema.MetaData` +collection. From this, it follows that automap can also be used +generate missing relationships given an otherwise complete model that fully +defines table metadata:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import Column, Integer, String, ForeignKey + + Base = automap_base() + + class User(Base): + __tablename__ = 'user' + + id = Column(Integer, primary_key=True) + name = Column(String) + + class Address(Base): + __tablename__ = 'address' + + id = Column(Integer, primary_key=True) + email = Column(String) + user_id = Column(ForeignKey('user.id')) + + # produce relationships + Base.prepare() + + # mapping is complete, with "address_collection" and + # "user" relationships + a1 = Address(email='u1') + a2 = Address(email='u2') + u1 = User(address_collection=[a1, a2]) + assert a1.user is u1 + +Above, given mostly complete ``User`` and ``Address`` mappings, the +:class:`_schema.ForeignKey` which we defined on ``Address.user_id`` allowed a +bidirectional relationship pair ``Address.user`` and +``User.address_collection`` to be generated on the mapped classes. + +Note that when subclassing :class:`.AutomapBase`, +the :meth:`.AutomapBase.prepare` method is required; if not called, the classes +we've declared are in an un-mapped state. + + +.. _automap_intercepting_columns: + +Intercepting Column Definitions +=============================== + +The :class:`_schema.MetaData` and :class:`_schema.Table` objects support an +event hook :meth:`_events.DDLEvents.column_reflect` that may be used to intercept +the information reflected about a database column before the :class:`_schema.Column` +object is constructed. For example if we wanted to map columns using a +naming convention such as ``"attr_<columnname>"``, the event could +be applied as:: + + @event.listens_for(Base.metadata, "column_reflect") + def column_reflect(inspector, table, column_info): + # set column.key = "attr_<lower_case_name>" + column_info['key'] = "attr_%s" % column_info['name'].lower() + + # run reflection + Base.prepare(autoload_with=engine) + +.. versionadded:: 1.4.0b2 the :meth:`_events.DDLEvents.column_reflect` event + may be applied to a :class:`_schema.MetaData` object. + +.. seealso:: + + :meth:`_events.DDLEvents.column_reflect` + + :ref:`mapper_automated_reflection_schemes` - in the ORM mapping documentation + + +""" # noqa +from __future__ import annotations + +import dataclasses +from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import List +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .. import util +from ..orm import backref +from ..orm import declarative_base as _declarative_base +from ..orm import exc as orm_exc +from ..orm import interfaces +from ..orm import relationship +from ..orm.decl_base import _DeferredMapperConfig +from ..orm.mapper import _CONFIGURE_MUTEX +from ..schema import ForeignKeyConstraint +from ..sql import and_ +from ..util import Properties +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ..engine.base import Engine + from ..orm.base import RelationshipDirection + from ..orm.relationships import ORMBackrefArgument + from ..orm.relationships import Relationship + from ..sql.schema import Column + from ..sql.schema import MetaData + from ..sql.schema import Table + from ..util import immutabledict + + +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + + +class PythonNameForTableType(Protocol): + def __call__( + self, base: Type[Any], tablename: str, table: Table + ) -> str: ... + + +def classname_for_table( + base: Type[Any], + tablename: str, + table: Table, +) -> str: + """Return the class name that should be used, given the name + of a table. + + The default implementation is:: + + return str(tablename) + + Alternate implementations can be specified using the + :paramref:`.AutomapBase.prepare.classname_for_table` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param tablename: string name of the :class:`_schema.Table`. + + :param table: the :class:`_schema.Table` object itself. + + :return: a string class name. + + .. note:: + + In Python 2, the string used for the class name **must** be a + non-Unicode object, e.g. a ``str()`` object. The ``.name`` attribute + of :class:`_schema.Table` is typically a Python unicode subclass, + so the + ``str()`` function should be applied to this name, after accounting for + any non-ASCII characters. + + """ + return str(tablename) + + +class NameForScalarRelationshipType(Protocol): + def __call__( + self, + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, + ) -> str: ... + + +def name_for_scalar_relationship( + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, +) -> str: + """Return the attribute name that should be used to refer from one + class to another, for a scalar object reference. + + The default implementation is:: + + return referred_cls.__name__.lower() + + Alternate implementations can be specified using the + :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param local_cls: the class to be mapped on the local side. + + :param referred_cls: the class to be mapped on the referring side. + + :param constraint: the :class:`_schema.ForeignKeyConstraint` that is being + inspected to produce this relationship. + + """ + return referred_cls.__name__.lower() + + +class NameForCollectionRelationshipType(Protocol): + def __call__( + self, + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, + ) -> str: ... + + +def name_for_collection_relationship( + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, +) -> str: + """Return the attribute name that should be used to refer from one + class to another, for a collection reference. + + The default implementation is:: + + return referred_cls.__name__.lower() + "_collection" + + Alternate implementations + can be specified using the + :paramref:`.AutomapBase.prepare.name_for_collection_relationship` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param local_cls: the class to be mapped on the local side. + + :param referred_cls: the class to be mapped on the referring side. + + :param constraint: the :class:`_schema.ForeignKeyConstraint` that is being + inspected to produce this relationship. + + """ + return referred_cls.__name__.lower() + "_collection" + + +class GenerateRelationshipType(Protocol): + @overload + def __call__( + self, + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., Relationship[Any]], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, + ) -> Relationship[Any]: ... + + @overload + def __call__( + self, + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., ORMBackrefArgument], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, + ) -> ORMBackrefArgument: ... + + def __call__( + self, + base: Type[Any], + direction: RelationshipDirection, + return_fn: Union[ + Callable[..., Relationship[Any]], Callable[..., ORMBackrefArgument] + ], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, + ) -> Union[ORMBackrefArgument, Relationship[Any]]: ... + + +@overload +def generate_relationship( + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., Relationship[Any]], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, +) -> Relationship[Any]: ... + + +@overload +def generate_relationship( + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., ORMBackrefArgument], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, +) -> ORMBackrefArgument: ... + + +def generate_relationship( + base: Type[Any], + direction: RelationshipDirection, + return_fn: Union[ + Callable[..., Relationship[Any]], Callable[..., ORMBackrefArgument] + ], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, +) -> Union[Relationship[Any], ORMBackrefArgument]: + r"""Generate a :func:`_orm.relationship` or :func:`.backref` + on behalf of two + mapped classes. + + An alternate implementation of this function can be specified using the + :paramref:`.AutomapBase.prepare.generate_relationship` parameter. + + The default implementation of this function is as follows:: + + if return_fn is backref: + return return_fn(attrname, **kw) + elif return_fn is relationship: + return return_fn(referred_cls, **kw) + else: + raise TypeError("Unknown relationship function: %s" % return_fn) + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param direction: indicate the "direction" of the relationship; this will + be one of :data:`.ONETOMANY`, :data:`.MANYTOONE`, :data:`.MANYTOMANY`. + + :param return_fn: the function that is used by default to create the + relationship. This will be either :func:`_orm.relationship` or + :func:`.backref`. The :func:`.backref` function's result will be used to + produce a new :func:`_orm.relationship` in a second step, + so it is critical + that user-defined implementations correctly differentiate between the two + functions, if a custom relationship function is being used. + + :param attrname: the attribute name to which this relationship is being + assigned. If the value of :paramref:`.generate_relationship.return_fn` is + the :func:`.backref` function, then this name is the name that is being + assigned to the backref. + + :param local_cls: the "local" class to which this relationship or backref + will be locally present. + + :param referred_cls: the "referred" class to which the relationship or + backref refers to. + + :param \**kw: all additional keyword arguments are passed along to the + function. + + :return: a :func:`_orm.relationship` or :func:`.backref` construct, + as dictated + by the :paramref:`.generate_relationship.return_fn` parameter. + + """ + + if return_fn is backref: + return return_fn(attrname, **kw) + elif return_fn is relationship: + return return_fn(referred_cls, **kw) + else: + raise TypeError("Unknown relationship function: %s" % return_fn) + + +ByModuleProperties = Properties[Union["ByModuleProperties", Type[Any]]] + + +class AutomapBase: + """Base class for an "automap" schema. + + The :class:`.AutomapBase` class can be compared to the "declarative base" + class that is produced by the :func:`.declarative.declarative_base` + function. In practice, the :class:`.AutomapBase` class is always used + as a mixin along with an actual declarative base. + + A new subclassable :class:`.AutomapBase` is typically instantiated + using the :func:`.automap_base` function. + + .. seealso:: + + :ref:`automap_toplevel` + + """ + + __abstract__ = True + + classes: ClassVar[Properties[Type[Any]]] + """An instance of :class:`.util.Properties` containing classes. + + This object behaves much like the ``.c`` collection on a table. Classes + are present under the name they were given, e.g.:: + + Base = automap_base() + Base.prepare(autoload_with=some_engine) + + User, Address = Base.classes.User, Base.classes.Address + + For class names that overlap with a method name of + :class:`.util.Properties`, such as ``items()``, the getitem form + is also supported:: + + Item = Base.classes["items"] + + """ + + by_module: ClassVar[ByModuleProperties] + """An instance of :class:`.util.Properties` containing a hierarchal + structure of dot-separated module names linked to classes. + + This collection is an alternative to the :attr:`.AutomapBase.classes` + collection that is useful when making use of the + :paramref:`.AutomapBase.prepare.modulename_for_table` parameter, which will + apply distinct ``__module__`` attributes to generated classes. + + The default ``__module__`` an automap-generated class is + ``sqlalchemy.ext.automap``; to access this namespace using + :attr:`.AutomapBase.by_module` looks like:: + + User = Base.by_module.sqlalchemy.ext.automap.User + + If a class had a ``__module__`` of ``mymodule.account``, accessing + this namespace looks like:: + + MyClass = Base.by_module.mymodule.account.MyClass + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`automap_by_module` + + """ + + metadata: ClassVar[MetaData] + """Refers to the :class:`_schema.MetaData` collection that will be used + for new :class:`_schema.Table` objects. + + .. seealso:: + + :ref:`orm_declarative_metadata` + + """ + + _sa_automapbase_bookkeeping: ClassVar[_Bookkeeping] + + @classmethod + @util.deprecated_params( + engine=( + "2.0", + "The :paramref:`_automap.AutomapBase.prepare.engine` parameter " + "is deprecated and will be removed in a future release. " + "Please use the " + ":paramref:`_automap.AutomapBase.prepare.autoload_with` " + "parameter.", + ), + reflect=( + "2.0", + "The :paramref:`_automap.AutomapBase.prepare.reflect` " + "parameter is deprecated and will be removed in a future " + "release. Reflection is enabled when " + ":paramref:`_automap.AutomapBase.prepare.autoload_with` " + "is passed.", + ), + ) + def prepare( + cls: Type[AutomapBase], + autoload_with: Optional[Engine] = None, + engine: Optional[Any] = None, + reflect: bool = False, + schema: Optional[str] = None, + classname_for_table: Optional[PythonNameForTableType] = None, + modulename_for_table: Optional[PythonNameForTableType] = None, + collection_class: Optional[Any] = None, + name_for_scalar_relationship: Optional[ + NameForScalarRelationshipType + ] = None, + name_for_collection_relationship: Optional[ + NameForCollectionRelationshipType + ] = None, + generate_relationship: Optional[GenerateRelationshipType] = None, + reflection_options: Union[ + Dict[_KT, _VT], immutabledict[_KT, _VT] + ] = util.EMPTY_DICT, + ) -> None: + """Extract mapped classes and relationships from the + :class:`_schema.MetaData` and perform mappings. + + For full documentation and examples see + :ref:`automap_basic_use`. + + :param autoload_with: an :class:`_engine.Engine` or + :class:`_engine.Connection` with which + to perform schema reflection; when specified, the + :meth:`_schema.MetaData.reflect` method will be invoked within + the scope of this method. + + :param engine: legacy; use :paramref:`.AutomapBase.autoload_with`. + Used to indicate the :class:`_engine.Engine` or + :class:`_engine.Connection` with which to reflect tables with, + if :paramref:`.AutomapBase.reflect` is True. + + :param reflect: legacy; use :paramref:`.AutomapBase.autoload_with`. + Indicates that :meth:`_schema.MetaData.reflect` should be invoked. + + :param classname_for_table: callable function which will be used to + produce new class names, given a table name. Defaults to + :func:`.classname_for_table`. + + :param modulename_for_table: callable function which will be used to + produce the effective ``__module__`` for an internally generated + class, to allow for multiple classes of the same name in a single + automap base which would be in different "modules". + + Defaults to ``None``, which will indicate that ``__module__`` will not + be set explicitly; the Python runtime will use the value + ``sqlalchemy.ext.automap`` for these classes. + + When assigning ``__module__`` to generated classes, they can be + accessed based on dot-separated module names using the + :attr:`.AutomapBase.by_module` collection. Classes that have + an explicit ``__module_`` assigned using this hook do **not** get + placed into the :attr:`.AutomapBase.classes` collection, only + into :attr:`.AutomapBase.by_module`. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`automap_by_module` + + :param name_for_scalar_relationship: callable function which will be + used to produce relationship names for scalar relationships. Defaults + to :func:`.name_for_scalar_relationship`. + + :param name_for_collection_relationship: callable function which will + be used to produce relationship names for collection-oriented + relationships. Defaults to :func:`.name_for_collection_relationship`. + + :param generate_relationship: callable function which will be used to + actually generate :func:`_orm.relationship` and :func:`.backref` + constructs. Defaults to :func:`.generate_relationship`. + + :param collection_class: the Python collection class that will be used + when a new :func:`_orm.relationship` + object is created that represents a + collection. Defaults to ``list``. + + :param schema: Schema name to reflect when reflecting tables using + the :paramref:`.AutomapBase.prepare.autoload_with` parameter. The name + is passed to the :paramref:`_schema.MetaData.reflect.schema` parameter + of :meth:`_schema.MetaData.reflect`. When omitted, the default schema + in use by the database connection is used. + + .. note:: The :paramref:`.AutomapBase.prepare.schema` + parameter supports reflection of a single schema at a time. + In order to include tables from many schemas, use + multiple calls to :meth:`.AutomapBase.prepare`. + + For an overview of multiple-schema automap including the use + of additional naming conventions to resolve table name + conflicts, see the section :ref:`automap_by_module`. + + .. versionadded:: 2.0 :meth:`.AutomapBase.prepare` supports being + directly invoked any number of times, keeping track of tables + that have already been processed to avoid processing them + a second time. + + :param reflection_options: When present, this dictionary of options + will be passed to :meth:`_schema.MetaData.reflect` + to supply general reflection-specific options like ``only`` and/or + dialect-specific options like ``oracle_resolve_synonyms``. + + .. versionadded:: 1.4 + + """ + + for mr in cls.__mro__: + if "_sa_automapbase_bookkeeping" in mr.__dict__: + automap_base = cast("Type[AutomapBase]", mr) + break + else: + assert False, "Can't locate automap base in class hierarchy" + + glbls = globals() + if classname_for_table is None: + classname_for_table = glbls["classname_for_table"] + if name_for_scalar_relationship is None: + name_for_scalar_relationship = glbls[ + "name_for_scalar_relationship" + ] + if name_for_collection_relationship is None: + name_for_collection_relationship = glbls[ + "name_for_collection_relationship" + ] + if generate_relationship is None: + generate_relationship = glbls["generate_relationship"] + if collection_class is None: + collection_class = list + + if autoload_with: + reflect = True + + if engine: + autoload_with = engine + + if reflect: + assert autoload_with + opts = dict( + schema=schema, + extend_existing=True, + autoload_replace=False, + ) + if reflection_options: + opts.update(reflection_options) + cls.metadata.reflect(autoload_with, **opts) # type: ignore[arg-type] # noqa: E501 + + with _CONFIGURE_MUTEX: + table_to_map_config: Union[ + Dict[Optional[Table], _DeferredMapperConfig], + Dict[Table, _DeferredMapperConfig], + ] = { + cast("Table", m.local_table): m + for m in _DeferredMapperConfig.classes_for_base( + cls, sort=False + ) + } + + many_to_many: List[ + Tuple[Table, Table, List[ForeignKeyConstraint], Table] + ] + many_to_many = [] + + bookkeeping = automap_base._sa_automapbase_bookkeeping + metadata_tables = cls.metadata.tables + + for table_key in set(metadata_tables).difference( + bookkeeping.table_keys + ): + table = metadata_tables[table_key] + bookkeeping.table_keys.add(table_key) + + lcl_m2m, rem_m2m, m2m_const = _is_many_to_many(cls, table) + if lcl_m2m is not None: + assert rem_m2m is not None + assert m2m_const is not None + many_to_many.append((lcl_m2m, rem_m2m, m2m_const, table)) + elif not table.primary_key: + continue + elif table not in table_to_map_config: + clsdict: Dict[str, Any] = {"__table__": table} + if modulename_for_table is not None: + new_module = modulename_for_table( + cls, table.name, table + ) + if new_module is not None: + clsdict["__module__"] = new_module + else: + new_module = None + + newname = classname_for_table(cls, table.name, table) + if new_module is None and newname in cls.classes: + util.warn( + "Ignoring duplicate class name " + f"'{newname}' " + "received in automap base for table " + f"{table.key} without " + "``__module__`` being set; consider using the " + "``modulename_for_table`` hook" + ) + continue + + mapped_cls = type( + newname, + (automap_base,), + clsdict, + ) + map_config = _DeferredMapperConfig.config_for_cls( + mapped_cls + ) + assert map_config.cls.__name__ == newname + if new_module is None: + cls.classes[newname] = mapped_cls + + by_module_properties: ByModuleProperties = cls.by_module + for token in map_config.cls.__module__.split("."): + if token not in by_module_properties: + by_module_properties[token] = util.Properties({}) + + props = by_module_properties[token] + + # we can assert this because the clsregistry + # module would have raised if there was a mismatch + # between modules/classes already. + # see test_cls_schema_name_conflict + assert isinstance(props, Properties) + by_module_properties = props + + by_module_properties[map_config.cls.__name__] = mapped_cls + + table_to_map_config[table] = map_config + + for map_config in table_to_map_config.values(): + _relationships_for_fks( + automap_base, + map_config, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, + ) + + for lcl_m2m, rem_m2m, m2m_const, table in many_to_many: + _m2m_relationship( + automap_base, + lcl_m2m, + rem_m2m, + m2m_const, + table, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, + ) + + for map_config in _DeferredMapperConfig.classes_for_base( + automap_base + ): + map_config.map() + + _sa_decl_prepare = True + """Indicate that the mapping of classes should be deferred. + + The presence of this attribute name indicates to declarative + that the call to mapper() should not occur immediately; instead, + information about the table and attributes to be mapped are gathered + into an internal structure called _DeferredMapperConfig. These + objects can be collected later using classes_for_base(), additional + mapping decisions can be made, and then the map() method will actually + apply the mapping. + + The only real reason this deferral of the whole + thing is needed is to support primary key columns that aren't reflected + yet when the class is declared; everything else can theoretically be + added to the mapper later. However, the _DeferredMapperConfig is a + nice interface in any case which exists at that not usually exposed point + at which declarative has the class and the Table but hasn't called + mapper() yet. + + """ + + @classmethod + def _sa_raise_deferred_config(cls) -> NoReturn: + raise orm_exc.UnmappedClassError( + cls, + msg="Class %s is a subclass of AutomapBase. " + "Mappings are not produced until the .prepare() " + "method is called on the class hierarchy." + % orm_exc._safe_cls_name(cls), + ) + + +@dataclasses.dataclass +class _Bookkeeping: + __slots__ = ("table_keys",) + + table_keys: Set[str] + + +def automap_base( + declarative_base: Optional[Type[Any]] = None, **kw: Any +) -> Any: + r"""Produce a declarative automap base. + + This function produces a new base class that is a product of the + :class:`.AutomapBase` class as well a declarative base produced by + :func:`.declarative.declarative_base`. + + All parameters other than ``declarative_base`` are keyword arguments + that are passed directly to the :func:`.declarative.declarative_base` + function. + + :param declarative_base: an existing class produced by + :func:`.declarative.declarative_base`. When this is passed, the function + no longer invokes :func:`.declarative.declarative_base` itself, and all + other keyword arguments are ignored. + + :param \**kw: keyword arguments are passed along to + :func:`.declarative.declarative_base`. + + """ + if declarative_base is None: + Base = _declarative_base(**kw) + else: + Base = declarative_base + + return type( + Base.__name__, + (AutomapBase, Base), + { + "__abstract__": True, + "classes": util.Properties({}), + "by_module": util.Properties({}), + "_sa_automapbase_bookkeeping": _Bookkeeping(set()), + }, + ) + + +def _is_many_to_many( + automap_base: Type[Any], table: Table +) -> Tuple[ + Optional[Table], Optional[Table], Optional[list[ForeignKeyConstraint]] +]: + fk_constraints = [ + const + for const in table.constraints + if isinstance(const, ForeignKeyConstraint) + ] + if len(fk_constraints) != 2: + return None, None, None + + cols: List[Column[Any]] = sum( + [ + [fk.parent for fk in fk_constraint.elements] + for fk_constraint in fk_constraints + ], + [], + ) + + if set(cols) != set(table.c): + return None, None, None + + return ( + fk_constraints[0].elements[0].column.table, + fk_constraints[1].elements[0].column.table, + fk_constraints, + ) + + +def _relationships_for_fks( + automap_base: Type[Any], + map_config: _DeferredMapperConfig, + table_to_map_config: Union[ + Dict[Optional[Table], _DeferredMapperConfig], + Dict[Table, _DeferredMapperConfig], + ], + collection_class: type, + name_for_scalar_relationship: NameForScalarRelationshipType, + name_for_collection_relationship: NameForCollectionRelationshipType, + generate_relationship: GenerateRelationshipType, +) -> None: + local_table = cast("Optional[Table]", map_config.local_table) + local_cls = cast( + "Optional[Type[Any]]", map_config.cls + ) # derived from a weakref, may be None + + if local_table is None or local_cls is None: + return + for constraint in local_table.constraints: + if isinstance(constraint, ForeignKeyConstraint): + fks = constraint.elements + referred_table = fks[0].column.table + referred_cfg = table_to_map_config.get(referred_table, None) + if referred_cfg is None: + continue + referred_cls = referred_cfg.cls + + if local_cls is not referred_cls and issubclass( + local_cls, referred_cls + ): + continue + + relationship_name = name_for_scalar_relationship( + automap_base, local_cls, referred_cls, constraint + ) + backref_name = name_for_collection_relationship( + automap_base, referred_cls, local_cls, constraint + ) + + o2m_kws: Dict[str, Union[str, bool]] = {} + nullable = False not in {fk.parent.nullable for fk in fks} + if not nullable: + o2m_kws["cascade"] = "all, delete-orphan" + + if ( + constraint.ondelete + and constraint.ondelete.lower() == "cascade" + ): + o2m_kws["passive_deletes"] = True + else: + if ( + constraint.ondelete + and constraint.ondelete.lower() == "set null" + ): + o2m_kws["passive_deletes"] = True + + create_backref = backref_name not in referred_cfg.properties + + if relationship_name not in map_config.properties: + if create_backref: + backref_obj = generate_relationship( + automap_base, + interfaces.ONETOMANY, + backref, + backref_name, + referred_cls, + local_cls, + collection_class=collection_class, + **o2m_kws, + ) + else: + backref_obj = None + rel = generate_relationship( + automap_base, + interfaces.MANYTOONE, + relationship, + relationship_name, + local_cls, + referred_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + backref=backref_obj, + remote_side=[fk.column for fk in constraint.elements], + ) + if rel is not None: + map_config.properties[relationship_name] = rel + if not create_backref: + referred_cfg.properties[ + backref_name + ].back_populates = relationship_name # type: ignore[union-attr] # noqa: E501 + elif create_backref: + rel = generate_relationship( + automap_base, + interfaces.ONETOMANY, + relationship, + backref_name, + referred_cls, + local_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + back_populates=relationship_name, + collection_class=collection_class, + **o2m_kws, + ) + if rel is not None: + referred_cfg.properties[backref_name] = rel + map_config.properties[ + relationship_name + ].back_populates = backref_name # type: ignore[union-attr] + + +def _m2m_relationship( + automap_base: Type[Any], + lcl_m2m: Table, + rem_m2m: Table, + m2m_const: List[ForeignKeyConstraint], + table: Table, + table_to_map_config: Union[ + Dict[Optional[Table], _DeferredMapperConfig], + Dict[Table, _DeferredMapperConfig], + ], + collection_class: type, + name_for_scalar_relationship: NameForCollectionRelationshipType, + name_for_collection_relationship: NameForCollectionRelationshipType, + generate_relationship: GenerateRelationshipType, +) -> None: + map_config = table_to_map_config.get(lcl_m2m, None) + referred_cfg = table_to_map_config.get(rem_m2m, None) + if map_config is None or referred_cfg is None: + return + + local_cls = map_config.cls + referred_cls = referred_cfg.cls + + relationship_name = name_for_collection_relationship( + automap_base, local_cls, referred_cls, m2m_const[0] + ) + backref_name = name_for_collection_relationship( + automap_base, referred_cls, local_cls, m2m_const[1] + ) + + create_backref = backref_name not in referred_cfg.properties + + if table in table_to_map_config: + overlaps = "__*" + else: + overlaps = None + + if relationship_name not in map_config.properties: + if create_backref: + backref_obj = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + backref, + backref_name, + referred_cls, + local_cls, + collection_class=collection_class, + overlaps=overlaps, + ) + else: + backref_obj = None + + rel = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + relationship, + relationship_name, + local_cls, + referred_cls, + overlaps=overlaps, + secondary=table, + primaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[0].elements + ), # type: ignore [arg-type] + secondaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[1].elements + ), # type: ignore [arg-type] + backref=backref_obj, + collection_class=collection_class, + ) + if rel is not None: + map_config.properties[relationship_name] = rel + + if not create_backref: + referred_cfg.properties[ + backref_name + ].back_populates = relationship_name # type: ignore[union-attr] # noqa: E501 + elif create_backref: + rel = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + relationship, + backref_name, + referred_cls, + local_cls, + overlaps=overlaps, + secondary=table, + primaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[1].elements + ), # type: ignore [arg-type] + secondaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[0].elements + ), # type: ignore [arg-type] + back_populates=relationship_name, + collection_class=collection_class, + ) + if rel is not None: + referred_cfg.properties[backref_name] = rel + map_config.properties[ + relationship_name + ].back_populates = backref_name # type: ignore[union-attr] diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/baked.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/baked.py new file mode 100644 index 0000000..60f7ae6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/baked.py @@ -0,0 +1,574 @@ +# ext/baked.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""Baked query extension. + +Provides a creational pattern for the :class:`.query.Query` object which +allows the fully constructed object, Core select statement, and string +compiled result to be fully cached. + + +""" + +import collections.abc as collections_abc +import logging + +from .. import exc as sa_exc +from .. import util +from ..orm import exc as orm_exc +from ..orm.query import Query +from ..orm.session import Session +from ..sql import func +from ..sql import literal_column +from ..sql import util as sql_util + + +log = logging.getLogger(__name__) + + +class Bakery: + """Callable which returns a :class:`.BakedQuery`. + + This object is returned by the class method + :meth:`.BakedQuery.bakery`. It exists as an object + so that the "cache" can be easily inspected. + + .. versionadded:: 1.2 + + + """ + + __slots__ = "cls", "cache" + + def __init__(self, cls_, cache): + self.cls = cls_ + self.cache = cache + + def __call__(self, initial_fn, *args): + return self.cls(self.cache, initial_fn, args) + + +class BakedQuery: + """A builder object for :class:`.query.Query` objects.""" + + __slots__ = "steps", "_bakery", "_cache_key", "_spoiled" + + def __init__(self, bakery, initial_fn, args=()): + self._cache_key = () + self._update_cache_key(initial_fn, args) + self.steps = [initial_fn] + self._spoiled = False + self._bakery = bakery + + @classmethod + def bakery(cls, size=200, _size_alert=None): + """Construct a new bakery. + + :return: an instance of :class:`.Bakery` + + """ + + return Bakery(cls, util.LRUCache(size, size_alert=_size_alert)) + + def _clone(self): + b1 = BakedQuery.__new__(BakedQuery) + b1._cache_key = self._cache_key + b1.steps = list(self.steps) + b1._bakery = self._bakery + b1._spoiled = self._spoiled + return b1 + + def _update_cache_key(self, fn, args=()): + self._cache_key += (fn.__code__,) + args + + def __iadd__(self, other): + if isinstance(other, tuple): + self.add_criteria(*other) + else: + self.add_criteria(other) + return self + + def __add__(self, other): + if isinstance(other, tuple): + return self.with_criteria(*other) + else: + return self.with_criteria(other) + + def add_criteria(self, fn, *args): + """Add a criteria function to this :class:`.BakedQuery`. + + This is equivalent to using the ``+=`` operator to + modify a :class:`.BakedQuery` in-place. + + """ + self._update_cache_key(fn, args) + self.steps.append(fn) + return self + + def with_criteria(self, fn, *args): + """Add a criteria function to a :class:`.BakedQuery` cloned from this + one. + + This is equivalent to using the ``+`` operator to + produce a new :class:`.BakedQuery` with modifications. + + """ + return self._clone().add_criteria(fn, *args) + + def for_session(self, session): + """Return a :class:`_baked.Result` object for this + :class:`.BakedQuery`. + + This is equivalent to calling the :class:`.BakedQuery` as a + Python callable, e.g. ``result = my_baked_query(session)``. + + """ + return Result(self, session) + + def __call__(self, session): + return self.for_session(session) + + def spoil(self, full=False): + """Cancel any query caching that will occur on this BakedQuery object. + + The BakedQuery can continue to be used normally, however additional + creational functions will not be cached; they will be called + on every invocation. + + This is to support the case where a particular step in constructing + a baked query disqualifies the query from being cacheable, such + as a variant that relies upon some uncacheable value. + + :param full: if False, only functions added to this + :class:`.BakedQuery` object subsequent to the spoil step will be + non-cached; the state of the :class:`.BakedQuery` up until + this point will be pulled from the cache. If True, then the + entire :class:`_query.Query` object is built from scratch each + time, with all creational functions being called on each + invocation. + + """ + if not full and not self._spoiled: + _spoil_point = self._clone() + _spoil_point._cache_key += ("_query_only",) + self.steps = [_spoil_point._retrieve_baked_query] + self._spoiled = True + return self + + def _effective_key(self, session): + """Return the key that actually goes into the cache dictionary for + this :class:`.BakedQuery`, taking into account the given + :class:`.Session`. + + This basically means we also will include the session's query_class, + as the actual :class:`_query.Query` object is part of what's cached + and needs to match the type of :class:`_query.Query` that a later + session will want to use. + + """ + return self._cache_key + (session._query_cls,) + + def _with_lazyload_options(self, options, effective_path, cache_path=None): + """Cloning version of _add_lazyload_options.""" + q = self._clone() + q._add_lazyload_options(options, effective_path, cache_path=cache_path) + return q + + def _add_lazyload_options(self, options, effective_path, cache_path=None): + """Used by per-state lazy loaders to add options to the + "lazy load" query from a parent query. + + Creates a cache key based on given load path and query options; + if a repeatable cache key cannot be generated, the query is + "spoiled" so that it won't use caching. + + """ + + key = () + + if not cache_path: + cache_path = effective_path + + for opt in options: + if opt._is_legacy_option or opt._is_compile_state: + ck = opt._generate_cache_key() + if ck is None: + self.spoil(full=True) + else: + assert not ck[1], ( + "loader options with variable bound parameters " + "not supported with baked queries. Please " + "use new-style select() statements for cached " + "ORM queries." + ) + key += ck[0] + + self.add_criteria( + lambda q: q._with_current_path(effective_path).options(*options), + cache_path.path, + key, + ) + + def _retrieve_baked_query(self, session): + query = self._bakery.get(self._effective_key(session), None) + if query is None: + query = self._as_query(session) + self._bakery[self._effective_key(session)] = query.with_session( + None + ) + return query.with_session(session) + + def _bake(self, session): + query = self._as_query(session) + query.session = None + + # in 1.4, this is where before_compile() event is + # invoked + statement = query._statement_20() + + # if the query is not safe to cache, we still do everything as though + # we did cache it, since the receiver of _bake() assumes subqueryload + # context was set up, etc. + # + # note also we want to cache the statement itself because this + # allows the statement itself to hold onto its cache key that is + # used by the Connection, which in itself is more expensive to + # generate than what BakedQuery was able to provide in 1.3 and prior + + if statement._compile_options._bake_ok: + self._bakery[self._effective_key(session)] = ( + query, + statement, + ) + + return query, statement + + def to_query(self, query_or_session): + """Return the :class:`_query.Query` object for use as a subquery. + + This method should be used within the lambda callable being used + to generate a step of an enclosing :class:`.BakedQuery`. The + parameter should normally be the :class:`_query.Query` object that + is passed to the lambda:: + + sub_bq = self.bakery(lambda s: s.query(User.name)) + sub_bq += lambda q: q.filter( + User.id == Address.user_id).correlate(Address) + + main_bq = self.bakery(lambda s: s.query(Address)) + main_bq += lambda q: q.filter( + sub_bq.to_query(q).exists()) + + In the case where the subquery is used in the first callable against + a :class:`.Session`, the :class:`.Session` is also accepted:: + + sub_bq = self.bakery(lambda s: s.query(User.name)) + sub_bq += lambda q: q.filter( + User.id == Address.user_id).correlate(Address) + + main_bq = self.bakery( + lambda s: s.query( + Address.id, sub_bq.to_query(q).scalar_subquery()) + ) + + :param query_or_session: a :class:`_query.Query` object or a class + :class:`.Session` object, that is assumed to be within the context + of an enclosing :class:`.BakedQuery` callable. + + + .. versionadded:: 1.3 + + + """ + + if isinstance(query_or_session, Session): + session = query_or_session + elif isinstance(query_or_session, Query): + session = query_or_session.session + if session is None: + raise sa_exc.ArgumentError( + "Given Query needs to be associated with a Session" + ) + else: + raise TypeError( + "Query or Session object expected, got %r." + % type(query_or_session) + ) + return self._as_query(session) + + def _as_query(self, session): + query = self.steps[0](session) + + for step in self.steps[1:]: + query = step(query) + + return query + + +class Result: + """Invokes a :class:`.BakedQuery` against a :class:`.Session`. + + The :class:`_baked.Result` object is where the actual :class:`.query.Query` + object gets created, or retrieved from the cache, + against a target :class:`.Session`, and is then invoked for results. + + """ + + __slots__ = "bq", "session", "_params", "_post_criteria" + + def __init__(self, bq, session): + self.bq = bq + self.session = session + self._params = {} + self._post_criteria = [] + + def params(self, *args, **kw): + """Specify parameters to be replaced into the string SQL statement.""" + + if len(args) == 1: + kw.update(args[0]) + elif len(args) > 0: + raise sa_exc.ArgumentError( + "params() takes zero or one positional argument, " + "which is a dictionary." + ) + self._params.update(kw) + return self + + def _using_post_criteria(self, fns): + if fns: + self._post_criteria.extend(fns) + return self + + def with_post_criteria(self, fn): + """Add a criteria function that will be applied post-cache. + + This adds a function that will be run against the + :class:`_query.Query` object after it is retrieved from the + cache. This currently includes **only** the + :meth:`_query.Query.params` and :meth:`_query.Query.execution_options` + methods. + + .. warning:: :meth:`_baked.Result.with_post_criteria` + functions are applied + to the :class:`_query.Query` + object **after** the query's SQL statement + object has been retrieved from the cache. Only + :meth:`_query.Query.params` and + :meth:`_query.Query.execution_options` + methods should be used. + + + .. versionadded:: 1.2 + + + """ + return self._using_post_criteria([fn]) + + def _as_query(self): + q = self.bq._as_query(self.session).params(self._params) + for fn in self._post_criteria: + q = fn(q) + return q + + def __str__(self): + return str(self._as_query()) + + def __iter__(self): + return self._iter().__iter__() + + def _iter(self): + bq = self.bq + + if not self.session.enable_baked_queries or bq._spoiled: + return self._as_query()._iter() + + query, statement = bq._bakery.get( + bq._effective_key(self.session), (None, None) + ) + if query is None: + query, statement = bq._bake(self.session) + + if self._params: + q = query.params(self._params) + else: + q = query + for fn in self._post_criteria: + q = fn(q) + + params = q._params + execution_options = dict(q._execution_options) + execution_options.update( + { + "_sa_orm_load_options": q.load_options, + "compiled_cache": bq._bakery, + } + ) + + result = self.session.execute( + statement, params, execution_options=execution_options + ) + if result._attributes.get("is_single_entity", False): + result = result.scalars() + + if result._attributes.get("filtered", False): + result = result.unique() + + return result + + def count(self): + """return the 'count'. + + Equivalent to :meth:`_query.Query.count`. + + Note this uses a subquery to ensure an accurate count regardless + of the structure of the original statement. + + """ + + col = func.count(literal_column("*")) + bq = self.bq.with_criteria(lambda q: q._legacy_from_self(col)) + return bq.for_session(self.session).params(self._params).scalar() + + def scalar(self): + """Return the first element of the first result or None + if no rows present. If multiple rows are returned, + raises MultipleResultsFound. + + Equivalent to :meth:`_query.Query.scalar`. + + """ + try: + ret = self.one() + if not isinstance(ret, collections_abc.Sequence): + return ret + return ret[0] + except orm_exc.NoResultFound: + return None + + def first(self): + """Return the first row. + + Equivalent to :meth:`_query.Query.first`. + + """ + + bq = self.bq.with_criteria(lambda q: q.slice(0, 1)) + return ( + bq.for_session(self.session) + .params(self._params) + ._using_post_criteria(self._post_criteria) + ._iter() + .first() + ) + + def one(self): + """Return exactly one result or raise an exception. + + Equivalent to :meth:`_query.Query.one`. + + """ + return self._iter().one() + + def one_or_none(self): + """Return one or zero results, or raise an exception for multiple + rows. + + Equivalent to :meth:`_query.Query.one_or_none`. + + """ + return self._iter().one_or_none() + + def all(self): + """Return all rows. + + Equivalent to :meth:`_query.Query.all`. + + """ + return self._iter().all() + + def get(self, ident): + """Retrieve an object based on identity. + + Equivalent to :meth:`_query.Query.get`. + + """ + + query = self.bq.steps[0](self.session) + return query._get_impl(ident, self._load_on_pk_identity) + + def _load_on_pk_identity(self, session, query, primary_key_identity, **kw): + """Load the given primary key identity from the database.""" + + mapper = query._raw_columns[0]._annotations["parententity"] + + _get_clause, _get_params = mapper._get_clause + + def setup(query): + _lcl_get_clause = _get_clause + q = query._clone() + q._get_condition() + q._order_by = None + + # None present in ident - turn those comparisons + # into "IS NULL" + if None in primary_key_identity: + nones = { + _get_params[col].key + for col, value in zip( + mapper.primary_key, primary_key_identity + ) + if value is None + } + _lcl_get_clause = sql_util.adapt_criterion_to_null( + _lcl_get_clause, nones + ) + + # TODO: can mapper._get_clause be pre-adapted? + q._where_criteria = ( + sql_util._deep_annotate(_lcl_get_clause, {"_orm_adapt": True}), + ) + + for fn in self._post_criteria: + q = fn(q) + return q + + # cache the query against a key that includes + # which positions in the primary key are NULL + # (remember, we can map to an OUTER JOIN) + bq = self.bq + + # add the clause we got from mapper._get_clause to the cache + # key so that if a race causes multiple calls to _get_clause, + # we've cached on ours + bq = bq._clone() + bq._cache_key += (_get_clause,) + + bq = bq.with_criteria( + setup, tuple(elem is None for elem in primary_key_identity) + ) + + params = { + _get_params[primary_key].key: id_val + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + } + + result = list(bq.for_session(self.session).params(**params)) + l = len(result) + if l > 1: + raise orm_exc.MultipleResultsFound() + elif l: + return result[0] + else: + return None + + +bakery = BakedQuery.bakery diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/compiler.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/compiler.py new file mode 100644 index 0000000..01462ad --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/compiler.py @@ -0,0 +1,555 @@ +# ext/compiler.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r"""Provides an API for creation of custom ClauseElements and compilers. + +Synopsis +======== + +Usage involves the creation of one or more +:class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or +more callables defining its compilation:: + + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.sql.expression import ColumnClause + + class MyColumn(ColumnClause): + inherit_cache = True + + @compiles(MyColumn) + def compile_mycolumn(element, compiler, **kw): + return "[%s]" % element.name + +Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`, +the base expression element for named column objects. The ``compiles`` +decorator registers itself with the ``MyColumn`` class so that it is invoked +when the object is compiled to a string:: + + from sqlalchemy import select + + s = select(MyColumn('x'), MyColumn('y')) + print(str(s)) + +Produces:: + + SELECT [x], [y] + +Dialect-specific compilation rules +================================== + +Compilers can also be made dialect-specific. The appropriate compiler will be +invoked for the dialect in use:: + + from sqlalchemy.schema import DDLElement + + class AlterColumn(DDLElement): + inherit_cache = False + + def __init__(self, column, cmd): + self.column = column + self.cmd = cmd + + @compiles(AlterColumn) + def visit_alter_column(element, compiler, **kw): + return "ALTER COLUMN %s ..." % element.column.name + + @compiles(AlterColumn, 'postgresql') + def visit_alter_column(element, compiler, **kw): + return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, + element.column.name) + +The second ``visit_alter_table`` will be invoked when any ``postgresql`` +dialect is used. + +.. _compilerext_compiling_subelements: + +Compiling sub-elements of a custom expression construct +======================================================= + +The ``compiler`` argument is the +:class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object +can be inspected for any information about the in-progress compilation, +including ``compiler.dialect``, ``compiler.statement`` etc. The +:class:`~sqlalchemy.sql.compiler.SQLCompiler` and +:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()`` +method which can be used for compilation of embedded attributes:: + + from sqlalchemy.sql.expression import Executable, ClauseElement + + class InsertFromSelect(Executable, ClauseElement): + inherit_cache = False + + def __init__(self, table, select): + self.table = table + self.select = select + + @compiles(InsertFromSelect) + def visit_insert_from_select(element, compiler, **kw): + return "INSERT INTO %s (%s)" % ( + compiler.process(element.table, asfrom=True, **kw), + compiler.process(element.select, **kw) + ) + + insert = InsertFromSelect(t1, select(t1).where(t1.c.x>5)) + print(insert) + +Produces:: + + "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z + FROM mytable WHERE mytable.x > :x_1)" + +.. note:: + + The above ``InsertFromSelect`` construct is only an example, this actual + functionality is already available using the + :meth:`_expression.Insert.from_select` method. + + +Cross Compiling between SQL and DDL compilers +--------------------------------------------- + +SQL and DDL constructs are each compiled using different base compilers - +``SQLCompiler`` and ``DDLCompiler``. A common need is to access the +compilation rules of SQL expressions from within a DDL expression. The +``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as +below where we generate a CHECK constraint that embeds a SQL expression:: + + @compiles(MyConstraint) + def compile_my_constraint(constraint, ddlcompiler, **kw): + kw['literal_binds'] = True + return "CONSTRAINT %s CHECK (%s)" % ( + constraint.name, + ddlcompiler.sql_compiler.process( + constraint.expression, **kw) + ) + +Above, we add an additional flag to the process step as called by +:meth:`.SQLCompiler.process`, which is the ``literal_binds`` flag. This +indicates that any SQL expression which refers to a :class:`.BindParameter` +object or other "literal" object such as those which refer to strings or +integers should be rendered **in-place**, rather than being referred to as +a bound parameter; when emitting DDL, bound parameters are typically not +supported. + + +Changing the default compilation of existing constructs +======================================================= + +The compiler extension applies just as well to the existing constructs. When +overriding the compilation of a built in SQL construct, the @compiles +decorator is invoked upon the appropriate class (be sure to use the class, +i.e. ``Insert`` or ``Select``, instead of the creation function such +as ``insert()`` or ``select()``). + +Within the new compilation function, to get at the "original" compilation +routine, use the appropriate visit_XXX method - this +because compiler.process() will call upon the overriding routine and cause +an endless loop. Such as, to add "prefix" to all insert statements:: + + from sqlalchemy.sql.expression import Insert + + @compiles(Insert) + def prefix_inserts(insert, compiler, **kw): + return compiler.visit_insert(insert.prefix_with("some prefix"), **kw) + +The above compiler will prefix all INSERT statements with "some prefix" when +compiled. + +.. _type_compilation_extension: + +Changing Compilation of Types +============================= + +``compiler`` works for types, too, such as below where we implement the +MS-SQL specific 'max' keyword for ``String``/``VARCHAR``:: + + @compiles(String, 'mssql') + @compiles(VARCHAR, 'mssql') + def compile_varchar(element, compiler, **kw): + if element.length == 'max': + return "VARCHAR('max')" + else: + return compiler.visit_VARCHAR(element, **kw) + + foo = Table('foo', metadata, + Column('data', VARCHAR('max')) + ) + +Subclassing Guidelines +====================== + +A big part of using the compiler extension is subclassing SQLAlchemy +expression constructs. To make this easier, the expression and +schema packages feature a set of "bases" intended for common tasks. +A synopsis is as follows: + +* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root + expression class. Any SQL expression can be derived from this base, and is + probably the best choice for longer constructs such as specialized INSERT + statements. + +* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all + "column-like" elements. Anything that you'd place in the "columns" clause of + a SELECT statement (as well as order by and group by) can derive from this - + the object will automatically have Python "comparison" behavior. + + :class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a + ``type`` member which is expression's return type. This can be established + at the instance level in the constructor, or at the class level if its + generally constant:: + + class timestamp(ColumnElement): + type = TIMESTAMP() + inherit_cache = True + +* :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a + ``ColumnElement`` and a "from clause" like object, and represents a SQL + function or stored procedure type of call. Since most databases support + statements along the line of "SELECT FROM <some function>" + ``FunctionElement`` adds in the ability to be used in the FROM clause of a + ``select()`` construct:: + + from sqlalchemy.sql.expression import FunctionElement + + class coalesce(FunctionElement): + name = 'coalesce' + inherit_cache = True + + @compiles(coalesce) + def compile(element, compiler, **kw): + return "coalesce(%s)" % compiler.process(element.clauses, **kw) + + @compiles(coalesce, 'oracle') + def compile(element, compiler, **kw): + if len(element.clauses) > 2: + raise TypeError("coalesce only supports two arguments on Oracle") + return "nvl(%s)" % compiler.process(element.clauses, **kw) + +* :class:`.ExecutableDDLElement` - The root of all DDL expressions, + like CREATE TABLE, ALTER TABLE, etc. Compilation of + :class:`.ExecutableDDLElement` subclasses is issued by a + :class:`.DDLCompiler` instead of a :class:`.SQLCompiler`. + :class:`.ExecutableDDLElement` can also be used as an event hook in + conjunction with event hooks like :meth:`.DDLEvents.before_create` and + :meth:`.DDLEvents.after_create`, allowing the construct to be invoked + automatically during CREATE TABLE and DROP TABLE sequences. + + .. seealso:: + + :ref:`metadata_ddl_toplevel` - contains examples of associating + :class:`.DDL` objects (which are themselves :class:`.ExecutableDDLElement` + instances) with :class:`.DDLEvents` event hooks. + +* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which + should be used with any expression class that represents a "standalone" + SQL statement that can be passed directly to an ``execute()`` method. It + is already implicit within ``DDLElement`` and ``FunctionElement``. + +Most of the above constructs also respond to SQL statement caching. A +subclassed construct will want to define the caching behavior for the object, +which usually means setting the flag ``inherit_cache`` to the value of +``False`` or ``True``. See the next section :ref:`compilerext_caching` +for background. + + +.. _compilerext_caching: + +Enabling Caching Support for Custom Constructs +============================================== + +SQLAlchemy as of version 1.4 includes a +:ref:`SQL compilation caching facility <sql_caching>` which will allow +equivalent SQL constructs to cache their stringified form, along with other +structural information used to fetch results from the statement. + +For reasons discussed at :ref:`caching_caveats`, the implementation of this +caching system takes a conservative approach towards including custom SQL +constructs and/or subclasses within the caching system. This includes that +any user-defined SQL constructs, including all the examples for this +extension, will not participate in caching by default unless they positively +assert that they are able to do so. The :attr:`.HasCacheKey.inherit_cache` +attribute when set to ``True`` at the class level of a specific subclass +will indicate that instances of this class may be safely cached, using the +cache key generation scheme of the immediate superclass. This applies +for example to the "synopsis" example indicated previously:: + + class MyColumn(ColumnClause): + inherit_cache = True + + @compiles(MyColumn) + def compile_mycolumn(element, compiler, **kw): + return "[%s]" % element.name + +Above, the ``MyColumn`` class does not include any new state that +affects its SQL compilation; the cache key of ``MyColumn`` instances will +make use of that of the ``ColumnClause`` superclass, meaning it will take +into account the class of the object (``MyColumn``), the string name and +datatype of the object:: + + >>> MyColumn("some_name", String())._generate_cache_key() + CacheKey( + key=('0', <class '__main__.MyColumn'>, + 'name', 'some_name', + 'type', (<class 'sqlalchemy.sql.sqltypes.String'>, + ('length', None), ('collation', None)) + ), bindparams=[]) + +For objects that are likely to be **used liberally as components within many +larger statements**, such as :class:`_schema.Column` subclasses and custom SQL +datatypes, it's important that **caching be enabled as much as possible**, as +this may otherwise negatively affect performance. + +An example of an object that **does** contain state which affects its SQL +compilation is the one illustrated at :ref:`compilerext_compiling_subelements`; +this is an "INSERT FROM SELECT" construct that combines together a +:class:`_schema.Table` as well as a :class:`_sql.Select` construct, each of +which independently affect the SQL string generation of the construct. For +this class, the example illustrates that it simply does not participate in +caching:: + + class InsertFromSelect(Executable, ClauseElement): + inherit_cache = False + + def __init__(self, table, select): + self.table = table + self.select = select + + @compiles(InsertFromSelect) + def visit_insert_from_select(element, compiler, **kw): + return "INSERT INTO %s (%s)" % ( + compiler.process(element.table, asfrom=True, **kw), + compiler.process(element.select, **kw) + ) + +While it is also possible that the above ``InsertFromSelect`` could be made to +produce a cache key that is composed of that of the :class:`_schema.Table` and +:class:`_sql.Select` components together, the API for this is not at the moment +fully public. However, for an "INSERT FROM SELECT" construct, which is only +used by itself for specific operations, caching is not as critical as in the +previous example. + +For objects that are **used in relative isolation and are generally +standalone**, such as custom :term:`DML` constructs like an "INSERT FROM +SELECT", **caching is generally less critical** as the lack of caching for such +a construct will have only localized implications for that specific operation. + + +Further Examples +================ + +"UTC timestamp" function +------------------------- + +A function that works like "CURRENT_TIMESTAMP" except applies the +appropriate conversions so that the time is in UTC time. Timestamps are best +stored in relational databases as UTC, without time zones. UTC so that your +database doesn't think time has gone backwards in the hour when daylight +savings ends, without timezones because timezones are like character +encodings - they're best applied only at the endpoints of an application +(i.e. convert to UTC upon user input, re-apply desired timezone upon display). + +For PostgreSQL and Microsoft SQL Server:: + + from sqlalchemy.sql import expression + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.types import DateTime + + class utcnow(expression.FunctionElement): + type = DateTime() + inherit_cache = True + + @compiles(utcnow, 'postgresql') + def pg_utcnow(element, compiler, **kw): + return "TIMEZONE('utc', CURRENT_TIMESTAMP)" + + @compiles(utcnow, 'mssql') + def ms_utcnow(element, compiler, **kw): + return "GETUTCDATE()" + +Example usage:: + + from sqlalchemy import ( + Table, Column, Integer, String, DateTime, MetaData + ) + metadata = MetaData() + event = Table("event", metadata, + Column("id", Integer, primary_key=True), + Column("description", String(50), nullable=False), + Column("timestamp", DateTime, server_default=utcnow()) + ) + +"GREATEST" function +------------------- + +The "GREATEST" function is given any number of arguments and returns the one +that is of the highest value - its equivalent to Python's ``max`` +function. A SQL standard version versus a CASE based version which only +accommodates two arguments:: + + from sqlalchemy.sql import expression, case + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.types import Numeric + + class greatest(expression.FunctionElement): + type = Numeric() + name = 'greatest' + inherit_cache = True + + @compiles(greatest) + def default_greatest(element, compiler, **kw): + return compiler.visit_function(element) + + @compiles(greatest, 'sqlite') + @compiles(greatest, 'mssql') + @compiles(greatest, 'oracle') + def case_greatest(element, compiler, **kw): + arg1, arg2 = list(element.clauses) + return compiler.process(case((arg1 > arg2, arg1), else_=arg2), **kw) + +Example usage:: + + Session.query(Account).\ + filter( + greatest( + Account.checking_balance, + Account.savings_balance) > 10000 + ) + +"false" expression +------------------ + +Render a "false" constant expression, rendering as "0" on platforms that +don't have a "false" constant:: + + from sqlalchemy.sql import expression + from sqlalchemy.ext.compiler import compiles + + class sql_false(expression.ColumnElement): + inherit_cache = True + + @compiles(sql_false) + def default_false(element, compiler, **kw): + return "false" + + @compiles(sql_false, 'mssql') + @compiles(sql_false, 'mysql') + @compiles(sql_false, 'oracle') + def int_false(element, compiler, **kw): + return "0" + +Example usage:: + + from sqlalchemy import select, union_all + + exp = union_all( + select(users.c.name, sql_false().label("enrolled")), + select(customers.c.name, customers.c.enrolled) + ) + +""" +from .. import exc +from ..sql import sqltypes + + +def compiles(class_, *specs): + """Register a function as a compiler for a + given :class:`_expression.ClauseElement` type.""" + + def decorate(fn): + # get an existing @compiles handler + existing = class_.__dict__.get("_compiler_dispatcher", None) + + # get the original handler. All ClauseElement classes have one + # of these, but some TypeEngine classes will not. + existing_dispatch = getattr(class_, "_compiler_dispatch", None) + + if not existing: + existing = _dispatcher() + + if existing_dispatch: + + def _wrap_existing_dispatch(element, compiler, **kw): + try: + return existing_dispatch(element, compiler, **kw) + except exc.UnsupportedCompilationError as uce: + raise exc.UnsupportedCompilationError( + compiler, + type(element), + message="%s construct has no default " + "compilation handler." % type(element), + ) from uce + + existing.specs["default"] = _wrap_existing_dispatch + + # TODO: why is the lambda needed ? + setattr( + class_, + "_compiler_dispatch", + lambda *arg, **kw: existing(*arg, **kw), + ) + setattr(class_, "_compiler_dispatcher", existing) + + if specs: + for s in specs: + existing.specs[s] = fn + + else: + existing.specs["default"] = fn + return fn + + return decorate + + +def deregister(class_): + """Remove all custom compilers associated with a given + :class:`_expression.ClauseElement` type. + + """ + + if hasattr(class_, "_compiler_dispatcher"): + class_._compiler_dispatch = class_._original_compiler_dispatch + del class_._compiler_dispatcher + + +class _dispatcher: + def __init__(self): + self.specs = {} + + def __call__(self, element, compiler, **kw): + # TODO: yes, this could also switch off of DBAPI in use. + fn = self.specs.get(compiler.dialect.name, None) + if not fn: + try: + fn = self.specs["default"] + except KeyError as ke: + raise exc.UnsupportedCompilationError( + compiler, + type(element), + message="%s construct has no default " + "compilation handler." % type(element), + ) from ke + + # if compilation includes add_to_result_map, collect add_to_result_map + # arguments from the user-defined callable, which are probably none + # because this is not public API. if it wasn't called, then call it + # ourselves. + arm = kw.get("add_to_result_map", None) + if arm: + arm_collection = [] + kw["add_to_result_map"] = lambda *args: arm_collection.append(args) + + expr = fn(element, compiler, **kw) + + if arm: + if not arm_collection: + arm_collection.append( + (None, None, (element,), sqltypes.NULLTYPE) + ) + for tup in arm_collection: + arm(*tup) + return expr diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__init__.py new file mode 100644 index 0000000..37da403 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__init__.py @@ -0,0 +1,65 @@ +# ext/declarative/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from .extensions import AbstractConcreteBase +from .extensions import ConcreteBase +from .extensions import DeferredReflection +from ... import util +from ...orm.decl_api import as_declarative as _as_declarative +from ...orm.decl_api import declarative_base as _declarative_base +from ...orm.decl_api import DeclarativeMeta +from ...orm.decl_api import declared_attr +from ...orm.decl_api import has_inherited_table as _has_inherited_table +from ...orm.decl_api import synonym_for as _synonym_for + + +@util.moved_20( + "The ``declarative_base()`` function is now available as " + ":func:`sqlalchemy.orm.declarative_base`." +) +def declarative_base(*arg, **kw): + return _declarative_base(*arg, **kw) + + +@util.moved_20( + "The ``as_declarative()`` function is now available as " + ":func:`sqlalchemy.orm.as_declarative`" +) +def as_declarative(*arg, **kw): + return _as_declarative(*arg, **kw) + + +@util.moved_20( + "The ``has_inherited_table()`` function is now available as " + ":func:`sqlalchemy.orm.has_inherited_table`." +) +def has_inherited_table(*arg, **kw): + return _has_inherited_table(*arg, **kw) + + +@util.moved_20( + "The ``synonym_for()`` function is now available as " + ":func:`sqlalchemy.orm.synonym_for`" +) +def synonym_for(*arg, **kw): + return _synonym_for(*arg, **kw) + + +__all__ = [ + "declarative_base", + "synonym_for", + "has_inherited_table", + "instrument_declarative", + "declared_attr", + "as_declarative", + "ConcreteBase", + "AbstractConcreteBase", + "DeclarativeMeta", + "DeferredReflection", +] diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..3b81c5f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__pycache__/extensions.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__pycache__/extensions.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..198346b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__pycache__/extensions.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/extensions.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/extensions.py new file mode 100644 index 0000000..c0f7e34 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/extensions.py @@ -0,0 +1,548 @@ +# ext/declarative/extensions.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""Public API functions and helpers for declarative.""" +from __future__ import annotations + +import collections +import contextlib +from typing import Any +from typing import Callable +from typing import TYPE_CHECKING +from typing import Union + +from ... import exc as sa_exc +from ...engine import Connection +from ...engine import Engine +from ...orm import exc as orm_exc +from ...orm import relationships +from ...orm.base import _mapper_or_none +from ...orm.clsregistry import _resolver +from ...orm.decl_base import _DeferredMapperConfig +from ...orm.util import polymorphic_union +from ...schema import Table +from ...util import OrderedDict + +if TYPE_CHECKING: + from ...sql.schema import MetaData + + +class ConcreteBase: + """A helper class for 'concrete' declarative mappings. + + :class:`.ConcreteBase` will use the :func:`.polymorphic_union` + function automatically, against all tables mapped as a subclass + to this class. The function is called via the + ``__declare_last__()`` function, which is essentially + a hook for the :meth:`.after_configured` event. + + :class:`.ConcreteBase` produces a mapped + table for the class itself. Compare to :class:`.AbstractConcreteBase`, + which does not. + + Example:: + + from sqlalchemy.ext.declarative import ConcreteBase + + class Employee(ConcreteBase, Base): + __tablename__ = 'employee' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + __mapper_args__ = { + 'polymorphic_identity':'employee', + 'concrete':True} + + class Manager(Employee): + __tablename__ = 'manager' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + manager_data = Column(String(40)) + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True} + + + The name of the discriminator column used by :func:`.polymorphic_union` + defaults to the name ``type``. To suit the use case of a mapping where an + actual column in a mapped table is already named ``type``, the + discriminator name can be configured by setting the + ``_concrete_discriminator_name`` attribute:: + + class Employee(ConcreteBase, Base): + _concrete_discriminator_name = '_concrete_discriminator' + + .. versionadded:: 1.3.19 Added the ``_concrete_discriminator_name`` + attribute to :class:`_declarative.ConcreteBase` so that the + virtual discriminator column name can be customized. + + .. versionchanged:: 1.4.2 The ``_concrete_discriminator_name`` attribute + need only be placed on the basemost class to take correct effect for + all subclasses. An explicit error message is now raised if the + mapped column names conflict with the discriminator name, whereas + in the 1.3.x series there would be some warnings and then a non-useful + query would be generated. + + .. seealso:: + + :class:`.AbstractConcreteBase` + + :ref:`concrete_inheritance` + + + """ + + @classmethod + def _create_polymorphic_union(cls, mappers, discriminator_name): + return polymorphic_union( + OrderedDict( + (mp.polymorphic_identity, mp.local_table) for mp in mappers + ), + discriminator_name, + "pjoin", + ) + + @classmethod + def __declare_first__(cls): + m = cls.__mapper__ + if m.with_polymorphic: + return + + discriminator_name = ( + getattr(cls, "_concrete_discriminator_name", None) or "type" + ) + + mappers = list(m.self_and_descendants) + pjoin = cls._create_polymorphic_union(mappers, discriminator_name) + m._set_with_polymorphic(("*", pjoin)) + m._set_polymorphic_on(pjoin.c[discriminator_name]) + + +class AbstractConcreteBase(ConcreteBase): + """A helper class for 'concrete' declarative mappings. + + :class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union` + function automatically, against all tables mapped as a subclass + to this class. The function is called via the + ``__declare_first__()`` function, which is essentially + a hook for the :meth:`.before_configured` event. + + :class:`.AbstractConcreteBase` applies :class:`_orm.Mapper` for its + immediately inheriting class, as would occur for any other + declarative mapped class. However, the :class:`_orm.Mapper` is not + mapped to any particular :class:`.Table` object. Instead, it's + mapped directly to the "polymorphic" selectable produced by + :func:`.polymorphic_union`, and performs no persistence operations on its + own. Compare to :class:`.ConcreteBase`, which maps its + immediately inheriting class to an actual + :class:`.Table` that stores rows directly. + + .. note:: + + The :class:`.AbstractConcreteBase` delays the mapper creation of the + base class until all the subclasses have been defined, + as it needs to create a mapping against a selectable that will include + all subclass tables. In order to achieve this, it waits for the + **mapper configuration event** to occur, at which point it scans + through all the configured subclasses and sets up a mapping that will + query against all subclasses at once. + + While this event is normally invoked automatically, in the case of + :class:`.AbstractConcreteBase`, it may be necessary to invoke it + explicitly after **all** subclass mappings are defined, if the first + operation is to be a query against this base class. To do so, once all + the desired classes have been configured, the + :meth:`_orm.registry.configure` method on the :class:`_orm.registry` + in use can be invoked, which is available in relation to a particular + declarative base class:: + + Base.registry.configure() + + Example:: + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.ext.declarative import AbstractConcreteBase + + class Base(DeclarativeBase): + pass + + class Employee(AbstractConcreteBase, Base): + pass + + class Manager(Employee): + __tablename__ = 'manager' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + manager_data = Column(String(40)) + + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True + } + + Base.registry.configure() + + The abstract base class is handled by declarative in a special way; + at class configuration time, it behaves like a declarative mixin + or an ``__abstract__`` base class. Once classes are configured + and mappings are produced, it then gets mapped itself, but + after all of its descendants. This is a very unique system of mapping + not found in any other SQLAlchemy API feature. + + Using this approach, we can specify columns and properties + that will take place on mapped subclasses, in the way that + we normally do as in :ref:`declarative_mixins`:: + + from sqlalchemy.ext.declarative import AbstractConcreteBase + + class Company(Base): + __tablename__ = 'company' + id = Column(Integer, primary_key=True) + + class Employee(AbstractConcreteBase, Base): + strict_attrs = True + + employee_id = Column(Integer, primary_key=True) + + @declared_attr + def company_id(cls): + return Column(ForeignKey('company.id')) + + @declared_attr + def company(cls): + return relationship("Company") + + class Manager(Employee): + __tablename__ = 'manager' + + name = Column(String(50)) + manager_data = Column(String(40)) + + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True + } + + Base.registry.configure() + + When we make use of our mappings however, both ``Manager`` and + ``Employee`` will have an independently usable ``.company`` attribute:: + + session.execute( + select(Employee).filter(Employee.company.has(id=5)) + ) + + :param strict_attrs: when specified on the base class, "strict" attribute + mode is enabled which attempts to limit ORM mapped attributes on the + base class to only those that are immediately present, while still + preserving "polymorphic" loading behavior. + + .. versionadded:: 2.0 + + .. seealso:: + + :class:`.ConcreteBase` + + :ref:`concrete_inheritance` + + :ref:`abstract_concrete_base` + + """ + + __no_table__ = True + + @classmethod + def __declare_first__(cls): + cls._sa_decl_prepare_nocascade() + + @classmethod + def _sa_decl_prepare_nocascade(cls): + if getattr(cls, "__mapper__", None): + return + + to_map = _DeferredMapperConfig.config_for_cls(cls) + + # can't rely on 'self_and_descendants' here + # since technically an immediate subclass + # might not be mapped, but a subclass + # may be. + mappers = [] + stack = list(cls.__subclasses__()) + while stack: + klass = stack.pop() + stack.extend(klass.__subclasses__()) + mn = _mapper_or_none(klass) + if mn is not None: + mappers.append(mn) + + discriminator_name = ( + getattr(cls, "_concrete_discriminator_name", None) or "type" + ) + pjoin = cls._create_polymorphic_union(mappers, discriminator_name) + + # For columns that were declared on the class, these + # are normally ignored with the "__no_table__" mapping, + # unless they have a different attribute key vs. col name + # and are in the properties argument. + # In that case, ensure we update the properties entry + # to the correct column from the pjoin target table. + declared_cols = set(to_map.declared_columns) + declared_col_keys = {c.key for c in declared_cols} + for k, v in list(to_map.properties.items()): + if v in declared_cols: + to_map.properties[k] = pjoin.c[v.key] + declared_col_keys.remove(v.key) + + to_map.local_table = pjoin + + strict_attrs = cls.__dict__.get("strict_attrs", False) + + m_args = to_map.mapper_args_fn or dict + + def mapper_args(): + args = m_args() + args["polymorphic_on"] = pjoin.c[discriminator_name] + args["polymorphic_abstract"] = True + if strict_attrs: + args["include_properties"] = ( + set(pjoin.primary_key) + | declared_col_keys + | {discriminator_name} + ) + args["with_polymorphic"] = ("*", pjoin) + return args + + to_map.mapper_args_fn = mapper_args + + to_map.map() + + stack = [cls] + while stack: + scls = stack.pop(0) + stack.extend(scls.__subclasses__()) + sm = _mapper_or_none(scls) + if sm and sm.concrete and sm.inherits is None: + for sup_ in scls.__mro__[1:]: + sup_sm = _mapper_or_none(sup_) + if sup_sm: + sm._set_concrete_base(sup_sm) + break + + @classmethod + def _sa_raise_deferred_config(cls): + raise orm_exc.UnmappedClassError( + cls, + msg="Class %s is a subclass of AbstractConcreteBase and " + "has a mapping pending until all subclasses are defined. " + "Call the sqlalchemy.orm.configure_mappers() function after " + "all subclasses have been defined to " + "complete the mapping of this class." + % orm_exc._safe_cls_name(cls), + ) + + +class DeferredReflection: + """A helper class for construction of mappings based on + a deferred reflection step. + + Normally, declarative can be used with reflection by + setting a :class:`_schema.Table` object using autoload_with=engine + as the ``__table__`` attribute on a declarative class. + The caveat is that the :class:`_schema.Table` must be fully + reflected, or at the very least have a primary key column, + at the point at which a normal declarative mapping is + constructed, meaning the :class:`_engine.Engine` must be available + at class declaration time. + + The :class:`.DeferredReflection` mixin moves the construction + of mappers to be at a later point, after a specific + method is called which first reflects all :class:`_schema.Table` + objects created so far. Classes can define it as such:: + + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.declarative import DeferredReflection + Base = declarative_base() + + class MyClass(DeferredReflection, Base): + __tablename__ = 'mytable' + + Above, ``MyClass`` is not yet mapped. After a series of + classes have been defined in the above fashion, all tables + can be reflected and mappings created using + :meth:`.prepare`:: + + engine = create_engine("someengine://...") + DeferredReflection.prepare(engine) + + The :class:`.DeferredReflection` mixin can be applied to individual + classes, used as the base for the declarative base itself, + or used in a custom abstract class. Using an abstract base + allows that only a subset of classes to be prepared for a + particular prepare step, which is necessary for applications + that use more than one engine. For example, if an application + has two engines, you might use two bases, and prepare each + separately, e.g.:: + + class ReflectedOne(DeferredReflection, Base): + __abstract__ = True + + class ReflectedTwo(DeferredReflection, Base): + __abstract__ = True + + class MyClass(ReflectedOne): + __tablename__ = 'mytable' + + class MyOtherClass(ReflectedOne): + __tablename__ = 'myothertable' + + class YetAnotherClass(ReflectedTwo): + __tablename__ = 'yetanothertable' + + # ... etc. + + Above, the class hierarchies for ``ReflectedOne`` and + ``ReflectedTwo`` can be configured separately:: + + ReflectedOne.prepare(engine_one) + ReflectedTwo.prepare(engine_two) + + .. seealso:: + + :ref:`orm_declarative_reflected_deferred_reflection` - in the + :ref:`orm_declarative_table_config_toplevel` section. + + """ + + @classmethod + def prepare( + cls, bind: Union[Engine, Connection], **reflect_kw: Any + ) -> None: + r"""Reflect all :class:`_schema.Table` objects for all current + :class:`.DeferredReflection` subclasses + + :param bind: :class:`_engine.Engine` or :class:`_engine.Connection` + instance + + ..versionchanged:: 2.0.16 a :class:`_engine.Connection` is also + accepted. + + :param \**reflect_kw: additional keyword arguments passed to + :meth:`_schema.MetaData.reflect`, such as + :paramref:`_schema.MetaData.reflect.views`. + + .. versionadded:: 2.0.16 + + """ + + to_map = _DeferredMapperConfig.classes_for_base(cls) + + metadata_to_table = collections.defaultdict(set) + + # first collect the primary __table__ for each class into a + # collection of metadata/schemaname -> table names + for thingy in to_map: + if thingy.local_table is not None: + metadata_to_table[ + (thingy.local_table.metadata, thingy.local_table.schema) + ].add(thingy.local_table.name) + + # then reflect all those tables into their metadatas + + if isinstance(bind, Connection): + conn = bind + ctx = contextlib.nullcontext(enter_result=conn) + elif isinstance(bind, Engine): + ctx = bind.connect() + else: + raise sa_exc.ArgumentError( + f"Expected Engine or Connection, got {bind!r}" + ) + + with ctx as conn: + for (metadata, schema), table_names in metadata_to_table.items(): + metadata.reflect( + conn, + only=table_names, + schema=schema, + extend_existing=True, + autoload_replace=False, + **reflect_kw, + ) + + metadata_to_table.clear() + + # .map() each class, then go through relationships and look + # for secondary + for thingy in to_map: + thingy.map() + + mapper = thingy.cls.__mapper__ + metadata = mapper.class_.metadata + + for rel in mapper._props.values(): + if ( + isinstance(rel, relationships.RelationshipProperty) + and rel._init_args.secondary._is_populated() + ): + secondary_arg = rel._init_args.secondary + + if isinstance(secondary_arg.argument, Table): + secondary_table = secondary_arg.argument + metadata_to_table[ + ( + secondary_table.metadata, + secondary_table.schema, + ) + ].add(secondary_table.name) + elif isinstance(secondary_arg.argument, str): + _, resolve_arg = _resolver(rel.parent.class_, rel) + + resolver = resolve_arg( + secondary_arg.argument, True + ) + metadata_to_table[ + (metadata, thingy.local_table.schema) + ].add(secondary_arg.argument) + + resolver._resolvers += ( + cls._sa_deferred_table_resolver(metadata), + ) + + secondary_arg.argument = resolver() + + for (metadata, schema), table_names in metadata_to_table.items(): + metadata.reflect( + conn, + only=table_names, + schema=schema, + extend_existing=True, + autoload_replace=False, + ) + + @classmethod + def _sa_deferred_table_resolver( + cls, metadata: MetaData + ) -> Callable[[str], Table]: + def _resolve(key: str) -> Table: + # reflection has already occurred so this Table would have + # its contents already + return Table(key, metadata) + + return _resolve + + _sa_decl_prepare = True + + @classmethod + def _sa_raise_deferred_config(cls): + raise orm_exc.UnmappedClassError( + cls, + msg="Class %s is a subclass of DeferredReflection. " + "Mappings are not produced until the .prepare() " + "method is called on the class hierarchy." + % orm_exc._safe_cls_name(cls), + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/horizontal_shard.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/horizontal_shard.py new file mode 100644 index 0000000..d8ee819 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/horizontal_shard.py @@ -0,0 +1,481 @@ +# ext/horizontal_shard.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Horizontal sharding support. + +Defines a rudimental 'horizontal sharding' system which allows a Session to +distribute queries and persistence operations across multiple databases. + +For a usage example, see the :ref:`examples_sharding` example included in +the source distribution. + +.. deepalchemy:: The horizontal sharding extension is an advanced feature, + involving a complex statement -> database interaction as well as + use of semi-public APIs for non-trivial cases. Simpler approaches to + refering to multiple database "shards", most commonly using a distinct + :class:`_orm.Session` per "shard", should always be considered first + before using this more complex and less-production-tested system. + + + +""" +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .. import event +from .. import exc +from .. import inspect +from .. import util +from ..orm import PassiveFlag +from ..orm._typing import OrmExecuteOptionsParameter +from ..orm.interfaces import ORMOption +from ..orm.mapper import Mapper +from ..orm.query import Query +from ..orm.session import _BindArguments +from ..orm.session import _PKIdentityArgument +from ..orm.session import Session +from ..util.typing import Protocol +from ..util.typing import Self + +if TYPE_CHECKING: + from ..engine.base import Connection + from ..engine.base import Engine + from ..engine.base import OptionEngine + from ..engine.result import IteratorResult + from ..engine.result import Result + from ..orm import LoaderCallableStatus + from ..orm._typing import _O + from ..orm.bulk_persistence import BulkUDCompileState + from ..orm.context import QueryContext + from ..orm.session import _EntityBindKey + from ..orm.session import _SessionBind + from ..orm.session import ORMExecuteState + from ..orm.state import InstanceState + from ..sql import Executable + from ..sql._typing import _TP + from ..sql.elements import ClauseElement + +__all__ = ["ShardedSession", "ShardedQuery"] + +_T = TypeVar("_T", bound=Any) + + +ShardIdentifier = str + + +class ShardChooser(Protocol): + def __call__( + self, + mapper: Optional[Mapper[_T]], + instance: Any, + clause: Optional[ClauseElement], + ) -> Any: ... + + +class IdentityChooser(Protocol): + def __call__( + self, + mapper: Mapper[_T], + primary_key: _PKIdentityArgument, + *, + lazy_loaded_from: Optional[InstanceState[Any]], + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + **kw: Any, + ) -> Any: ... + + +class ShardedQuery(Query[_T]): + """Query class used with :class:`.ShardedSession`. + + .. legacy:: The :class:`.ShardedQuery` is a subclass of the legacy + :class:`.Query` class. The :class:`.ShardedSession` now supports + 2.0 style execution via the :meth:`.ShardedSession.execute` method. + + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + assert isinstance(self.session, ShardedSession) + + self.identity_chooser = self.session.identity_chooser + self.execute_chooser = self.session.execute_chooser + self._shard_id = None + + def set_shard(self, shard_id: ShardIdentifier) -> Self: + """Return a new query, limited to a single shard ID. + + All subsequent operations with the returned query will + be against the single shard regardless of other state. + + The shard_id can be passed for a 2.0 style execution to the + bind_arguments dictionary of :meth:`.Session.execute`:: + + results = session.execute( + stmt, + bind_arguments={"shard_id": "my_shard"} + ) + + """ + return self.execution_options(_sa_shard_id=shard_id) + + +class ShardedSession(Session): + shard_chooser: ShardChooser + identity_chooser: IdentityChooser + execute_chooser: Callable[[ORMExecuteState], Iterable[Any]] + + def __init__( + self, + shard_chooser: ShardChooser, + identity_chooser: Optional[IdentityChooser] = None, + execute_chooser: Optional[ + Callable[[ORMExecuteState], Iterable[Any]] + ] = None, + shards: Optional[Dict[str, Any]] = None, + query_cls: Type[Query[_T]] = ShardedQuery, + *, + id_chooser: Optional[ + Callable[[Query[_T], Iterable[_T]], Iterable[Any]] + ] = None, + query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None, + **kwargs: Any, + ) -> None: + """Construct a ShardedSession. + + :param shard_chooser: A callable which, passed a Mapper, a mapped + instance, and possibly a SQL clause, returns a shard ID. This id + may be based off of the attributes present within the object, or on + some round-robin scheme. If the scheme is based on a selection, it + should set whatever state on the instance to mark it in the future as + participating in that shard. + + :param identity_chooser: A callable, passed a Mapper and primary key + argument, which should return a list of shard ids where this + primary key might reside. + + .. versionchanged:: 2.0 The ``identity_chooser`` parameter + supersedes the ``id_chooser`` parameter. + + :param execute_chooser: For a given :class:`.ORMExecuteState`, + returns the list of shard_ids + where the query should be issued. Results from all shards returned + will be combined together into a single listing. + + .. versionchanged:: 1.4 The ``execute_chooser`` parameter + supersedes the ``query_chooser`` parameter. + + :param shards: A dictionary of string shard names + to :class:`~sqlalchemy.engine.Engine` objects. + + """ + super().__init__(query_cls=query_cls, **kwargs) + + event.listen( + self, "do_orm_execute", execute_and_instances, retval=True + ) + self.shard_chooser = shard_chooser + + if id_chooser: + _id_chooser = id_chooser + util.warn_deprecated( + "The ``id_chooser`` parameter is deprecated; " + "please use ``identity_chooser``.", + "2.0", + ) + + def _legacy_identity_chooser( + mapper: Mapper[_T], + primary_key: _PKIdentityArgument, + *, + lazy_loaded_from: Optional[InstanceState[Any]], + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + **kw: Any, + ) -> Any: + q = self.query(mapper) + if lazy_loaded_from: + q = q._set_lazyload_from(lazy_loaded_from) + return _id_chooser(q, primary_key) + + self.identity_chooser = _legacy_identity_chooser + elif identity_chooser: + self.identity_chooser = identity_chooser + else: + raise exc.ArgumentError( + "identity_chooser or id_chooser is required" + ) + + if query_chooser: + _query_chooser = query_chooser + util.warn_deprecated( + "The ``query_chooser`` parameter is deprecated; " + "please use ``execute_chooser``.", + "1.4", + ) + if execute_chooser: + raise exc.ArgumentError( + "Can't pass query_chooser and execute_chooser " + "at the same time." + ) + + def _default_execute_chooser( + orm_context: ORMExecuteState, + ) -> Iterable[Any]: + return _query_chooser(orm_context.statement) + + if execute_chooser is None: + execute_chooser = _default_execute_chooser + + if execute_chooser is None: + raise exc.ArgumentError( + "execute_chooser or query_chooser is required" + ) + self.execute_chooser = execute_chooser + self.__shards: Dict[ShardIdentifier, _SessionBind] = {} + if shards is not None: + for k in shards: + self.bind_shard(k, shards[k]) + + def _identity_lookup( + self, + mapper: Mapper[_O], + primary_key_identity: Union[Any, Tuple[Any, ...]], + identity_token: Optional[Any] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + lazy_loaded_from: Optional[InstanceState[Any]] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Union[Optional[_O], LoaderCallableStatus]: + """override the default :meth:`.Session._identity_lookup` method so + that we search for a given non-token primary key identity across all + possible identity tokens (e.g. shard ids). + + .. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from + the :class:`_query.Query` object to the :class:`.Session`. + + """ + + if identity_token is not None: + obj = super()._identity_lookup( + mapper, + primary_key_identity, + identity_token=identity_token, + **kw, + ) + + return obj + else: + for shard_id in self.identity_chooser( + mapper, + primary_key_identity, + lazy_loaded_from=lazy_loaded_from, + execution_options=execution_options, + bind_arguments=dict(bind_arguments) if bind_arguments else {}, + ): + obj2 = super()._identity_lookup( + mapper, + primary_key_identity, + identity_token=shard_id, + lazy_loaded_from=lazy_loaded_from, + **kw, + ) + if obj2 is not None: + return obj2 + + return None + + def _choose_shard_and_assign( + self, + mapper: Optional[_EntityBindKey[_O]], + instance: Any, + **kw: Any, + ) -> Any: + if instance is not None: + state = inspect(instance) + if state.key: + token = state.key[2] + assert token is not None + return token + elif state.identity_token: + return state.identity_token + + assert isinstance(mapper, Mapper) + shard_id = self.shard_chooser(mapper, instance, **kw) + if instance is not None: + state.identity_token = shard_id + return shard_id + + def connection_callable( # type: ignore [override] + self, + mapper: Optional[Mapper[_T]] = None, + instance: Optional[Any] = None, + shard_id: Optional[ShardIdentifier] = None, + **kw: Any, + ) -> Connection: + """Provide a :class:`_engine.Connection` to use in the unit of work + flush process. + + """ + + if shard_id is None: + shard_id = self._choose_shard_and_assign(mapper, instance) + + if self.in_transaction(): + trans = self.get_transaction() + assert trans is not None + return trans.connection(mapper, shard_id=shard_id) + else: + bind = self.get_bind( + mapper=mapper, shard_id=shard_id, instance=instance + ) + + if isinstance(bind, Engine): + return bind.connect(**kw) + else: + assert isinstance(bind, Connection) + return bind + + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + *, + shard_id: Optional[ShardIdentifier] = None, + instance: Optional[Any] = None, + clause: Optional[ClauseElement] = None, + **kw: Any, + ) -> _SessionBind: + if shard_id is None: + shard_id = self._choose_shard_and_assign( + mapper, instance=instance, clause=clause + ) + assert shard_id is not None + return self.__shards[shard_id] + + def bind_shard( + self, shard_id: ShardIdentifier, bind: Union[Engine, OptionEngine] + ) -> None: + self.__shards[shard_id] = bind + + +class set_shard_id(ORMOption): + """a loader option for statements to apply a specific shard id to the + primary query as well as for additional relationship and column + loaders. + + The :class:`_horizontal.set_shard_id` option may be applied using + the :meth:`_sql.Executable.options` method of any executable statement:: + + stmt = ( + select(MyObject). + where(MyObject.name == 'some name'). + options(set_shard_id("shard1")) + ) + + Above, the statement when invoked will limit to the "shard1" shard + identifier for the primary query as well as for all relationship and + column loading strategies, including eager loaders such as + :func:`_orm.selectinload`, deferred column loaders like :func:`_orm.defer`, + and the lazy relationship loader :func:`_orm.lazyload`. + + In this way, the :class:`_horizontal.set_shard_id` option has much wider + scope than using the "shard_id" argument within the + :paramref:`_orm.Session.execute.bind_arguments` dictionary. + + + .. versionadded:: 2.0.0 + + """ + + __slots__ = ("shard_id", "propagate_to_loaders") + + def __init__( + self, shard_id: ShardIdentifier, propagate_to_loaders: bool = True + ): + """Construct a :class:`_horizontal.set_shard_id` option. + + :param shard_id: shard identifier + :param propagate_to_loaders: if left at its default of ``True``, the + shard option will take place for lazy loaders such as + :func:`_orm.lazyload` and :func:`_orm.defer`; if False, the option + will not be propagated to loaded objects. Note that :func:`_orm.defer` + always limits to the shard_id of the parent row in any case, so the + parameter only has a net effect on the behavior of the + :func:`_orm.lazyload` strategy. + + """ + self.shard_id = shard_id + self.propagate_to_loaders = propagate_to_loaders + + +def execute_and_instances( + orm_context: ORMExecuteState, +) -> Union[Result[_T], IteratorResult[_TP]]: + active_options: Union[ + None, + QueryContext.default_load_options, + Type[QueryContext.default_load_options], + BulkUDCompileState.default_update_options, + Type[BulkUDCompileState.default_update_options], + ] + + if orm_context.is_select: + active_options = orm_context.load_options + + elif orm_context.is_update or orm_context.is_delete: + active_options = orm_context.update_delete_options + else: + active_options = None + + session = orm_context.session + assert isinstance(session, ShardedSession) + + def iter_for_shard( + shard_id: ShardIdentifier, + ) -> Union[Result[_T], IteratorResult[_TP]]: + bind_arguments = dict(orm_context.bind_arguments) + bind_arguments["shard_id"] = shard_id + + orm_context.update_execution_options(identity_token=shard_id) + return orm_context.invoke_statement(bind_arguments=bind_arguments) + + for orm_opt in orm_context._non_compile_orm_options: + # TODO: if we had an ORMOption that gets applied at ORM statement + # execution time, that would allow this to be more generalized. + # for now just iterate and look for our options + if isinstance(orm_opt, set_shard_id): + shard_id = orm_opt.shard_id + break + else: + if active_options and active_options._identity_token is not None: + shard_id = active_options._identity_token + elif "_sa_shard_id" in orm_context.execution_options: + shard_id = orm_context.execution_options["_sa_shard_id"] + elif "shard_id" in orm_context.bind_arguments: + shard_id = orm_context.bind_arguments["shard_id"] + else: + shard_id = None + + if shard_id is not None: + return iter_for_shard(shard_id) + else: + partial = [] + for shard_id in session.execute_chooser(orm_context): + result_ = iter_for_shard(shard_id) + partial.append(result_) + return partial[0].merge(*partial[1:]) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/hybrid.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/hybrid.py new file mode 100644 index 0000000..25b74d8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/hybrid.py @@ -0,0 +1,1514 @@ +# ext/hybrid.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r"""Define attributes on ORM-mapped classes that have "hybrid" behavior. + +"hybrid" means the attribute has distinct behaviors defined at the +class level and at the instance level. + +The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of +method decorator and has minimal dependencies on the rest of SQLAlchemy. +Its basic theory of operation can work with any descriptor-based expression +system. + +Consider a mapping ``Interval``, representing integer ``start`` and ``end`` +values. We can define higher level functions on mapped classes that produce SQL +expressions at the class level, and Python expression evaluation at the +instance level. Below, each function decorated with :class:`.hybrid_method` or +:class:`.hybrid_property` may receive ``self`` as an instance of the class, or +may receive the class directly, depending on context:: + + from __future__ import annotations + + from sqlalchemy.ext.hybrid import hybrid_method + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + + class Base(DeclarativeBase): + pass + + class Interval(Base): + __tablename__ = 'interval' + + id: Mapped[int] = mapped_column(primary_key=True) + start: Mapped[int] + end: Mapped[int] + + def __init__(self, start: int, end: int): + self.start = start + self.end = end + + @hybrid_property + def length(self) -> int: + return self.end - self.start + + @hybrid_method + def contains(self, point: int) -> bool: + return (self.start <= point) & (point <= self.end) + + @hybrid_method + def intersects(self, other: Interval) -> bool: + return self.contains(other.start) | self.contains(other.end) + + +Above, the ``length`` property returns the difference between the +``end`` and ``start`` attributes. With an instance of ``Interval``, +this subtraction occurs in Python, using normal Python descriptor +mechanics:: + + >>> i1 = Interval(5, 10) + >>> i1.length + 5 + +When dealing with the ``Interval`` class itself, the :class:`.hybrid_property` +descriptor evaluates the function body given the ``Interval`` class as +the argument, which when evaluated with SQLAlchemy expression mechanics +returns a new SQL expression: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(Interval.length)) + {printsql}SELECT interval."end" - interval.start AS length + FROM interval{stop} + + + >>> print(select(Interval).filter(Interval.length > 10)) + {printsql}SELECT interval.id, interval.start, interval."end" + FROM interval + WHERE interval."end" - interval.start > :param_1 + +Filtering methods such as :meth:`.Select.filter_by` are supported +with hybrid attributes as well: + +.. sourcecode:: pycon+sql + + >>> print(select(Interval).filter_by(length=5)) + {printsql}SELECT interval.id, interval.start, interval."end" + FROM interval + WHERE interval."end" - interval.start = :param_1 + +The ``Interval`` class example also illustrates two methods, +``contains()`` and ``intersects()``, decorated with +:class:`.hybrid_method`. This decorator applies the same idea to +methods that :class:`.hybrid_property` applies to attributes. The +methods return boolean values, and take advantage of the Python ``|`` +and ``&`` bitwise operators to produce equivalent instance-level and +SQL expression-level boolean behavior: + +.. sourcecode:: pycon+sql + + >>> i1.contains(6) + True + >>> i1.contains(15) + False + >>> i1.intersects(Interval(7, 18)) + True + >>> i1.intersects(Interval(25, 29)) + False + + >>> print(select(Interval).filter(Interval.contains(15))) + {printsql}SELECT interval.id, interval.start, interval."end" + FROM interval + WHERE interval.start <= :start_1 AND interval."end" > :end_1{stop} + + >>> ia = aliased(Interval) + >>> print(select(Interval, ia).filter(Interval.intersects(ia))) + {printsql}SELECT interval.id, interval.start, + interval."end", interval_1.id AS interval_1_id, + interval_1.start AS interval_1_start, interval_1."end" AS interval_1_end + FROM interval, interval AS interval_1 + WHERE interval.start <= interval_1.start + AND interval."end" > interval_1.start + OR interval.start <= interval_1."end" + AND interval."end" > interval_1."end"{stop} + +.. _hybrid_distinct_expression: + +Defining Expression Behavior Distinct from Attribute Behavior +-------------------------------------------------------------- + +In the previous section, our usage of the ``&`` and ``|`` bitwise operators +within the ``Interval.contains`` and ``Interval.intersects`` methods was +fortunate, considering our functions operated on two boolean values to return a +new one. In many cases, the construction of an in-Python function and a +SQLAlchemy SQL expression have enough differences that two separate Python +expressions should be defined. The :mod:`~sqlalchemy.ext.hybrid` decorator +defines a **modifier** :meth:`.hybrid_property.expression` for this purpose. As an +example we'll define the radius of the interval, which requires the usage of +the absolute value function:: + + from sqlalchemy import ColumnElement + from sqlalchemy import Float + from sqlalchemy import func + from sqlalchemy import type_coerce + + class Interval(Base): + # ... + + @hybrid_property + def radius(self) -> float: + return abs(self.length) / 2 + + @radius.inplace.expression + @classmethod + def _radius_expression(cls) -> ColumnElement[float]: + return type_coerce(func.abs(cls.length) / 2, Float) + +In the above example, the :class:`.hybrid_property` first assigned to the +name ``Interval.radius`` is amended by a subsequent method called +``Interval._radius_expression``, using the decorator +``@radius.inplace.expression``, which chains together two modifiers +:attr:`.hybrid_property.inplace` and :attr:`.hybrid_property.expression`. +The use of :attr:`.hybrid_property.inplace` indicates that the +:meth:`.hybrid_property.expression` modifier should mutate the +existing hybrid object at ``Interval.radius`` in place, without creating a +new object. Notes on this modifier and its +rationale are discussed in the next section :ref:`hybrid_pep484_naming`. +The use of ``@classmethod`` is optional, and is strictly to give typing +tools a hint that ``cls`` in this case is expected to be the ``Interval`` +class, and not an instance of ``Interval``. + +.. note:: :attr:`.hybrid_property.inplace` as well as the use of ``@classmethod`` + for proper typing support are available as of SQLAlchemy 2.0.4, and will + not work in earlier versions. + +With ``Interval.radius`` now including an expression element, the SQL +function ``ABS()`` is returned when accessing ``Interval.radius`` +at the class level: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(Interval).filter(Interval.radius > 5)) + {printsql}SELECT interval.id, interval.start, interval."end" + FROM interval + WHERE abs(interval."end" - interval.start) / :abs_1 > :param_1 + + +.. _hybrid_pep484_naming: + +Using ``inplace`` to create pep-484 compliant hybrid properties +--------------------------------------------------------------- + +In the previous section, a :class:`.hybrid_property` decorator is illustrated +which includes two separate method-level functions being decorated, both +to produce a single object attribute referenced as ``Interval.radius``. +There are actually several different modifiers we can use for +:class:`.hybrid_property` including :meth:`.hybrid_property.expression`, +:meth:`.hybrid_property.setter` and :meth:`.hybrid_property.update_expression`. + +SQLAlchemy's :class:`.hybrid_property` decorator intends that adding on these +methods may be done in the identical manner as Python's built-in +``@property`` decorator, where idiomatic use is to continue to redefine the +attribute repeatedly, using the **same attribute name** each time, as in the +example below that illustrates the use of :meth:`.hybrid_property.setter` and +:meth:`.hybrid_property.expression` for the ``Interval.radius`` descriptor:: + + # correct use, however is not accepted by pep-484 tooling + + class Interval(Base): + # ... + + @hybrid_property + def radius(self): + return abs(self.length) / 2 + + @radius.setter + def radius(self, value): + self.length = value * 2 + + @radius.expression + def radius(cls): + return type_coerce(func.abs(cls.length) / 2, Float) + +Above, there are three ``Interval.radius`` methods, but as each are decorated, +first by the :class:`.hybrid_property` decorator and then by the +``@radius`` name itself, the end effect is that ``Interval.radius`` is +a single attribute with three different functions contained within it. +This style of use is taken from `Python's documented use of @property +<https://docs.python.org/3/library/functions.html#property>`_. +It is important to note that the way both ``@property`` as well as +:class:`.hybrid_property` work, a **copy of the descriptor is made each time**. +That is, each call to ``@radius.expression``, ``@radius.setter`` etc. +make a new object entirely. This allows the attribute to be re-defined in +subclasses without issue (see :ref:`hybrid_reuse_subclass` later in this +section for how this is used). + +However, the above approach is not compatible with typing tools such as +mypy and pyright. Python's own ``@property`` decorator does not have this +limitation only because +`these tools hardcode the behavior of @property +<https://github.com/python/typing/discussions/1102>`_, meaning this syntax +is not available to SQLAlchemy under :pep:`484` compliance. + +In order to produce a reasonable syntax while remaining typing compliant, +the :attr:`.hybrid_property.inplace` decorator allows the same +decorator to be re-used with different method names, while still producing +a single decorator under one name:: + + # correct use which is also accepted by pep-484 tooling + + class Interval(Base): + # ... + + @hybrid_property + def radius(self) -> float: + return abs(self.length) / 2 + + @radius.inplace.setter + def _radius_setter(self, value: float) -> None: + # for example only + self.length = value * 2 + + @radius.inplace.expression + @classmethod + def _radius_expression(cls) -> ColumnElement[float]: + return type_coerce(func.abs(cls.length) / 2, Float) + +Using :attr:`.hybrid_property.inplace` further qualifies the use of the +decorator that a new copy should not be made, thereby maintaining the +``Interval.radius`` name while allowing additional methods +``Interval._radius_setter`` and ``Interval._radius_expression`` to be +differently named. + + +.. versionadded:: 2.0.4 Added :attr:`.hybrid_property.inplace` to allow + less verbose construction of composite :class:`.hybrid_property` objects + while not having to use repeated method names. Additionally allowed the + use of ``@classmethod`` within :attr:`.hybrid_property.expression`, + :attr:`.hybrid_property.update_expression`, and + :attr:`.hybrid_property.comparator` to allow typing tools to identify + ``cls`` as a class and not an instance in the method signature. + + +Defining Setters +---------------- + +The :meth:`.hybrid_property.setter` modifier allows the construction of a +custom setter method, that can modify values on the object:: + + class Interval(Base): + # ... + + @hybrid_property + def length(self) -> int: + return self.end - self.start + + @length.inplace.setter + def _length_setter(self, value: int) -> None: + self.end = self.start + value + +The ``length(self, value)`` method is now called upon set:: + + >>> i1 = Interval(5, 10) + >>> i1.length + 5 + >>> i1.length = 12 + >>> i1.end + 17 + +.. _hybrid_bulk_update: + +Allowing Bulk ORM Update +------------------------ + +A hybrid can define a custom "UPDATE" handler for when using +ORM-enabled updates, allowing the hybrid to be used in the +SET clause of the update. + +Normally, when using a hybrid with :func:`_sql.update`, the SQL +expression is used as the column that's the target of the SET. If our +``Interval`` class had a hybrid ``start_point`` that linked to +``Interval.start``, this could be substituted directly:: + + from sqlalchemy import update + stmt = update(Interval).values({Interval.start_point: 10}) + +However, when using a composite hybrid like ``Interval.length``, this +hybrid represents more than one column. We can set up a handler that will +accommodate a value passed in the VALUES expression which can affect +this, using the :meth:`.hybrid_property.update_expression` decorator. +A handler that works similarly to our setter would be:: + + from typing import List, Tuple, Any + + class Interval(Base): + # ... + + @hybrid_property + def length(self) -> int: + return self.end - self.start + + @length.inplace.setter + def _length_setter(self, value: int) -> None: + self.end = self.start + value + + @length.inplace.update_expression + def _length_update_expression(cls, value: Any) -> List[Tuple[Any, Any]]: + return [ + (cls.end, cls.start + value) + ] + +Above, if we use ``Interval.length`` in an UPDATE expression, we get +a hybrid SET expression: + +.. sourcecode:: pycon+sql + + + >>> from sqlalchemy import update + >>> print(update(Interval).values({Interval.length: 25})) + {printsql}UPDATE interval SET "end"=(interval.start + :start_1) + +This SET expression is accommodated by the ORM automatically. + +.. seealso:: + + :ref:`orm_expression_update_delete` - includes background on ORM-enabled + UPDATE statements + + +Working with Relationships +-------------------------- + +There's no essential difference when creating hybrids that work with +related objects as opposed to column-based data. The need for distinct +expressions tends to be greater. The two variants we'll illustrate +are the "join-dependent" hybrid, and the "correlated subquery" hybrid. + +Join-Dependent Relationship Hybrid +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Consider the following declarative +mapping which relates a ``User`` to a ``SavingsAccount``:: + + from __future__ import annotations + + from decimal import Decimal + from typing import cast + from typing import List + from typing import Optional + + from sqlalchemy import ForeignKey + from sqlalchemy import Numeric + from sqlalchemy import String + from sqlalchemy import SQLColumnExpression + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy.orm import relationship + + + class Base(DeclarativeBase): + pass + + + class SavingsAccount(Base): + __tablename__ = 'account' + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey('user.id')) + balance: Mapped[Decimal] = mapped_column(Numeric(15, 5)) + + owner: Mapped[User] = relationship(back_populates="accounts") + + class User(Base): + __tablename__ = 'user' + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + + accounts: Mapped[List[SavingsAccount]] = relationship( + back_populates="owner", lazy="selectin" + ) + + @hybrid_property + def balance(self) -> Optional[Decimal]: + if self.accounts: + return self.accounts[0].balance + else: + return None + + @balance.inplace.setter + def _balance_setter(self, value: Optional[Decimal]) -> None: + assert value is not None + + if not self.accounts: + account = SavingsAccount(owner=self) + else: + account = self.accounts[0] + account.balance = value + + @balance.inplace.expression + @classmethod + def _balance_expression(cls) -> SQLColumnExpression[Optional[Decimal]]: + return cast("SQLColumnExpression[Optional[Decimal]]", SavingsAccount.balance) + +The above hybrid property ``balance`` works with the first +``SavingsAccount`` entry in the list of accounts for this user. The +in-Python getter/setter methods can treat ``accounts`` as a Python +list available on ``self``. + +.. tip:: The ``User.balance`` getter in the above example accesses the + ``self.acccounts`` collection, which will normally be loaded via the + :func:`.selectinload` loader strategy configured on the ``User.balance`` + :func:`_orm.relationship`. The default loader strategy when not otherwise + stated on :func:`_orm.relationship` is :func:`.lazyload`, which emits SQL on + demand. When using asyncio, on-demand loaders such as :func:`.lazyload` are + not supported, so care should be taken to ensure the ``self.accounts`` + collection is accessible to this hybrid accessor when using asyncio. + +At the expression level, it's expected that the ``User`` class will +be used in an appropriate context such that an appropriate join to +``SavingsAccount`` will be present: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(User, User.balance). + ... join(User.accounts).filter(User.balance > 5000)) + {printsql}SELECT "user".id AS user_id, "user".name AS user_name, + account.balance AS account_balance + FROM "user" JOIN account ON "user".id = account.user_id + WHERE account.balance > :balance_1 + +Note however, that while the instance level accessors need to worry +about whether ``self.accounts`` is even present, this issue expresses +itself differently at the SQL expression level, where we basically +would use an outer join: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> from sqlalchemy import or_ + >>> print (select(User, User.balance).outerjoin(User.accounts). + ... filter(or_(User.balance < 5000, User.balance == None))) + {printsql}SELECT "user".id AS user_id, "user".name AS user_name, + account.balance AS account_balance + FROM "user" LEFT OUTER JOIN account ON "user".id = account.user_id + WHERE account.balance < :balance_1 OR account.balance IS NULL + +Correlated Subquery Relationship Hybrid +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We can, of course, forego being dependent on the enclosing query's usage +of joins in favor of the correlated subquery, which can portably be packed +into a single column expression. A correlated subquery is more portable, but +often performs more poorly at the SQL level. Using the same technique +illustrated at :ref:`mapper_column_property_sql_expressions`, +we can adjust our ``SavingsAccount`` example to aggregate the balances for +*all* accounts, and use a correlated subquery for the column expression:: + + from __future__ import annotations + + from decimal import Decimal + from typing import List + + from sqlalchemy import ForeignKey + from sqlalchemy import func + from sqlalchemy import Numeric + from sqlalchemy import select + from sqlalchemy import SQLColumnExpression + from sqlalchemy import String + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy.orm import relationship + + + class Base(DeclarativeBase): + pass + + + class SavingsAccount(Base): + __tablename__ = 'account' + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey('user.id')) + balance: Mapped[Decimal] = mapped_column(Numeric(15, 5)) + + owner: Mapped[User] = relationship(back_populates="accounts") + + class User(Base): + __tablename__ = 'user' + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + + accounts: Mapped[List[SavingsAccount]] = relationship( + back_populates="owner", lazy="selectin" + ) + + @hybrid_property + def balance(self) -> Decimal: + return sum((acc.balance for acc in self.accounts), start=Decimal("0")) + + @balance.inplace.expression + @classmethod + def _balance_expression(cls) -> SQLColumnExpression[Decimal]: + return ( + select(func.sum(SavingsAccount.balance)) + .where(SavingsAccount.user_id == cls.id) + .label("total_balance") + ) + + +The above recipe will give us the ``balance`` column which renders +a correlated SELECT: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(User).filter(User.balance > 400)) + {printsql}SELECT "user".id, "user".name + FROM "user" + WHERE ( + SELECT sum(account.balance) AS sum_1 FROM account + WHERE account.user_id = "user".id + ) > :param_1 + + +.. _hybrid_custom_comparators: + +Building Custom Comparators +--------------------------- + +The hybrid property also includes a helper that allows construction of +custom comparators. A comparator object allows one to customize the +behavior of each SQLAlchemy expression operator individually. They +are useful when creating custom types that have some highly +idiosyncratic behavior on the SQL side. + +.. note:: The :meth:`.hybrid_property.comparator` decorator introduced + in this section **replaces** the use of the + :meth:`.hybrid_property.expression` decorator. + They cannot be used together. + +The example class below allows case-insensitive comparisons on the attribute +named ``word_insensitive``:: + + from __future__ import annotations + + from typing import Any + + from sqlalchemy import ColumnElement + from sqlalchemy import func + from sqlalchemy.ext.hybrid import Comparator + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + class Base(DeclarativeBase): + pass + + + class CaseInsensitiveComparator(Comparator[str]): + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return func.lower(self.__clause_element__()) == func.lower(other) + + class SearchWord(Base): + __tablename__ = 'searchword' + + id: Mapped[int] = mapped_column(primary_key=True) + word: Mapped[str] + + @hybrid_property + def word_insensitive(self) -> str: + return self.word.lower() + + @word_insensitive.inplace.comparator + @classmethod + def _word_insensitive_comparator(cls) -> CaseInsensitiveComparator: + return CaseInsensitiveComparator(cls.word) + +Above, SQL expressions against ``word_insensitive`` will apply the ``LOWER()`` +SQL function to both sides: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(SearchWord).filter_by(word_insensitive="Trucks")) + {printsql}SELECT searchword.id, searchword.word + FROM searchword + WHERE lower(searchword.word) = lower(:lower_1) + + +The ``CaseInsensitiveComparator`` above implements part of the +:class:`.ColumnOperators` interface. A "coercion" operation like +lowercasing can be applied to all comparison operations (i.e. ``eq``, +``lt``, ``gt``, etc.) using :meth:`.Operators.operate`:: + + class CaseInsensitiveComparator(Comparator): + def operate(self, op, other, **kwargs): + return op( + func.lower(self.__clause_element__()), + func.lower(other), + **kwargs, + ) + +.. _hybrid_reuse_subclass: + +Reusing Hybrid Properties across Subclasses +------------------------------------------- + +A hybrid can be referred to from a superclass, to allow modifying +methods like :meth:`.hybrid_property.getter`, :meth:`.hybrid_property.setter` +to be used to redefine those methods on a subclass. This is similar to +how the standard Python ``@property`` object works:: + + class FirstNameOnly(Base): + # ... + + first_name: Mapped[str] + + @hybrid_property + def name(self) -> str: + return self.first_name + + @name.inplace.setter + def _name_setter(self, value: str) -> None: + self.first_name = value + + class FirstNameLastName(FirstNameOnly): + # ... + + last_name: Mapped[str] + + # 'inplace' is not used here; calling getter creates a copy + # of FirstNameOnly.name that is local to FirstNameLastName + @FirstNameOnly.name.getter + def name(self) -> str: + return self.first_name + ' ' + self.last_name + + @name.inplace.setter + def _name_setter(self, value: str) -> None: + self.first_name, self.last_name = value.split(' ', 1) + +Above, the ``FirstNameLastName`` class refers to the hybrid from +``FirstNameOnly.name`` to repurpose its getter and setter for the subclass. + +When overriding :meth:`.hybrid_property.expression` and +:meth:`.hybrid_property.comparator` alone as the first reference to the +superclass, these names conflict with the same-named accessors on the class- +level :class:`.QueryableAttribute` object returned at the class level. To +override these methods when referring directly to the parent class descriptor, +add the special qualifier :attr:`.hybrid_property.overrides`, which will de- +reference the instrumented attribute back to the hybrid object:: + + class FirstNameLastName(FirstNameOnly): + # ... + + last_name: Mapped[str] + + @FirstNameOnly.name.overrides.expression + @classmethod + def name(cls): + return func.concat(cls.first_name, ' ', cls.last_name) + + +Hybrid Value Objects +-------------------- + +Note in our previous example, if we were to compare the ``word_insensitive`` +attribute of a ``SearchWord`` instance to a plain Python string, the plain +Python string would not be coerced to lower case - the +``CaseInsensitiveComparator`` we built, being returned by +``@word_insensitive.comparator``, only applies to the SQL side. + +A more comprehensive form of the custom comparator is to construct a *Hybrid +Value Object*. This technique applies the target value or expression to a value +object which is then returned by the accessor in all cases. The value object +allows control of all operations upon the value as well as how compared values +are treated, both on the SQL expression side as well as the Python value side. +Replacing the previous ``CaseInsensitiveComparator`` class with a new +``CaseInsensitiveWord`` class:: + + class CaseInsensitiveWord(Comparator): + "Hybrid value representing a lower case representation of a word." + + def __init__(self, word): + if isinstance(word, basestring): + self.word = word.lower() + elif isinstance(word, CaseInsensitiveWord): + self.word = word.word + else: + self.word = func.lower(word) + + def operate(self, op, other, **kwargs): + if not isinstance(other, CaseInsensitiveWord): + other = CaseInsensitiveWord(other) + return op(self.word, other.word, **kwargs) + + def __clause_element__(self): + return self.word + + def __str__(self): + return self.word + + key = 'word' + "Label to apply to Query tuple results" + +Above, the ``CaseInsensitiveWord`` object represents ``self.word``, which may +be a SQL function, or may be a Python native. By overriding ``operate()`` and +``__clause_element__()`` to work in terms of ``self.word``, all comparison +operations will work against the "converted" form of ``word``, whether it be +SQL side or Python side. Our ``SearchWord`` class can now deliver the +``CaseInsensitiveWord`` object unconditionally from a single hybrid call:: + + class SearchWord(Base): + __tablename__ = 'searchword' + id: Mapped[int] = mapped_column(primary_key=True) + word: Mapped[str] + + @hybrid_property + def word_insensitive(self) -> CaseInsensitiveWord: + return CaseInsensitiveWord(self.word) + +The ``word_insensitive`` attribute now has case-insensitive comparison behavior +universally, including SQL expression vs. Python expression (note the Python +value is converted to lower case on the Python side here): + +.. sourcecode:: pycon+sql + + >>> print(select(SearchWord).filter_by(word_insensitive="Trucks")) + {printsql}SELECT searchword.id AS searchword_id, searchword.word AS searchword_word + FROM searchword + WHERE lower(searchword.word) = :lower_1 + +SQL expression versus SQL expression: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.orm import aliased + >>> sw1 = aliased(SearchWord) + >>> sw2 = aliased(SearchWord) + >>> print( + ... select(sw1.word_insensitive, sw2.word_insensitive).filter( + ... sw1.word_insensitive > sw2.word_insensitive + ... ) + ... ) + {printsql}SELECT lower(searchword_1.word) AS lower_1, + lower(searchword_2.word) AS lower_2 + FROM searchword AS searchword_1, searchword AS searchword_2 + WHERE lower(searchword_1.word) > lower(searchword_2.word) + +Python only expression:: + + >>> ws1 = SearchWord(word="SomeWord") + >>> ws1.word_insensitive == "sOmEwOrD" + True + >>> ws1.word_insensitive == "XOmEwOrX" + False + >>> print(ws1.word_insensitive) + someword + +The Hybrid Value pattern is very useful for any kind of value that may have +multiple representations, such as timestamps, time deltas, units of +measurement, currencies and encrypted passwords. + +.. seealso:: + + `Hybrids and Value Agnostic Types + <https://techspot.zzzeek.org/2011/10/21/hybrids-and-value-agnostic-types/>`_ + - on the techspot.zzzeek.org blog + + `Value Agnostic Types, Part II + <https://techspot.zzzeek.org/2011/10/29/value-agnostic-types-part-ii/>`_ - + on the techspot.zzzeek.org blog + + +""" # noqa + +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import cast +from typing import Generic +from typing import List +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .. import util +from ..orm import attributes +from ..orm import InspectionAttrExtensionType +from ..orm import interfaces +from ..orm import ORMDescriptor +from ..orm.attributes import QueryableAttribute +from ..sql import roles +from ..sql._typing import is_has_clause_element +from ..sql.elements import ColumnElement +from ..sql.elements import SQLCoreOperations +from ..util.typing import Concatenate +from ..util.typing import Literal +from ..util.typing import ParamSpec +from ..util.typing import Protocol +from ..util.typing import Self + +if TYPE_CHECKING: + from ..orm.interfaces import MapperProperty + from ..orm.util import AliasedInsp + from ..sql import SQLColumnExpression + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _HasClauseElement + from ..sql._typing import _InfoType + from ..sql.operators import OperatorType + +_P = ParamSpec("_P") +_R = TypeVar("_R") +_T = TypeVar("_T", bound=Any) +_TE = TypeVar("_TE", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) +_T_con = TypeVar("_T_con", bound=Any, contravariant=True) + + +class HybridExtensionType(InspectionAttrExtensionType): + HYBRID_METHOD = "HYBRID_METHOD" + """Symbol indicating an :class:`InspectionAttr` that's + of type :class:`.hybrid_method`. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + + .. seealso:: + + :attr:`_orm.Mapper.all_orm_attributes` + + """ + + HYBRID_PROPERTY = "HYBRID_PROPERTY" + """Symbol indicating an :class:`InspectionAttr` that's + of type :class:`.hybrid_method`. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + + .. seealso:: + + :attr:`_orm.Mapper.all_orm_attributes` + + """ + + +class _HybridGetterType(Protocol[_T_co]): + def __call__(s, self: Any) -> _T_co: ... + + +class _HybridSetterType(Protocol[_T_con]): + def __call__(s, self: Any, value: _T_con) -> None: ... + + +class _HybridUpdaterType(Protocol[_T_con]): + def __call__( + s, + cls: Any, + value: Union[_T_con, _ColumnExpressionArgument[_T_con]], + ) -> List[Tuple[_DMLColumnArgument, Any]]: ... + + +class _HybridDeleterType(Protocol[_T_co]): + def __call__(s, self: Any) -> None: ... + + +class _HybridExprCallableType(Protocol[_T_co]): + def __call__( + s, cls: Any + ) -> Union[_HasClauseElement[_T_co], SQLColumnExpression[_T_co]]: ... + + +class _HybridComparatorCallableType(Protocol[_T]): + def __call__(self, cls: Any) -> Comparator[_T]: ... + + +class _HybridClassLevelAccessor(QueryableAttribute[_T]): + """Describe the object returned by a hybrid_property() when + called as a class-level descriptor. + + """ + + if TYPE_CHECKING: + + def getter( + self, fget: _HybridGetterType[_T] + ) -> hybrid_property[_T]: ... + + def setter( + self, fset: _HybridSetterType[_T] + ) -> hybrid_property[_T]: ... + + def deleter( + self, fdel: _HybridDeleterType[_T] + ) -> hybrid_property[_T]: ... + + @property + def overrides(self) -> hybrid_property[_T]: ... + + def update_expression( + self, meth: _HybridUpdaterType[_T] + ) -> hybrid_property[_T]: ... + + +class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): + """A decorator which allows definition of a Python object method with both + instance-level and class-level behavior. + + """ + + is_attribute = True + extension_type = HybridExtensionType.HYBRID_METHOD + + def __init__( + self, + func: Callable[Concatenate[Any, _P], _R], + expr: Optional[ + Callable[Concatenate[Any, _P], SQLCoreOperations[_R]] + ] = None, + ): + """Create a new :class:`.hybrid_method`. + + Usage is typically via decorator:: + + from sqlalchemy.ext.hybrid import hybrid_method + + class SomeClass: + @hybrid_method + def value(self, x, y): + return self._value + x + y + + @value.expression + @classmethod + def value(cls, x, y): + return func.some_function(cls._value, x, y) + + """ + self.func = func + if expr is not None: + self.expression(expr) + else: + self.expression(func) # type: ignore + + @property + def inplace(self) -> Self: + """Return the inplace mutator for this :class:`.hybrid_method`. + + The :class:`.hybrid_method` class already performs "in place" mutation + when the :meth:`.hybrid_method.expression` decorator is called, + so this attribute returns Self. + + .. versionadded:: 2.0.4 + + .. seealso:: + + :ref:`hybrid_pep484_naming` + + """ + return self + + @overload + def __get__( + self, instance: Literal[None], owner: Type[object] + ) -> Callable[_P, SQLCoreOperations[_R]]: ... + + @overload + def __get__( + self, instance: object, owner: Type[object] + ) -> Callable[_P, _R]: ... + + def __get__( + self, instance: Optional[object], owner: Type[object] + ) -> Union[Callable[_P, _R], Callable[_P, SQLCoreOperations[_R]]]: + if instance is None: + return self.expr.__get__(owner, owner) # type: ignore + else: + return self.func.__get__(instance, owner) # type: ignore + + def expression( + self, expr: Callable[Concatenate[Any, _P], SQLCoreOperations[_R]] + ) -> hybrid_method[_P, _R]: + """Provide a modifying decorator that defines a + SQL-expression producing method.""" + + self.expr = expr + if not self.expr.__doc__: + self.expr.__doc__ = self.func.__doc__ + return self + + +def _unwrap_classmethod(meth: _T) -> _T: + if isinstance(meth, classmethod): + return meth.__func__ # type: ignore + else: + return meth + + +class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): + """A decorator which allows definition of a Python descriptor with both + instance-level and class-level behavior. + + """ + + is_attribute = True + extension_type = HybridExtensionType.HYBRID_PROPERTY + + __name__: str + + def __init__( + self, + fget: _HybridGetterType[_T], + fset: Optional[_HybridSetterType[_T]] = None, + fdel: Optional[_HybridDeleterType[_T]] = None, + expr: Optional[_HybridExprCallableType[_T]] = None, + custom_comparator: Optional[Comparator[_T]] = None, + update_expr: Optional[_HybridUpdaterType[_T]] = None, + ): + """Create a new :class:`.hybrid_property`. + + Usage is typically via decorator:: + + from sqlalchemy.ext.hybrid import hybrid_property + + class SomeClass: + @hybrid_property + def value(self): + return self._value + + @value.setter + def value(self, value): + self._value = value + + """ + self.fget = fget + self.fset = fset + self.fdel = fdel + self.expr = _unwrap_classmethod(expr) + self.custom_comparator = _unwrap_classmethod(custom_comparator) + self.update_expr = _unwrap_classmethod(update_expr) + util.update_wrapper(self, fget) + + @overload + def __get__(self, instance: Any, owner: Literal[None]) -> Self: ... + + @overload + def __get__( + self, instance: Literal[None], owner: Type[object] + ) -> _HybridClassLevelAccessor[_T]: ... + + @overload + def __get__(self, instance: object, owner: Type[object]) -> _T: ... + + def __get__( + self, instance: Optional[object], owner: Optional[Type[object]] + ) -> Union[hybrid_property[_T], _HybridClassLevelAccessor[_T], _T]: + if owner is None: + return self + elif instance is None: + return self._expr_comparator(owner) + else: + return self.fget(instance) + + def __set__(self, instance: object, value: Any) -> None: + if self.fset is None: + raise AttributeError("can't set attribute") + self.fset(instance, value) + + def __delete__(self, instance: object) -> None: + if self.fdel is None: + raise AttributeError("can't delete attribute") + self.fdel(instance) + + def _copy(self, **kw: Any) -> hybrid_property[_T]: + defaults = { + key: value + for key, value in self.__dict__.items() + if not key.startswith("_") + } + defaults.update(**kw) + return type(self)(**defaults) + + @property + def overrides(self) -> Self: + """Prefix for a method that is overriding an existing attribute. + + The :attr:`.hybrid_property.overrides` accessor just returns + this hybrid object, which when called at the class level from + a parent class, will de-reference the "instrumented attribute" + normally returned at this level, and allow modifying decorators + like :meth:`.hybrid_property.expression` and + :meth:`.hybrid_property.comparator` + to be used without conflicting with the same-named attributes + normally present on the :class:`.QueryableAttribute`:: + + class SuperClass: + # ... + + @hybrid_property + def foobar(self): + return self._foobar + + class SubClass(SuperClass): + # ... + + @SuperClass.foobar.overrides.expression + def foobar(cls): + return func.subfoobar(self._foobar) + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`hybrid_reuse_subclass` + + """ + return self + + class _InPlace(Generic[_TE]): + """A builder helper for .hybrid_property. + + .. versionadded:: 2.0.4 + + """ + + __slots__ = ("attr",) + + def __init__(self, attr: hybrid_property[_TE]): + self.attr = attr + + def _set(self, **kw: Any) -> hybrid_property[_TE]: + for k, v in kw.items(): + setattr(self.attr, k, _unwrap_classmethod(v)) + return self.attr + + def getter(self, fget: _HybridGetterType[_TE]) -> hybrid_property[_TE]: + return self._set(fget=fget) + + def setter(self, fset: _HybridSetterType[_TE]) -> hybrid_property[_TE]: + return self._set(fset=fset) + + def deleter( + self, fdel: _HybridDeleterType[_TE] + ) -> hybrid_property[_TE]: + return self._set(fdel=fdel) + + def expression( + self, expr: _HybridExprCallableType[_TE] + ) -> hybrid_property[_TE]: + return self._set(expr=expr) + + def comparator( + self, comparator: _HybridComparatorCallableType[_TE] + ) -> hybrid_property[_TE]: + return self._set(custom_comparator=comparator) + + def update_expression( + self, meth: _HybridUpdaterType[_TE] + ) -> hybrid_property[_TE]: + return self._set(update_expr=meth) + + @property + def inplace(self) -> _InPlace[_T]: + """Return the inplace mutator for this :class:`.hybrid_property`. + + This is to allow in-place mutation of the hybrid, allowing the first + hybrid method of a certain name to be re-used in order to add + more methods without having to name those methods the same, e.g.:: + + class Interval(Base): + # ... + + @hybrid_property + def radius(self) -> float: + return abs(self.length) / 2 + + @radius.inplace.setter + def _radius_setter(self, value: float) -> None: + self.length = value * 2 + + @radius.inplace.expression + def _radius_expression(cls) -> ColumnElement[float]: + return type_coerce(func.abs(cls.length) / 2, Float) + + .. versionadded:: 2.0.4 + + .. seealso:: + + :ref:`hybrid_pep484_naming` + + """ + return hybrid_property._InPlace(self) + + def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a getter method. + + .. versionadded:: 1.2 + + """ + + return self._copy(fget=fget) + + def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a setter method.""" + + return self._copy(fset=fset) + + def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a deletion method.""" + + return self._copy(fdel=fdel) + + def expression( + self, expr: _HybridExprCallableType[_T] + ) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a SQL-expression + producing method. + + When a hybrid is invoked at the class level, the SQL expression given + here is wrapped inside of a specialized :class:`.QueryableAttribute`, + which is the same kind of object used by the ORM to represent other + mapped attributes. The reason for this is so that other class-level + attributes such as docstrings and a reference to the hybrid itself may + be maintained within the structure that's returned, without any + modifications to the original SQL expression passed in. + + .. note:: + + When referring to a hybrid property from an owning class (e.g. + ``SomeClass.some_hybrid``), an instance of + :class:`.QueryableAttribute` is returned, representing the + expression or comparator object as well as this hybrid object. + However, that object itself has accessors called ``expression`` and + ``comparator``; so when attempting to override these decorators on a + subclass, it may be necessary to qualify it using the + :attr:`.hybrid_property.overrides` modifier first. See that + modifier for details. + + .. seealso:: + + :ref:`hybrid_distinct_expression` + + """ + + return self._copy(expr=expr) + + def comparator( + self, comparator: _HybridComparatorCallableType[_T] + ) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a custom + comparator producing method. + + The return value of the decorated method should be an instance of + :class:`~.hybrid.Comparator`. + + .. note:: The :meth:`.hybrid_property.comparator` decorator + **replaces** the use of the :meth:`.hybrid_property.expression` + decorator. They cannot be used together. + + When a hybrid is invoked at the class level, the + :class:`~.hybrid.Comparator` object given here is wrapped inside of a + specialized :class:`.QueryableAttribute`, which is the same kind of + object used by the ORM to represent other mapped attributes. The + reason for this is so that other class-level attributes such as + docstrings and a reference to the hybrid itself may be maintained + within the structure that's returned, without any modifications to the + original comparator object passed in. + + .. note:: + + When referring to a hybrid property from an owning class (e.g. + ``SomeClass.some_hybrid``), an instance of + :class:`.QueryableAttribute` is returned, representing the + expression or comparator object as this hybrid object. However, + that object itself has accessors called ``expression`` and + ``comparator``; so when attempting to override these decorators on a + subclass, it may be necessary to qualify it using the + :attr:`.hybrid_property.overrides` modifier first. See that + modifier for details. + + """ + return self._copy(custom_comparator=comparator) + + def update_expression( + self, meth: _HybridUpdaterType[_T] + ) -> hybrid_property[_T]: + """Provide a modifying decorator that defines an UPDATE tuple + producing method. + + The method accepts a single value, which is the value to be + rendered into the SET clause of an UPDATE statement. The method + should then process this value into individual column expressions + that fit into the ultimate SET clause, and return them as a + sequence of 2-tuples. Each tuple + contains a column expression as the key and a value to be rendered. + + E.g.:: + + class Person(Base): + # ... + + first_name = Column(String) + last_name = Column(String) + + @hybrid_property + def fullname(self): + return first_name + " " + last_name + + @fullname.update_expression + def fullname(cls, value): + fname, lname = value.split(" ", 1) + return [ + (cls.first_name, fname), + (cls.last_name, lname) + ] + + .. versionadded:: 1.2 + + """ + return self._copy(update_expr=meth) + + @util.memoized_property + def _expr_comparator( + self, + ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]: + if self.custom_comparator is not None: + return self._get_comparator(self.custom_comparator) + elif self.expr is not None: + return self._get_expr(self.expr) + else: + return self._get_expr(cast(_HybridExprCallableType[_T], self.fget)) + + def _get_expr( + self, expr: _HybridExprCallableType[_T] + ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]: + def _expr(cls: Any) -> ExprComparator[_T]: + return ExprComparator(cls, expr(cls), self) + + util.update_wrapper(_expr, expr) + + return self._get_comparator(_expr) + + def _get_comparator( + self, comparator: Any + ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]: + proxy_attr = attributes.create_proxied_attribute(self) + + def expr_comparator( + owner: Type[object], + ) -> _HybridClassLevelAccessor[_T]: + # because this is the descriptor protocol, we don't really know + # what our attribute name is. so search for it through the + # MRO. + for lookup in owner.__mro__: + if self.__name__ in lookup.__dict__: + if lookup.__dict__[self.__name__] is self: + name = self.__name__ + break + else: + name = attributes._UNKNOWN_ATTR_KEY # type: ignore[assignment] + + return cast( + "_HybridClassLevelAccessor[_T]", + proxy_attr( + owner, + name, + self, + comparator(owner), + doc=comparator.__doc__ or self.__doc__, + ), + ) + + return expr_comparator + + +class Comparator(interfaces.PropComparator[_T]): + """A helper class that allows easy construction of custom + :class:`~.orm.interfaces.PropComparator` + classes for usage with hybrids.""" + + def __init__( + self, expression: Union[_HasClauseElement[_T], SQLColumnExpression[_T]] + ): + self.expression = expression + + def __clause_element__(self) -> roles.ColumnsClauseRole: + expr = self.expression + if is_has_clause_element(expr): + ret_expr = expr.__clause_element__() + else: + if TYPE_CHECKING: + assert isinstance(expr, ColumnElement) + ret_expr = expr + + if TYPE_CHECKING: + # see test_hybrid->test_expression_isnt_clause_element + # that exercises the usual place this is caught if not + # true + assert isinstance(ret_expr, ColumnElement) + return ret_expr + + @util.non_memoized_property + def property(self) -> interfaces.MapperProperty[_T]: + raise NotImplementedError() + + def adapt_to_entity( + self, adapt_to_entity: AliasedInsp[Any] + ) -> Comparator[_T]: + # interesting.... + return self + + +class ExprComparator(Comparator[_T]): + def __init__( + self, + cls: Type[Any], + expression: Union[_HasClauseElement[_T], SQLColumnExpression[_T]], + hybrid: hybrid_property[_T], + ): + self.cls = cls + self.expression = expression + self.hybrid = hybrid + + def __getattr__(self, key: str) -> Any: + return getattr(self.expression, key) + + @util.ro_non_memoized_property + def info(self) -> _InfoType: + return self.hybrid.info + + def _bulk_update_tuples( + self, value: Any + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: + if isinstance(self.expression, attributes.QueryableAttribute): + return self.expression._bulk_update_tuples(value) + elif self.hybrid.update_expr is not None: + return self.hybrid.update_expr(self.cls, value) + else: + return [(self.expression, value)] + + @util.non_memoized_property + def property(self) -> MapperProperty[_T]: + # this accessor is not normally used, however is accessed by things + # like ORM synonyms if the hybrid is used in this context; the + # .property attribute is not necessarily accessible + return self.expression.property # type: ignore + + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.expression, *other, **kwargs) + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(other, self.expression, **kwargs) # type: ignore diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/indexable.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/indexable.py new file mode 100644 index 0000000..3c41930 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/indexable.py @@ -0,0 +1,341 @@ +# ext/indexable.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""Define attributes on ORM-mapped classes that have "index" attributes for +columns with :class:`_types.Indexable` types. + +"index" means the attribute is associated with an element of an +:class:`_types.Indexable` column with the predefined index to access it. +The :class:`_types.Indexable` types include types such as +:class:`_types.ARRAY`, :class:`_types.JSON` and +:class:`_postgresql.HSTORE`. + + + +The :mod:`~sqlalchemy.ext.indexable` extension provides +:class:`_schema.Column`-like interface for any element of an +:class:`_types.Indexable` typed column. In simple cases, it can be +treated as a :class:`_schema.Column` - mapped attribute. + +Synopsis +======== + +Given ``Person`` as a model with a primary key and JSON data field. +While this field may have any number of elements encoded within it, +we would like to refer to the element called ``name`` individually +as a dedicated attribute which behaves like a standalone column:: + + from sqlalchemy import Column, JSON, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.indexable import index_property + + Base = declarative_base() + + class Person(Base): + __tablename__ = 'person' + + id = Column(Integer, primary_key=True) + data = Column(JSON) + + name = index_property('data', 'name') + + +Above, the ``name`` attribute now behaves like a mapped column. We +can compose a new ``Person`` and set the value of ``name``:: + + >>> person = Person(name='Alchemist') + +The value is now accessible:: + + >>> person.name + 'Alchemist' + +Behind the scenes, the JSON field was initialized to a new blank dictionary +and the field was set:: + + >>> person.data + {"name": "Alchemist'} + +The field is mutable in place:: + + >>> person.name = 'Renamed' + >>> person.name + 'Renamed' + >>> person.data + {'name': 'Renamed'} + +When using :class:`.index_property`, the change that we make to the indexable +structure is also automatically tracked as history; we no longer need +to use :class:`~.mutable.MutableDict` in order to track this change +for the unit of work. + +Deletions work normally as well:: + + >>> del person.name + >>> person.data + {} + +Above, deletion of ``person.name`` deletes the value from the dictionary, +but not the dictionary itself. + +A missing key will produce ``AttributeError``:: + + >>> person = Person() + >>> person.name + ... + AttributeError: 'name' + +Unless you set a default value:: + + >>> class Person(Base): + >>> __tablename__ = 'person' + >>> + >>> id = Column(Integer, primary_key=True) + >>> data = Column(JSON) + >>> + >>> name = index_property('data', 'name', default=None) # See default + + >>> person = Person() + >>> print(person.name) + None + + +The attributes are also accessible at the class level. +Below, we illustrate ``Person.name`` used to generate +an indexed SQL criteria:: + + >>> from sqlalchemy.orm import Session + >>> session = Session() + >>> query = session.query(Person).filter(Person.name == 'Alchemist') + +The above query is equivalent to:: + + >>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist') + +Multiple :class:`.index_property` objects can be chained to produce +multiple levels of indexing:: + + from sqlalchemy import Column, JSON, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.indexable import index_property + + Base = declarative_base() + + class Person(Base): + __tablename__ = 'person' + + id = Column(Integer, primary_key=True) + data = Column(JSON) + + birthday = index_property('data', 'birthday') + year = index_property('birthday', 'year') + month = index_property('birthday', 'month') + day = index_property('birthday', 'day') + +Above, a query such as:: + + q = session.query(Person).filter(Person.year == '1980') + +On a PostgreSQL backend, the above query will render as:: + + SELECT person.id, person.data + FROM person + WHERE person.data -> %(data_1)s -> %(param_1)s = %(param_2)s + +Default Values +============== + +:class:`.index_property` includes special behaviors for when the indexed +data structure does not exist, and a set operation is called: + +* For an :class:`.index_property` that is given an integer index value, + the default data structure will be a Python list of ``None`` values, + at least as long as the index value; the value is then set at its + place in the list. This means for an index value of zero, the list + will be initialized to ``[None]`` before setting the given value, + and for an index value of five, the list will be initialized to + ``[None, None, None, None, None]`` before setting the fifth element + to the given value. Note that an existing list is **not** extended + in place to receive a value. + +* for an :class:`.index_property` that is given any other kind of index + value (e.g. strings usually), a Python dictionary is used as the + default data structure. + +* The default data structure can be set to any Python callable using the + :paramref:`.index_property.datatype` parameter, overriding the previous + rules. + + +Subclassing +=========== + +:class:`.index_property` can be subclassed, in particular for the common +use case of providing coercion of values or SQL expressions as they are +accessed. Below is a common recipe for use with a PostgreSQL JSON type, +where we want to also include automatic casting plus ``astext()``:: + + class pg_json_property(index_property): + def __init__(self, attr_name, index, cast_type): + super(pg_json_property, self).__init__(attr_name, index) + self.cast_type = cast_type + + def expr(self, model): + expr = super(pg_json_property, self).expr(model) + return expr.astext.cast(self.cast_type) + +The above subclass can be used with the PostgreSQL-specific +version of :class:`_postgresql.JSON`:: + + from sqlalchemy import Column, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.dialects.postgresql import JSON + + Base = declarative_base() + + class Person(Base): + __tablename__ = 'person' + + id = Column(Integer, primary_key=True) + data = Column(JSON) + + age = pg_json_property('data', 'age', Integer) + +The ``age`` attribute at the instance level works as before; however +when rendering SQL, PostgreSQL's ``->>`` operator will be used +for indexed access, instead of the usual index operator of ``->``:: + + >>> query = session.query(Person).filter(Person.age < 20) + +The above query will render:: + + SELECT person.id, person.data + FROM person + WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s + +""" # noqa +from .. import inspect +from ..ext.hybrid import hybrid_property +from ..orm.attributes import flag_modified + + +__all__ = ["index_property"] + + +class index_property(hybrid_property): # noqa + """A property generator. The generated property describes an object + attribute that corresponds to an :class:`_types.Indexable` + column. + + .. seealso:: + + :mod:`sqlalchemy.ext.indexable` + + """ + + _NO_DEFAULT_ARGUMENT = object() + + def __init__( + self, + attr_name, + index, + default=_NO_DEFAULT_ARGUMENT, + datatype=None, + mutable=True, + onebased=True, + ): + """Create a new :class:`.index_property`. + + :param attr_name: + An attribute name of an `Indexable` typed column, or other + attribute that returns an indexable structure. + :param index: + The index to be used for getting and setting this value. This + should be the Python-side index value for integers. + :param default: + A value which will be returned instead of `AttributeError` + when there is not a value at given index. + :param datatype: default datatype to use when the field is empty. + By default, this is derived from the type of index used; a + Python list for an integer index, or a Python dictionary for + any other style of index. For a list, the list will be + initialized to a list of None values that is at least + ``index`` elements long. + :param mutable: if False, writes and deletes to the attribute will + be disallowed. + :param onebased: assume the SQL representation of this value is + one-based; that is, the first index in SQL is 1, not zero. + """ + + if mutable: + super().__init__(self.fget, self.fset, self.fdel, self.expr) + else: + super().__init__(self.fget, None, None, self.expr) + self.attr_name = attr_name + self.index = index + self.default = default + is_numeric = isinstance(index, int) + onebased = is_numeric and onebased + + if datatype is not None: + self.datatype = datatype + else: + if is_numeric: + self.datatype = lambda: [None for x in range(index + 1)] + else: + self.datatype = dict + self.onebased = onebased + + def _fget_default(self, err=None): + if self.default == self._NO_DEFAULT_ARGUMENT: + raise AttributeError(self.attr_name) from err + else: + return self.default + + def fget(self, instance): + attr_name = self.attr_name + column_value = getattr(instance, attr_name) + if column_value is None: + return self._fget_default() + try: + value = column_value[self.index] + except (KeyError, IndexError) as err: + return self._fget_default(err) + else: + return value + + def fset(self, instance, value): + attr_name = self.attr_name + column_value = getattr(instance, attr_name, None) + if column_value is None: + column_value = self.datatype() + setattr(instance, attr_name, column_value) + column_value[self.index] = value + setattr(instance, attr_name, column_value) + if attr_name in inspect(instance).mapper.attrs: + flag_modified(instance, attr_name) + + def fdel(self, instance): + attr_name = self.attr_name + column_value = getattr(instance, attr_name) + if column_value is None: + raise AttributeError(self.attr_name) + try: + del column_value[self.index] + except KeyError as err: + raise AttributeError(self.attr_name) from err + else: + setattr(instance, attr_name, column_value) + flag_modified(instance, attr_name) + + def expr(self, model): + column = getattr(model, self.attr_name) + index = self.index + if self.onebased: + index += 1 + return column[index] diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/instrumentation.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/instrumentation.py new file mode 100644 index 0000000..5f3c712 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/instrumentation.py @@ -0,0 +1,450 @@ +# ext/instrumentation.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""Extensible class instrumentation. + +The :mod:`sqlalchemy.ext.instrumentation` package provides for alternate +systems of class instrumentation within the ORM. Class instrumentation +refers to how the ORM places attributes on the class which maintain +data and track changes to that data, as well as event hooks installed +on the class. + +.. note:: + The extension package is provided for the benefit of integration + with other object management packages, which already perform + their own instrumentation. It is not intended for general use. + +For examples of how the instrumentation extension is used, +see the example :ref:`examples_instrumentation`. + +""" +import weakref + +from .. import util +from ..orm import attributes +from ..orm import base as orm_base +from ..orm import collections +from ..orm import exc as orm_exc +from ..orm import instrumentation as orm_instrumentation +from ..orm import util as orm_util +from ..orm.instrumentation import _default_dict_getter +from ..orm.instrumentation import _default_manager_getter +from ..orm.instrumentation import _default_opt_manager_getter +from ..orm.instrumentation import _default_state_getter +from ..orm.instrumentation import ClassManager +from ..orm.instrumentation import InstrumentationFactory + + +INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__" +"""Attribute, elects custom instrumentation when present on a mapped class. + +Allows a class to specify a slightly or wildly different technique for +tracking changes made to mapped attributes and collections. + +Only one instrumentation implementation is allowed in a given object +inheritance hierarchy. + +The value of this attribute must be a callable and will be passed a class +object. The callable must return one of: + + - An instance of an :class:`.InstrumentationManager` or subclass + - An object implementing all or some of InstrumentationManager (TODO) + - A dictionary of callables, implementing all or some of the above (TODO) + - An instance of a :class:`.ClassManager` or subclass + +This attribute is consulted by SQLAlchemy instrumentation +resolution, once the :mod:`sqlalchemy.ext.instrumentation` module +has been imported. If custom finders are installed in the global +instrumentation_finders list, they may or may not choose to honor this +attribute. + +""" + + +def find_native_user_instrumentation_hook(cls): + """Find user-specified instrumentation management for a class.""" + return getattr(cls, INSTRUMENTATION_MANAGER, None) + + +instrumentation_finders = [find_native_user_instrumentation_hook] +"""An extensible sequence of callables which return instrumentation +implementations + +When a class is registered, each callable will be passed a class object. +If None is returned, the +next finder in the sequence is consulted. Otherwise the return must be an +instrumentation factory that follows the same guidelines as +sqlalchemy.ext.instrumentation.INSTRUMENTATION_MANAGER. + +By default, the only finder is find_native_user_instrumentation_hook, which +searches for INSTRUMENTATION_MANAGER. If all finders return None, standard +ClassManager instrumentation is used. + +""" + + +class ExtendedInstrumentationRegistry(InstrumentationFactory): + """Extends :class:`.InstrumentationFactory` with additional + bookkeeping, to accommodate multiple types of + class managers. + + """ + + _manager_finders = weakref.WeakKeyDictionary() + _state_finders = weakref.WeakKeyDictionary() + _dict_finders = weakref.WeakKeyDictionary() + _extended = False + + def _locate_extended_factory(self, class_): + for finder in instrumentation_finders: + factory = finder(class_) + if factory is not None: + manager = self._extended_class_manager(class_, factory) + return manager, factory + else: + return None, None + + def _check_conflicts(self, class_, factory): + existing_factories = self._collect_management_factories_for( + class_ + ).difference([factory]) + if existing_factories: + raise TypeError( + "multiple instrumentation implementations specified " + "in %s inheritance hierarchy: %r" + % (class_.__name__, list(existing_factories)) + ) + + def _extended_class_manager(self, class_, factory): + manager = factory(class_) + if not isinstance(manager, ClassManager): + manager = _ClassInstrumentationAdapter(class_, manager) + + if factory != ClassManager and not self._extended: + # somebody invoked a custom ClassManager. + # reinstall global "getter" functions with the more + # expensive ones. + self._extended = True + _install_instrumented_lookups() + + self._manager_finders[class_] = manager.manager_getter() + self._state_finders[class_] = manager.state_getter() + self._dict_finders[class_] = manager.dict_getter() + return manager + + def _collect_management_factories_for(self, cls): + """Return a collection of factories in play or specified for a + hierarchy. + + Traverses the entire inheritance graph of a cls and returns a + collection of instrumentation factories for those classes. Factories + are extracted from active ClassManagers, if available, otherwise + instrumentation_finders is consulted. + + """ + hierarchy = util.class_hierarchy(cls) + factories = set() + for member in hierarchy: + manager = self.opt_manager_of_class(member) + if manager is not None: + factories.add(manager.factory) + else: + for finder in instrumentation_finders: + factory = finder(member) + if factory is not None: + break + else: + factory = None + factories.add(factory) + factories.discard(None) + return factories + + def unregister(self, class_): + super().unregister(class_) + if class_ in self._manager_finders: + del self._manager_finders[class_] + del self._state_finders[class_] + del self._dict_finders[class_] + + def opt_manager_of_class(self, cls): + try: + finder = self._manager_finders.get( + cls, _default_opt_manager_getter + ) + except TypeError: + # due to weakref lookup on invalid object + return None + else: + return finder(cls) + + def manager_of_class(self, cls): + try: + finder = self._manager_finders.get(cls, _default_manager_getter) + except TypeError: + # due to weakref lookup on invalid object + raise orm_exc.UnmappedClassError( + cls, f"Can't locate an instrumentation manager for class {cls}" + ) + else: + manager = finder(cls) + if manager is None: + raise orm_exc.UnmappedClassError( + cls, + f"Can't locate an instrumentation manager for class {cls}", + ) + return manager + + def state_of(self, instance): + if instance is None: + raise AttributeError("None has no persistent state.") + return self._state_finders.get( + instance.__class__, _default_state_getter + )(instance) + + def dict_of(self, instance): + if instance is None: + raise AttributeError("None has no persistent state.") + return self._dict_finders.get( + instance.__class__, _default_dict_getter + )(instance) + + +orm_instrumentation._instrumentation_factory = _instrumentation_factory = ( + ExtendedInstrumentationRegistry() +) +orm_instrumentation.instrumentation_finders = instrumentation_finders + + +class InstrumentationManager: + """User-defined class instrumentation extension. + + :class:`.InstrumentationManager` can be subclassed in order + to change + how class instrumentation proceeds. This class exists for + the purposes of integration with other object management + frameworks which would like to entirely modify the + instrumentation methodology of the ORM, and is not intended + for regular usage. For interception of class instrumentation + events, see :class:`.InstrumentationEvents`. + + The API for this class should be considered as semi-stable, + and may change slightly with new releases. + + """ + + # r4361 added a mandatory (cls) constructor to this interface. + # given that, perhaps class_ should be dropped from all of these + # signatures. + + def __init__(self, class_): + pass + + def manage(self, class_, manager): + setattr(class_, "_default_class_manager", manager) + + def unregister(self, class_, manager): + delattr(class_, "_default_class_manager") + + def manager_getter(self, class_): + def get(cls): + return cls._default_class_manager + + return get + + def instrument_attribute(self, class_, key, inst): + pass + + def post_configure_attribute(self, class_, key, inst): + pass + + def install_descriptor(self, class_, key, inst): + setattr(class_, key, inst) + + def uninstall_descriptor(self, class_, key): + delattr(class_, key) + + def install_member(self, class_, key, implementation): + setattr(class_, key, implementation) + + def uninstall_member(self, class_, key): + delattr(class_, key) + + def instrument_collection_class(self, class_, key, collection_class): + return collections.prepare_instrumentation(collection_class) + + def get_instance_dict(self, class_, instance): + return instance.__dict__ + + def initialize_instance_dict(self, class_, instance): + pass + + def install_state(self, class_, instance, state): + setattr(instance, "_default_state", state) + + def remove_state(self, class_, instance): + delattr(instance, "_default_state") + + def state_getter(self, class_): + return lambda instance: getattr(instance, "_default_state") + + def dict_getter(self, class_): + return lambda inst: self.get_instance_dict(class_, inst) + + +class _ClassInstrumentationAdapter(ClassManager): + """Adapts a user-defined InstrumentationManager to a ClassManager.""" + + def __init__(self, class_, override): + self._adapted = override + self._get_state = self._adapted.state_getter(class_) + self._get_dict = self._adapted.dict_getter(class_) + + ClassManager.__init__(self, class_) + + def manage(self): + self._adapted.manage(self.class_, self) + + def unregister(self): + self._adapted.unregister(self.class_, self) + + def manager_getter(self): + return self._adapted.manager_getter(self.class_) + + def instrument_attribute(self, key, inst, propagated=False): + ClassManager.instrument_attribute(self, key, inst, propagated) + if not propagated: + self._adapted.instrument_attribute(self.class_, key, inst) + + def post_configure_attribute(self, key): + super().post_configure_attribute(key) + self._adapted.post_configure_attribute(self.class_, key, self[key]) + + def install_descriptor(self, key, inst): + self._adapted.install_descriptor(self.class_, key, inst) + + def uninstall_descriptor(self, key): + self._adapted.uninstall_descriptor(self.class_, key) + + def install_member(self, key, implementation): + self._adapted.install_member(self.class_, key, implementation) + + def uninstall_member(self, key): + self._adapted.uninstall_member(self.class_, key) + + def instrument_collection_class(self, key, collection_class): + return self._adapted.instrument_collection_class( + self.class_, key, collection_class + ) + + def initialize_collection(self, key, state, factory): + delegate = getattr(self._adapted, "initialize_collection", None) + if delegate: + return delegate(key, state, factory) + else: + return ClassManager.initialize_collection( + self, key, state, factory + ) + + def new_instance(self, state=None): + instance = self.class_.__new__(self.class_) + self.setup_instance(instance, state) + return instance + + def _new_state_if_none(self, instance): + """Install a default InstanceState if none is present. + + A private convenience method used by the __init__ decorator. + """ + if self.has_state(instance): + return False + else: + return self.setup_instance(instance) + + def setup_instance(self, instance, state=None): + self._adapted.initialize_instance_dict(self.class_, instance) + + if state is None: + state = self._state_constructor(instance, self) + + # the given instance is assumed to have no state + self._adapted.install_state(self.class_, instance, state) + return state + + def teardown_instance(self, instance): + self._adapted.remove_state(self.class_, instance) + + def has_state(self, instance): + try: + self._get_state(instance) + except orm_exc.NO_STATE: + return False + else: + return True + + def state_getter(self): + return self._get_state + + def dict_getter(self): + return self._get_dict + + +def _install_instrumented_lookups(): + """Replace global class/object management functions + with ExtendedInstrumentationRegistry implementations, which + allow multiple types of class managers to be present, + at the cost of performance. + + This function is called only by ExtendedInstrumentationRegistry + and unit tests specific to this behavior. + + The _reinstall_default_lookups() function can be called + after this one to re-establish the default functions. + + """ + _install_lookups( + dict( + instance_state=_instrumentation_factory.state_of, + instance_dict=_instrumentation_factory.dict_of, + manager_of_class=_instrumentation_factory.manager_of_class, + opt_manager_of_class=_instrumentation_factory.opt_manager_of_class, + ) + ) + + +def _reinstall_default_lookups(): + """Restore simplified lookups.""" + _install_lookups( + dict( + instance_state=_default_state_getter, + instance_dict=_default_dict_getter, + manager_of_class=_default_manager_getter, + opt_manager_of_class=_default_opt_manager_getter, + ) + ) + _instrumentation_factory._extended = False + + +def _install_lookups(lookups): + global instance_state, instance_dict + global manager_of_class, opt_manager_of_class + instance_state = lookups["instance_state"] + instance_dict = lookups["instance_dict"] + manager_of_class = lookups["manager_of_class"] + opt_manager_of_class = lookups["opt_manager_of_class"] + orm_base.instance_state = attributes.instance_state = ( + orm_instrumentation.instance_state + ) = instance_state + orm_base.instance_dict = attributes.instance_dict = ( + orm_instrumentation.instance_dict + ) = instance_dict + orm_base.manager_of_class = attributes.manager_of_class = ( + orm_instrumentation.manager_of_class + ) = manager_of_class + orm_base.opt_manager_of_class = orm_util.opt_manager_of_class = ( + attributes.opt_manager_of_class + ) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mutable.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mutable.py new file mode 100644 index 0000000..7da5075 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mutable.py @@ -0,0 +1,1073 @@ +# ext/mutable.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r"""Provide support for tracking of in-place changes to scalar values, +which are propagated into ORM change events on owning parent objects. + +.. _mutable_scalars: + +Establishing Mutability on Scalar Column Values +=============================================== + +A typical example of a "mutable" structure is a Python dictionary. +Following the example introduced in :ref:`types_toplevel`, we +begin with a custom type that marshals Python dictionaries into +JSON strings before being persisted:: + + from sqlalchemy.types import TypeDecorator, VARCHAR + import json + + class JSONEncodedDict(TypeDecorator): + "Represents an immutable structure as a json-encoded string." + + impl = VARCHAR + + def process_bind_param(self, value, dialect): + if value is not None: + value = json.dumps(value) + return value + + def process_result_value(self, value, dialect): + if value is not None: + value = json.loads(value) + return value + +The usage of ``json`` is only for the purposes of example. The +:mod:`sqlalchemy.ext.mutable` extension can be used +with any type whose target Python type may be mutable, including +:class:`.PickleType`, :class:`_postgresql.ARRAY`, etc. + +When using the :mod:`sqlalchemy.ext.mutable` extension, the value itself +tracks all parents which reference it. Below, we illustrate a simple +version of the :class:`.MutableDict` dictionary object, which applies +the :class:`.Mutable` mixin to a plain Python dictionary:: + + from sqlalchemy.ext.mutable import Mutable + + class MutableDict(Mutable, dict): + @classmethod + def coerce(cls, key, value): + "Convert plain dictionaries to MutableDict." + + if not isinstance(value, MutableDict): + if isinstance(value, dict): + return MutableDict(value) + + # this call will raise ValueError + return Mutable.coerce(key, value) + else: + return value + + def __setitem__(self, key, value): + "Detect dictionary set events and emit change events." + + dict.__setitem__(self, key, value) + self.changed() + + def __delitem__(self, key): + "Detect dictionary del events and emit change events." + + dict.__delitem__(self, key) + self.changed() + +The above dictionary class takes the approach of subclassing the Python +built-in ``dict`` to produce a dict +subclass which routes all mutation events through ``__setitem__``. There are +variants on this approach, such as subclassing ``UserDict.UserDict`` or +``collections.MutableMapping``; the part that's important to this example is +that the :meth:`.Mutable.changed` method is called whenever an in-place +change to the datastructure takes place. + +We also redefine the :meth:`.Mutable.coerce` method which will be used to +convert any values that are not instances of ``MutableDict``, such +as the plain dictionaries returned by the ``json`` module, into the +appropriate type. Defining this method is optional; we could just as well +created our ``JSONEncodedDict`` such that it always returns an instance +of ``MutableDict``, and additionally ensured that all calling code +uses ``MutableDict`` explicitly. When :meth:`.Mutable.coerce` is not +overridden, any values applied to a parent object which are not instances +of the mutable type will raise a ``ValueError``. + +Our new ``MutableDict`` type offers a class method +:meth:`~.Mutable.as_mutable` which we can use within column metadata +to associate with types. This method grabs the given type object or +class and associates a listener that will detect all future mappings +of this type, applying event listening instrumentation to the mapped +attribute. Such as, with classical table metadata:: + + from sqlalchemy import Table, Column, Integer + + my_data = Table('my_data', metadata, + Column('id', Integer, primary_key=True), + Column('data', MutableDict.as_mutable(JSONEncodedDict)) + ) + +Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict`` +(if the type object was not an instance already), which will intercept any +attributes which are mapped against this type. Below we establish a simple +mapping against the ``my_data`` table:: + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + class Base(DeclarativeBase): + pass + + class MyDataClass(Base): + __tablename__ = 'my_data' + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[dict[str, str]] = mapped_column(MutableDict.as_mutable(JSONEncodedDict)) + +The ``MyDataClass.data`` member will now be notified of in place changes +to its value. + +Any in-place changes to the ``MyDataClass.data`` member +will flag the attribute as "dirty" on the parent object:: + + >>> from sqlalchemy.orm import Session + + >>> sess = Session(some_engine) + >>> m1 = MyDataClass(data={'value1':'foo'}) + >>> sess.add(m1) + >>> sess.commit() + + >>> m1.data['value1'] = 'bar' + >>> assert m1 in sess.dirty + True + +The ``MutableDict`` can be associated with all future instances +of ``JSONEncodedDict`` in one step, using +:meth:`~.Mutable.associate_with`. This is similar to +:meth:`~.Mutable.as_mutable` except it will intercept all occurrences +of ``MutableDict`` in all mappings unconditionally, without +the need to declare it individually:: + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + MutableDict.associate_with(JSONEncodedDict) + + class Base(DeclarativeBase): + pass + + class MyDataClass(Base): + __tablename__ = 'my_data' + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[dict[str, str]] = mapped_column(JSONEncodedDict) + + +Supporting Pickling +-------------------- + +The key to the :mod:`sqlalchemy.ext.mutable` extension relies upon the +placement of a ``weakref.WeakKeyDictionary`` upon the value object, which +stores a mapping of parent mapped objects keyed to the attribute name under +which they are associated with this value. ``WeakKeyDictionary`` objects are +not picklable, due to the fact that they contain weakrefs and function +callbacks. In our case, this is a good thing, since if this dictionary were +picklable, it could lead to an excessively large pickle size for our value +objects that are pickled by themselves outside of the context of the parent. +The developer responsibility here is only to provide a ``__getstate__`` method +that excludes the :meth:`~MutableBase._parents` collection from the pickle +stream:: + + class MyMutableType(Mutable): + def __getstate__(self): + d = self.__dict__.copy() + d.pop('_parents', None) + return d + +With our dictionary example, we need to return the contents of the dict itself +(and also restore them on __setstate__):: + + class MutableDict(Mutable, dict): + # .... + + def __getstate__(self): + return dict(self) + + def __setstate__(self, state): + self.update(state) + +In the case that our mutable value object is pickled as it is attached to one +or more parent objects that are also part of the pickle, the :class:`.Mutable` +mixin will re-establish the :attr:`.Mutable._parents` collection on each value +object as the owning parents themselves are unpickled. + +Receiving Events +---------------- + +The :meth:`.AttributeEvents.modified` event handler may be used to receive +an event when a mutable scalar emits a change event. This event handler +is called when the :func:`.attributes.flag_modified` function is called +from within the mutable extension:: + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy import event + + class Base(DeclarativeBase): + pass + + class MyDataClass(Base): + __tablename__ = 'my_data' + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[dict[str, str]] = mapped_column(MutableDict.as_mutable(JSONEncodedDict)) + + @event.listens_for(MyDataClass.data, "modified") + def modified_json(instance, initiator): + print("json value modified:", instance.data) + +.. _mutable_composites: + +Establishing Mutability on Composites +===================================== + +Composites are a special ORM feature which allow a single scalar attribute to +be assigned an object value which represents information "composed" from one +or more columns from the underlying mapped table. The usual example is that of +a geometric "point", and is introduced in :ref:`mapper_composite`. + +As is the case with :class:`.Mutable`, the user-defined composite class +subclasses :class:`.MutableComposite` as a mixin, and detects and delivers +change events to its parents via the :meth:`.MutableComposite.changed` method. +In the case of a composite class, the detection is usually via the usage of the +special Python method ``__setattr__()``. In the example below, we expand upon the ``Point`` +class introduced in :ref:`mapper_composite` to include +:class:`.MutableComposite` in its bases and to route attribute set events via +``__setattr__`` to the :meth:`.MutableComposite.changed` method:: + + import dataclasses + from sqlalchemy.ext.mutable import MutableComposite + + @dataclasses.dataclass + class Point(MutableComposite): + x: int + y: int + + def __setattr__(self, key, value): + "Intercept set events" + + # set the attribute + object.__setattr__(self, key, value) + + # alert all parents to the change + self.changed() + + +The :class:`.MutableComposite` class makes use of class mapping events to +automatically establish listeners for any usage of :func:`_orm.composite` that +specifies our ``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` +class, listeners are established which will route change events from ``Point`` +objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes:: + + from sqlalchemy.orm import DeclarativeBase, Mapped + from sqlalchemy.orm import composite, mapped_column + + class Base(DeclarativeBase): + pass + + + class Vertex(Base): + __tablename__ = "vertices" + + id: Mapped[int] = mapped_column(primary_key=True) + + start: Mapped[Point] = composite(mapped_column("x1"), mapped_column("y1")) + end: Mapped[Point] = composite(mapped_column("x2"), mapped_column("y2")) + + def __repr__(self): + return f"Vertex(start={self.start}, end={self.end})" + +Any in-place changes to the ``Vertex.start`` or ``Vertex.end`` members +will flag the attribute as "dirty" on the parent object: + +.. sourcecode:: python+sql + + >>> from sqlalchemy.orm import Session + >>> sess = Session(engine) + >>> v1 = Vertex(start=Point(3, 4), end=Point(12, 15)) + >>> sess.add(v1) + {sql}>>> sess.flush() + BEGIN (implicit) + INSERT INTO vertices (x1, y1, x2, y2) VALUES (?, ?, ?, ?) + [...] (3, 4, 12, 15) + + {stop}>>> v1.end.x = 8 + >>> assert v1 in sess.dirty + True + {sql}>>> sess.commit() + UPDATE vertices SET x2=? WHERE vertices.id = ? + [...] (8, 1) + COMMIT + +Coercing Mutable Composites +--------------------------- + +The :meth:`.MutableBase.coerce` method is also supported on composite types. +In the case of :class:`.MutableComposite`, the :meth:`.MutableBase.coerce` +method is only called for attribute set operations, not load operations. +Overriding the :meth:`.MutableBase.coerce` method is essentially equivalent +to using a :func:`.validates` validation routine for all attributes which +make use of the custom composite type:: + + @dataclasses.dataclass + class Point(MutableComposite): + # other Point methods + # ... + + def coerce(cls, key, value): + if isinstance(value, tuple): + value = Point(*value) + elif not isinstance(value, Point): + raise ValueError("tuple or Point expected") + return value + +Supporting Pickling +-------------------- + +As is the case with :class:`.Mutable`, the :class:`.MutableComposite` helper +class uses a ``weakref.WeakKeyDictionary`` available via the +:meth:`MutableBase._parents` attribute which isn't picklable. If we need to +pickle instances of ``Point`` or its owning class ``Vertex``, we at least need +to define a ``__getstate__`` that doesn't include the ``_parents`` dictionary. +Below we define both a ``__getstate__`` and a ``__setstate__`` that package up +the minimal form of our ``Point`` class:: + + @dataclasses.dataclass + class Point(MutableComposite): + # ... + + def __getstate__(self): + return self.x, self.y + + def __setstate__(self, state): + self.x, self.y = state + +As with :class:`.Mutable`, the :class:`.MutableComposite` augments the +pickling process of the parent's object-relational state so that the +:meth:`MutableBase._parents` collection is restored to all ``Point`` objects. + +""" # noqa: E501 + +from __future__ import annotations + +from collections import defaultdict +from typing import AbstractSet +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref +from weakref import WeakKeyDictionary + +from .. import event +from .. import inspect +from .. import types +from .. import util +from ..orm import Mapper +from ..orm._typing import _ExternalEntityType +from ..orm._typing import _O +from ..orm._typing import _T +from ..orm.attributes import AttributeEventToken +from ..orm.attributes import flag_modified +from ..orm.attributes import InstrumentedAttribute +from ..orm.attributes import QueryableAttribute +from ..orm.context import QueryContext +from ..orm.decl_api import DeclarativeAttributeIntercept +from ..orm.state import InstanceState +from ..orm.unitofwork import UOWTransaction +from ..sql.base import SchemaEventTarget +from ..sql.schema import Column +from ..sql.type_api import TypeEngine +from ..util import memoized_property +from ..util.typing import SupportsIndex +from ..util.typing import TypeGuard + +_KT = TypeVar("_KT") # Key type. +_VT = TypeVar("_VT") # Value type. + + +class MutableBase: + """Common base class to :class:`.Mutable` + and :class:`.MutableComposite`. + + """ + + @memoized_property + def _parents(self) -> WeakKeyDictionary[Any, Any]: + """Dictionary of parent object's :class:`.InstanceState`->attribute + name on the parent. + + This attribute is a so-called "memoized" property. It initializes + itself with a new ``weakref.WeakKeyDictionary`` the first time + it is accessed, returning the same object upon subsequent access. + + .. versionchanged:: 1.4 the :class:`.InstanceState` is now used + as the key in the weak dictionary rather than the instance + itself. + + """ + + return weakref.WeakKeyDictionary() + + @classmethod + def coerce(cls, key: str, value: Any) -> Optional[Any]: + """Given a value, coerce it into the target type. + + Can be overridden by custom subclasses to coerce incoming + data into a particular type. + + By default, raises ``ValueError``. + + This method is called in different scenarios depending on if + the parent class is of type :class:`.Mutable` or of type + :class:`.MutableComposite`. In the case of the former, it is called + for both attribute-set operations as well as during ORM loading + operations. For the latter, it is only called during attribute-set + operations; the mechanics of the :func:`.composite` construct + handle coercion during load operations. + + + :param key: string name of the ORM-mapped attribute being set. + :param value: the incoming value. + :return: the method should return the coerced value, or raise + ``ValueError`` if the coercion cannot be completed. + + """ + if value is None: + return None + msg = "Attribute '%s' does not accept objects of type %s" + raise ValueError(msg % (key, type(value))) + + @classmethod + def _get_listen_keys(cls, attribute: QueryableAttribute[Any]) -> Set[str]: + """Given a descriptor attribute, return a ``set()`` of the attribute + keys which indicate a change in the state of this attribute. + + This is normally just ``set([attribute.key])``, but can be overridden + to provide for additional keys. E.g. a :class:`.MutableComposite` + augments this set with the attribute keys associated with the columns + that comprise the composite value. + + This collection is consulted in the case of intercepting the + :meth:`.InstanceEvents.refresh` and + :meth:`.InstanceEvents.refresh_flush` events, which pass along a list + of attribute names that have been refreshed; the list is compared + against this set to determine if action needs to be taken. + + """ + return {attribute.key} + + @classmethod + def _listen_on_attribute( + cls, + attribute: QueryableAttribute[Any], + coerce: bool, + parent_cls: _ExternalEntityType[Any], + ) -> None: + """Establish this type as a mutation listener for the given + mapped descriptor. + + """ + key = attribute.key + if parent_cls is not attribute.class_: + return + + # rely on "propagate" here + parent_cls = attribute.class_ + + listen_keys = cls._get_listen_keys(attribute) + + def load(state: InstanceState[_O], *args: Any) -> None: + """Listen for objects loaded or refreshed. + + Wrap the target data member's value with + ``Mutable``. + + """ + val = state.dict.get(key, None) + if val is not None: + if coerce: + val = cls.coerce(key, val) + state.dict[key] = val + val._parents[state] = key + + def load_attrs( + state: InstanceState[_O], + ctx: Union[object, QueryContext, UOWTransaction], + attrs: Iterable[Any], + ) -> None: + if not attrs or listen_keys.intersection(attrs): + load(state) + + def set_( + target: InstanceState[_O], + value: MutableBase | None, + oldvalue: MutableBase | None, + initiator: AttributeEventToken, + ) -> MutableBase | None: + """Listen for set/replace events on the target + data member. + + Establish a weak reference to the parent object + on the incoming value, remove it for the one + outgoing. + + """ + if value is oldvalue: + return value + + if not isinstance(value, cls): + value = cls.coerce(key, value) + if value is not None: + value._parents[target] = key + if isinstance(oldvalue, cls): + oldvalue._parents.pop(inspect(target), None) + return value + + def pickle( + state: InstanceState[_O], state_dict: Dict[str, Any] + ) -> None: + val = state.dict.get(key, None) + if val is not None: + if "ext.mutable.values" not in state_dict: + state_dict["ext.mutable.values"] = defaultdict(list) + state_dict["ext.mutable.values"][key].append(val) + + def unpickle( + state: InstanceState[_O], state_dict: Dict[str, Any] + ) -> None: + if "ext.mutable.values" in state_dict: + collection = state_dict["ext.mutable.values"] + if isinstance(collection, list): + # legacy format + for val in collection: + val._parents[state] = key + else: + for val in state_dict["ext.mutable.values"][key]: + val._parents[state] = key + + event.listen( + parent_cls, + "_sa_event_merge_wo_load", + load, + raw=True, + propagate=True, + ) + + event.listen(parent_cls, "load", load, raw=True, propagate=True) + event.listen( + parent_cls, "refresh", load_attrs, raw=True, propagate=True + ) + event.listen( + parent_cls, "refresh_flush", load_attrs, raw=True, propagate=True + ) + event.listen( + attribute, "set", set_, raw=True, retval=True, propagate=True + ) + event.listen(parent_cls, "pickle", pickle, raw=True, propagate=True) + event.listen( + parent_cls, "unpickle", unpickle, raw=True, propagate=True + ) + + +class Mutable(MutableBase): + """Mixin that defines transparent propagation of change + events to a parent object. + + See the example in :ref:`mutable_scalars` for usage information. + + """ + + def changed(self) -> None: + """Subclasses should call this method whenever change events occur.""" + + for parent, key in self._parents.items(): + flag_modified(parent.obj(), key) + + @classmethod + def associate_with_attribute( + cls, attribute: InstrumentedAttribute[_O] + ) -> None: + """Establish this type as a mutation listener for the given + mapped descriptor. + + """ + cls._listen_on_attribute(attribute, True, attribute.class_) + + @classmethod + def associate_with(cls, sqltype: type) -> None: + """Associate this wrapper with all future mapped columns + of the given type. + + This is a convenience method that calls + ``associate_with_attribute`` automatically. + + .. warning:: + + The listeners established by this method are *global* + to all mappers, and are *not* garbage collected. Only use + :meth:`.associate_with` for types that are permanent to an + application, not with ad-hoc types else this will cause unbounded + growth in memory usage. + + """ + + def listen_for_type(mapper: Mapper[_O], class_: type) -> None: + if mapper.non_primary: + return + for prop in mapper.column_attrs: + if isinstance(prop.columns[0].type, sqltype): + cls.associate_with_attribute(getattr(class_, prop.key)) + + event.listen(Mapper, "mapper_configured", listen_for_type) + + @classmethod + def as_mutable(cls, sqltype: TypeEngine[_T]) -> TypeEngine[_T]: + """Associate a SQL type with this mutable Python type. + + This establishes listeners that will detect ORM mappings against + the given type, adding mutation event trackers to those mappings. + + The type is returned, unconditionally as an instance, so that + :meth:`.as_mutable` can be used inline:: + + Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('data', MyMutableType.as_mutable(PickleType)) + ) + + Note that the returned type is always an instance, even if a class + is given, and that only columns which are declared specifically with + that type instance receive additional instrumentation. + + To associate a particular mutable type with all occurrences of a + particular type, use the :meth:`.Mutable.associate_with` classmethod + of the particular :class:`.Mutable` subclass to establish a global + association. + + .. warning:: + + The listeners established by this method are *global* + to all mappers, and are *not* garbage collected. Only use + :meth:`.as_mutable` for types that are permanent to an application, + not with ad-hoc types else this will cause unbounded growth + in memory usage. + + """ + sqltype = types.to_instance(sqltype) + + # a SchemaType will be copied when the Column is copied, + # and we'll lose our ability to link that type back to the original. + # so track our original type w/ columns + if isinstance(sqltype, SchemaEventTarget): + + @event.listens_for(sqltype, "before_parent_attach") + def _add_column_memo( + sqltyp: TypeEngine[Any], + parent: Column[_T], + ) -> None: + parent.info["_ext_mutable_orig_type"] = sqltyp + + schema_event_check = True + else: + schema_event_check = False + + def listen_for_type( + mapper: Mapper[_T], + class_: Union[DeclarativeAttributeIntercept, type], + ) -> None: + if mapper.non_primary: + return + _APPLIED_KEY = "_ext_mutable_listener_applied" + + for prop in mapper.column_attrs: + if ( + # all Mutable types refer to a Column that's mapped, + # since this is the only kind of Core target the ORM can + # "mutate" + isinstance(prop.expression, Column) + and ( + ( + schema_event_check + and prop.expression.info.get( + "_ext_mutable_orig_type" + ) + is sqltype + ) + or prop.expression.type is sqltype + ) + ): + if not prop.expression.info.get(_APPLIED_KEY, False): + prop.expression.info[_APPLIED_KEY] = True + cls.associate_with_attribute(getattr(class_, prop.key)) + + event.listen(Mapper, "mapper_configured", listen_for_type) + + return sqltype + + +class MutableComposite(MutableBase): + """Mixin that defines transparent propagation of change + events on a SQLAlchemy "composite" object to its + owning parent or parents. + + See the example in :ref:`mutable_composites` for usage information. + + """ + + @classmethod + def _get_listen_keys(cls, attribute: QueryableAttribute[_O]) -> Set[str]: + return {attribute.key}.union(attribute.property._attribute_keys) + + def changed(self) -> None: + """Subclasses should call this method whenever change events occur.""" + + for parent, key in self._parents.items(): + prop = parent.mapper.get_property(key) + for value, attr_name in zip( + prop._composite_values_from_instance(self), + prop._attribute_keys, + ): + setattr(parent.obj(), attr_name, value) + + +def _setup_composite_listener() -> None: + def _listen_for_type(mapper: Mapper[_T], class_: type) -> None: + for prop in mapper.iterate_properties: + if ( + hasattr(prop, "composite_class") + and isinstance(prop.composite_class, type) + and issubclass(prop.composite_class, MutableComposite) + ): + prop.composite_class._listen_on_attribute( + getattr(class_, prop.key), False, class_ + ) + + if not event.contains(Mapper, "mapper_configured", _listen_for_type): + event.listen(Mapper, "mapper_configured", _listen_for_type) + + +_setup_composite_listener() + + +class MutableDict(Mutable, Dict[_KT, _VT]): + """A dictionary type that implements :class:`.Mutable`. + + The :class:`.MutableDict` object implements a dictionary that will + emit change events to the underlying mapping when the contents of + the dictionary are altered, including when values are added or removed. + + Note that :class:`.MutableDict` does **not** apply mutable tracking to the + *values themselves* inside the dictionary. Therefore it is not a sufficient + solution for the use case of tracking deep changes to a *recursive* + dictionary structure, such as a JSON structure. To support this use case, + build a subclass of :class:`.MutableDict` that provides appropriate + coercion to the values placed in the dictionary so that they too are + "mutable", and emit events up to their parent structure. + + .. seealso:: + + :class:`.MutableList` + + :class:`.MutableSet` + + """ + + def __setitem__(self, key: _KT, value: _VT) -> None: + """Detect dictionary set events and emit change events.""" + super().__setitem__(key, value) + self.changed() + + if TYPE_CHECKING: + # from https://github.com/python/mypy/issues/14858 + + @overload + def setdefault( + self: MutableDict[_KT, Optional[_T]], key: _KT, value: None = None + ) -> Optional[_T]: ... + + @overload + def setdefault(self, key: _KT, value: _VT) -> _VT: ... + + def setdefault(self, key: _KT, value: object = None) -> object: ... + + else: + + def setdefault(self, *arg): # noqa: F811 + result = super().setdefault(*arg) + self.changed() + return result + + def __delitem__(self, key: _KT) -> None: + """Detect dictionary del events and emit change events.""" + super().__delitem__(key) + self.changed() + + def update(self, *a: Any, **kw: _VT) -> None: + super().update(*a, **kw) + self.changed() + + if TYPE_CHECKING: + + @overload + def pop(self, __key: _KT) -> _VT: ... + + @overload + def pop(self, __key: _KT, __default: _VT | _T) -> _VT | _T: ... + + def pop( + self, __key: _KT, __default: _VT | _T | None = None + ) -> _VT | _T: ... + + else: + + def pop(self, *arg): # noqa: F811 + result = super().pop(*arg) + self.changed() + return result + + def popitem(self) -> Tuple[_KT, _VT]: + result = super().popitem() + self.changed() + return result + + def clear(self) -> None: + super().clear() + self.changed() + + @classmethod + def coerce(cls, key: str, value: Any) -> MutableDict[_KT, _VT] | None: + """Convert plain dictionary to instance of this class.""" + if not isinstance(value, cls): + if isinstance(value, dict): + return cls(value) + return Mutable.coerce(key, value) + else: + return value + + def __getstate__(self) -> Dict[_KT, _VT]: + return dict(self) + + def __setstate__( + self, state: Union[Dict[str, int], Dict[str, str]] + ) -> None: + self.update(state) + + +class MutableList(Mutable, List[_T]): + """A list type that implements :class:`.Mutable`. + + The :class:`.MutableList` object implements a list that will + emit change events to the underlying mapping when the contents of + the list are altered, including when values are added or removed. + + Note that :class:`.MutableList` does **not** apply mutable tracking to the + *values themselves* inside the list. Therefore it is not a sufficient + solution for the use case of tracking deep changes to a *recursive* + mutable structure, such as a JSON structure. To support this use case, + build a subclass of :class:`.MutableList` that provides appropriate + coercion to the values placed in the dictionary so that they too are + "mutable", and emit events up to their parent structure. + + .. seealso:: + + :class:`.MutableDict` + + :class:`.MutableSet` + + """ + + def __reduce_ex__( + self, proto: SupportsIndex + ) -> Tuple[type, Tuple[List[int]]]: + return (self.__class__, (list(self),)) + + # needed for backwards compatibility with + # older pickles + def __setstate__(self, state: Iterable[_T]) -> None: + self[:] = state + + def is_scalar(self, value: _T | Iterable[_T]) -> TypeGuard[_T]: + return not util.is_non_string_iterable(value) + + def is_iterable(self, value: _T | Iterable[_T]) -> TypeGuard[Iterable[_T]]: + return util.is_non_string_iterable(value) + + def __setitem__( + self, index: SupportsIndex | slice, value: _T | Iterable[_T] + ) -> None: + """Detect list set events and emit change events.""" + if isinstance(index, SupportsIndex) and self.is_scalar(value): + super().__setitem__(index, value) + elif isinstance(index, slice) and self.is_iterable(value): + super().__setitem__(index, value) + self.changed() + + def __delitem__(self, index: SupportsIndex | slice) -> None: + """Detect list del events and emit change events.""" + super().__delitem__(index) + self.changed() + + def pop(self, *arg: SupportsIndex) -> _T: + result = super().pop(*arg) + self.changed() + return result + + def append(self, x: _T) -> None: + super().append(x) + self.changed() + + def extend(self, x: Iterable[_T]) -> None: + super().extend(x) + self.changed() + + def __iadd__(self, x: Iterable[_T]) -> MutableList[_T]: # type: ignore[override,misc] # noqa: E501 + self.extend(x) + return self + + def insert(self, i: SupportsIndex, x: _T) -> None: + super().insert(i, x) + self.changed() + + def remove(self, i: _T) -> None: + super().remove(i) + self.changed() + + def clear(self) -> None: + super().clear() + self.changed() + + def sort(self, **kw: Any) -> None: + super().sort(**kw) + self.changed() + + def reverse(self) -> None: + super().reverse() + self.changed() + + @classmethod + def coerce( + cls, key: str, value: MutableList[_T] | _T + ) -> Optional[MutableList[_T]]: + """Convert plain list to instance of this class.""" + if not isinstance(value, cls): + if isinstance(value, list): + return cls(value) + return Mutable.coerce(key, value) + else: + return value + + +class MutableSet(Mutable, Set[_T]): + """A set type that implements :class:`.Mutable`. + + The :class:`.MutableSet` object implements a set that will + emit change events to the underlying mapping when the contents of + the set are altered, including when values are added or removed. + + Note that :class:`.MutableSet` does **not** apply mutable tracking to the + *values themselves* inside the set. Therefore it is not a sufficient + solution for the use case of tracking deep changes to a *recursive* + mutable structure. To support this use case, + build a subclass of :class:`.MutableSet` that provides appropriate + coercion to the values placed in the dictionary so that they too are + "mutable", and emit events up to their parent structure. + + .. seealso:: + + :class:`.MutableDict` + + :class:`.MutableList` + + + """ + + def update(self, *arg: Iterable[_T]) -> None: + super().update(*arg) + self.changed() + + def intersection_update(self, *arg: Iterable[Any]) -> None: + super().intersection_update(*arg) + self.changed() + + def difference_update(self, *arg: Iterable[Any]) -> None: + super().difference_update(*arg) + self.changed() + + def symmetric_difference_update(self, *arg: Iterable[_T]) -> None: + super().symmetric_difference_update(*arg) + self.changed() + + def __ior__(self, other: AbstractSet[_T]) -> MutableSet[_T]: # type: ignore[override,misc] # noqa: E501 + self.update(other) + return self + + def __iand__(self, other: AbstractSet[object]) -> MutableSet[_T]: + self.intersection_update(other) + return self + + def __ixor__(self, other: AbstractSet[_T]) -> MutableSet[_T]: # type: ignore[override,misc] # noqa: E501 + self.symmetric_difference_update(other) + return self + + def __isub__(self, other: AbstractSet[object]) -> MutableSet[_T]: # type: ignore[misc] # noqa: E501 + self.difference_update(other) + return self + + def add(self, elem: _T) -> None: + super().add(elem) + self.changed() + + def remove(self, elem: _T) -> None: + super().remove(elem) + self.changed() + + def discard(self, elem: _T) -> None: + super().discard(elem) + self.changed() + + def pop(self, *arg: Any) -> _T: + result = super().pop(*arg) + self.changed() + return result + + def clear(self) -> None: + super().clear() + self.changed() + + @classmethod + def coerce(cls, index: str, value: Any) -> Optional[MutableSet[_T]]: + """Convert plain set to instance of this class.""" + if not isinstance(value, cls): + if isinstance(value, set): + return cls(value) + return Mutable.coerce(index, value) + else: + return value + + def __getstate__(self) -> Set[_T]: + return set(self) + + def __setstate__(self, state: Iterable[_T]) -> None: + self.update(state) + + def __reduce_ex__( + self, proto: SupportsIndex + ) -> Tuple[type, Tuple[List[int]]]: + return (self.__class__, (list(self),)) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__init__.py new file mode 100644 index 0000000..de2c02e --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__init__.py @@ -0,0 +1,6 @@ +# ext/mypy/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..7ad6efd --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/apply.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/apply.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6072e1d --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/apply.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/decl_class.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/decl_class.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0b6844d --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/decl_class.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/infer.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/infer.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..98231e9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/infer.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/names.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/names.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..41c9ba3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/names.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/plugin.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/plugin.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..30fab74 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/plugin.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/util.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/util.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ee8ba78 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/util.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/apply.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/apply.py new file mode 100644 index 0000000..eb90194 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/apply.py @@ -0,0 +1,320 @@ +# ext/mypy/apply.py +# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import List +from typing import Optional +from typing import Union + +from mypy.nodes import ARG_NAMED_OPT +from mypy.nodes import Argument +from mypy.nodes import AssignmentStmt +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import MDEF +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import RefExpr +from mypy.nodes import StrExpr +from mypy.nodes import SymbolTableNode +from mypy.nodes import TempNode +from mypy.nodes import TypeInfo +from mypy.nodes import Var +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.plugins.common import add_method_to_class +from mypy.types import AnyType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneTyp +from mypy.types import ProperType +from mypy.types import TypeOfAny +from mypy.types import UnboundType +from mypy.types import UnionType + +from . import infer +from . import util +from .names import expr_to_mapped_constructor +from .names import NAMED_TYPE_SQLA_MAPPED + + +def apply_mypy_mapped_attr( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + item: Union[NameExpr, StrExpr], + attributes: List[util.SQLAlchemyAttribute], +) -> None: + if isinstance(item, NameExpr): + name = item.name + elif isinstance(item, StrExpr): + name = item.value + else: + return None + + for stmt in cls.defs.body: + if ( + isinstance(stmt, AssignmentStmt) + and isinstance(stmt.lvalues[0], NameExpr) + and stmt.lvalues[0].name == name + ): + break + else: + util.fail(api, f"Can't find mapped attribute {name}", cls) + return None + + if stmt.type is None: + util.fail( + api, + "Statement linked from _mypy_mapped_attrs has no " + "typing information", + stmt, + ) + return None + + left_hand_explicit_type = get_proper_type(stmt.type) + assert isinstance( + left_hand_explicit_type, (Instance, UnionType, UnboundType) + ) + + attributes.append( + util.SQLAlchemyAttribute( + name=name, + line=item.line, + column=item.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) + + apply_type_to_mapped_statement( + api, stmt, stmt.lvalues[0], left_hand_explicit_type, None + ) + + +def re_apply_declarative_assignments( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """For multiple class passes, re-apply our left-hand side types as mypy + seems to reset them in place. + + """ + mapped_attr_lookup = {attr.name: attr for attr in attributes} + update_cls_metadata = False + + for stmt in cls.defs.body: + # for a re-apply, all of our statements are AssignmentStmt; + # @declared_attr calls will have been converted and this + # currently seems to be preserved by mypy (but who knows if this + # will change). + if ( + isinstance(stmt, AssignmentStmt) + and isinstance(stmt.lvalues[0], NameExpr) + and stmt.lvalues[0].name in mapped_attr_lookup + and isinstance(stmt.lvalues[0].node, Var) + ): + left_node = stmt.lvalues[0].node + + python_type_for_type = mapped_attr_lookup[ + stmt.lvalues[0].name + ].type + + left_node_proper_type = get_proper_type(left_node.type) + + # if we have scanned an UnboundType and now there's a more + # specific type than UnboundType, call the re-scan so we + # can get that set up correctly + if ( + isinstance(python_type_for_type, UnboundType) + and not isinstance(left_node_proper_type, UnboundType) + and ( + isinstance(stmt.rvalue, CallExpr) + and isinstance(stmt.rvalue.callee, MemberExpr) + and isinstance(stmt.rvalue.callee.expr, NameExpr) + and stmt.rvalue.callee.expr.node is not None + and stmt.rvalue.callee.expr.node.fullname + == NAMED_TYPE_SQLA_MAPPED + and stmt.rvalue.callee.name == "_empty_constructor" + and isinstance(stmt.rvalue.args[0], CallExpr) + and isinstance(stmt.rvalue.args[0].callee, RefExpr) + ) + ): + new_python_type_for_type = ( + infer.infer_type_from_right_hand_nameexpr( + api, + stmt, + left_node, + left_node_proper_type, + stmt.rvalue.args[0].callee, + ) + ) + + if new_python_type_for_type is not None and not isinstance( + new_python_type_for_type, UnboundType + ): + python_type_for_type = new_python_type_for_type + + # update the SQLAlchemyAttribute with the better + # information + mapped_attr_lookup[stmt.lvalues[0].name].type = ( + python_type_for_type + ) + + update_cls_metadata = True + + if ( + not isinstance(left_node.type, Instance) + or left_node.type.type.fullname != NAMED_TYPE_SQLA_MAPPED + ): + assert python_type_for_type is not None + left_node.type = api.named_type( + NAMED_TYPE_SQLA_MAPPED, [python_type_for_type] + ) + + if update_cls_metadata: + util.set_mapped_attributes(cls.info, attributes) + + +def apply_type_to_mapped_statement( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + lvalue: NameExpr, + left_hand_explicit_type: Optional[ProperType], + python_type_for_type: Optional[ProperType], +) -> None: + """Apply the Mapped[<type>] annotation and right hand object to a + declarative assignment statement. + + This converts a Python declarative class statement such as:: + + class User(Base): + # ... + + attrname = Column(Integer) + + To one that describes the final Python behavior to Mypy:: + + class User(Base): + # ... + + attrname : Mapped[Optional[int]] = <meaningless temp node> + + """ + left_node = lvalue.node + assert isinstance(left_node, Var) + + # to be completely honest I have no idea what the difference between + # left_node.type and stmt.type is, what it means if these are different + # vs. the same, why in order to get tests to pass I have to assign + # to stmt.type for the second case and not the first. this is complete + # trying every combination until it works stuff. + + if left_hand_explicit_type is not None: + lvalue.is_inferred_def = False + left_node.type = api.named_type( + NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type] + ) + else: + lvalue.is_inferred_def = False + left_node.type = api.named_type( + NAMED_TYPE_SQLA_MAPPED, + ( + [AnyType(TypeOfAny.special_form)] + if python_type_for_type is None + else [python_type_for_type] + ), + ) + + # so to have it skip the right side totally, we can do this: + # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form)) + + # however, if we instead manufacture a new node that uses the old + # one, then we can still get type checking for the call itself, + # e.g. the Column, relationship() call, etc. + + # rewrite the node as: + # <attr> : Mapped[<typ>] = + # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>) + # the original right-hand side is maintained so it gets type checked + # internally + stmt.rvalue = expr_to_mapped_constructor(stmt.rvalue) + + if stmt.type is not None and python_type_for_type is not None: + stmt.type = python_type_for_type + + +def add_additional_orm_attributes( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Apply __init__, __table__ and other attributes to the mapped class.""" + + info = util.info_for_cls(cls, api) + + if info is None: + return + + is_base = util.get_is_base(info) + + if "__init__" not in info.names and not is_base: + mapped_attr_names = {attr.name: attr.type for attr in attributes} + + for base in info.mro[1:-1]: + if "sqlalchemy" not in info.metadata: + continue + + base_cls_attributes = util.get_mapped_attributes(base, api) + if base_cls_attributes is None: + continue + + for attr in base_cls_attributes: + mapped_attr_names.setdefault(attr.name, attr.type) + + arguments = [] + for name, typ in mapped_attr_names.items(): + if typ is None: + typ = AnyType(TypeOfAny.special_form) + arguments.append( + Argument( + variable=Var(name, typ), + type_annotation=typ, + initializer=TempNode(typ), + kind=ARG_NAMED_OPT, + ) + ) + + add_method_to_class(api, cls, "__init__", arguments, NoneTyp()) + + if "__table__" not in info.names and util.get_has_table(info): + _apply_placeholder_attr_to_class( + api, cls, "sqlalchemy.sql.schema.Table", "__table__" + ) + if not is_base: + _apply_placeholder_attr_to_class( + api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__" + ) + + +def _apply_placeholder_attr_to_class( + api: SemanticAnalyzerPluginInterface, + cls: ClassDef, + qualified_name: str, + attrname: str, +) -> None: + sym = api.lookup_fully_qualified_or_none(qualified_name) + if sym: + assert isinstance(sym.node, TypeInfo) + type_: ProperType = Instance(sym.node, []) + else: + type_ = AnyType(TypeOfAny.special_form) + var = Var(attrname) + var._fullname = cls.fullname + "." + attrname + var.info = cls.info + var.type = type_ + cls.info.names[attrname] = SymbolTableNode(MDEF, var) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/decl_class.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/decl_class.py new file mode 100644 index 0000000..3d578b3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/decl_class.py @@ -0,0 +1,515 @@ +# ext/mypy/decl_class.py +# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import List +from typing import Optional +from typing import Union + +from mypy.nodes import AssignmentStmt +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import Decorator +from mypy.nodes import LambdaExpr +from mypy.nodes import ListExpr +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import PlaceholderNode +from mypy.nodes import RefExpr +from mypy.nodes import StrExpr +from mypy.nodes import SymbolNode +from mypy.nodes import SymbolTableNode +from mypy.nodes import TempNode +from mypy.nodes import TypeInfo +from mypy.nodes import Var +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.types import AnyType +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneType +from mypy.types import ProperType +from mypy.types import Type +from mypy.types import TypeOfAny +from mypy.types import UnboundType +from mypy.types import UnionType + +from . import apply +from . import infer +from . import names +from . import util + + +def scan_declarative_assignments_and_apply_types( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + is_mixin_scan: bool = False, +) -> Optional[List[util.SQLAlchemyAttribute]]: + info = util.info_for_cls(cls, api) + + if info is None: + # this can occur during cached passes + return None + elif cls.fullname.startswith("builtins"): + return None + + mapped_attributes: Optional[List[util.SQLAlchemyAttribute]] = ( + util.get_mapped_attributes(info, api) + ) + + # used by assign.add_additional_orm_attributes among others + util.establish_as_sqlalchemy(info) + + if mapped_attributes is not None: + # ensure that a class that's mapped is always picked up by + # its mapped() decorator or declarative metaclass before + # it would be detected as an unmapped mixin class + + if not is_mixin_scan: + # mypy can call us more than once. it then *may* have reset the + # left hand side of everything, but not the right that we removed, + # removing our ability to re-scan. but we have the types + # here, so lets re-apply them, or if we have an UnboundType, + # we can re-scan + + apply.re_apply_declarative_assignments(cls, api, mapped_attributes) + + return mapped_attributes + + mapped_attributes = [] + + if not cls.defs.body: + # when we get a mixin class from another file, the body is + # empty (!) but the names are in the symbol table. so use that. + + for sym_name, sym in info.names.items(): + _scan_symbol_table_entry( + cls, api, sym_name, sym, mapped_attributes + ) + else: + for stmt in util.flatten_typechecking(cls.defs.body): + if isinstance(stmt, AssignmentStmt): + _scan_declarative_assignment_stmt( + cls, api, stmt, mapped_attributes + ) + elif isinstance(stmt, Decorator): + _scan_declarative_decorator_stmt( + cls, api, stmt, mapped_attributes + ) + _scan_for_mapped_bases(cls, api) + + if not is_mixin_scan: + apply.add_additional_orm_attributes(cls, api, mapped_attributes) + + util.set_mapped_attributes(info, mapped_attributes) + + return mapped_attributes + + +def _scan_symbol_table_entry( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + name: str, + value: SymbolTableNode, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Extract mapping information from a SymbolTableNode that's in the + type.names dictionary. + + """ + value_type = get_proper_type(value.type) + if not isinstance(value_type, Instance): + return + + left_hand_explicit_type = None + type_id = names.type_id_for_named_node(value_type.type) + # type_id = names._type_id_for_unbound_type(value.type.type, cls, api) + + err = False + + # TODO: this is nearly the same logic as that of + # _scan_declarative_decorator_stmt, likely can be merged + if type_id in { + names.MAPPED, + names.RELATIONSHIP, + names.COMPOSITE_PROPERTY, + names.MAPPER_PROPERTY, + names.SYNONYM_PROPERTY, + names.COLUMN_PROPERTY, + }: + if value_type.args: + left_hand_explicit_type = get_proper_type(value_type.args[0]) + else: + err = True + elif type_id is names.COLUMN: + if not value_type.args: + err = True + else: + typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type( + value_type.args[0] + ) + if isinstance(typeengine_arg, Instance): + typeengine_arg = typeengine_arg.type + + if isinstance(typeengine_arg, (UnboundType, TypeInfo)): + sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg) + if sym is not None and isinstance(sym.node, TypeInfo): + if names.has_base_type_id(sym.node, names.TYPEENGINE): + left_hand_explicit_type = UnionType( + [ + infer.extract_python_type_from_typeengine( + api, sym.node, [] + ), + NoneType(), + ] + ) + else: + util.fail( + api, + "Column type should be a TypeEngine " + "subclass not '{}'".format(sym.node.fullname), + value_type, + ) + + if err: + msg = ( + "Can't infer type from attribute {} on class {}. " + "please specify a return type from this function that is " + "one of: Mapped[<python type>], relationship[<target class>], " + "Column[<TypeEngine>], MapperProperty[<python type>]" + ) + util.fail(api, msg.format(name, cls.name), cls) + + left_hand_explicit_type = AnyType(TypeOfAny.special_form) + + if left_hand_explicit_type is not None: + assert value.node is not None + attributes.append( + util.SQLAlchemyAttribute( + name=name, + line=value.node.line, + column=value.node.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) + + +def _scan_declarative_decorator_stmt( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + stmt: Decorator, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Extract mapping information from a @declared_attr in a declarative + class. + + E.g.:: + + @reg.mapped + class MyClass: + # ... + + @declared_attr + def updated_at(cls) -> Column[DateTime]: + return Column(DateTime) + + Will resolve in mypy as:: + + @reg.mapped + class MyClass: + # ... + + updated_at: Mapped[Optional[datetime.datetime]] + + """ + for dec in stmt.decorators: + if ( + isinstance(dec, (NameExpr, MemberExpr, SymbolNode)) + and names.type_id_for_named_node(dec) is names.DECLARED_ATTR + ): + break + else: + return + + dec_index = cls.defs.body.index(stmt) + + left_hand_explicit_type: Optional[ProperType] = None + + if util.name_is_dunder(stmt.name): + # for dunder names like __table_args__, __tablename__, + # __mapper_args__ etc., rewrite these as simple assignment + # statements; otherwise mypy doesn't like if the decorated + # function has an annotation like ``cls: Type[Foo]`` because + # it isn't @classmethod + any_ = AnyType(TypeOfAny.special_form) + left_node = NameExpr(stmt.var.name) + left_node.node = stmt.var + new_stmt = AssignmentStmt([left_node], TempNode(any_)) + new_stmt.type = left_node.node.type + cls.defs.body[dec_index] = new_stmt + return + elif isinstance(stmt.func.type, CallableType): + func_type = stmt.func.type.ret_type + if isinstance(func_type, UnboundType): + type_id = names.type_id_for_unbound_type(func_type, cls, api) + else: + # this does not seem to occur unless the type argument is + # incorrect + return + + if ( + type_id + in { + names.MAPPED, + names.RELATIONSHIP, + names.COMPOSITE_PROPERTY, + names.MAPPER_PROPERTY, + names.SYNONYM_PROPERTY, + names.COLUMN_PROPERTY, + } + and func_type.args + ): + left_hand_explicit_type = get_proper_type(func_type.args[0]) + elif type_id is names.COLUMN and func_type.args: + typeengine_arg = func_type.args[0] + if isinstance(typeengine_arg, UnboundType): + sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg) + if sym is not None and isinstance(sym.node, TypeInfo): + if names.has_base_type_id(sym.node, names.TYPEENGINE): + left_hand_explicit_type = UnionType( + [ + infer.extract_python_type_from_typeengine( + api, sym.node, [] + ), + NoneType(), + ] + ) + else: + util.fail( + api, + "Column type should be a TypeEngine " + "subclass not '{}'".format(sym.node.fullname), + func_type, + ) + + if left_hand_explicit_type is None: + # no type on the decorated function. our option here is to + # dig into the function body and get the return type, but they + # should just have an annotation. + msg = ( + "Can't infer type from @declared_attr on function '{}'; " + "please specify a return type from this function that is " + "one of: Mapped[<python type>], relationship[<target class>], " + "Column[<TypeEngine>], MapperProperty[<python type>]" + ) + util.fail(api, msg.format(stmt.var.name), stmt) + + left_hand_explicit_type = AnyType(TypeOfAny.special_form) + + left_node = NameExpr(stmt.var.name) + left_node.node = stmt.var + + # totally feeling around in the dark here as I don't totally understand + # the significance of UnboundType. It seems to be something that is + # not going to do what's expected when it is applied as the type of + # an AssignmentStatement. So do a feeling-around-in-the-dark version + # of converting it to the regular Instance/TypeInfo/UnionType structures + # we see everywhere else. + if isinstance(left_hand_explicit_type, UnboundType): + left_hand_explicit_type = get_proper_type( + util.unbound_to_instance(api, left_hand_explicit_type) + ) + + left_node.node.type = api.named_type( + names.NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type] + ) + + # this will ignore the rvalue entirely + # rvalue = TempNode(AnyType(TypeOfAny.special_form)) + + # rewrite the node as: + # <attr> : Mapped[<typ>] = + # _sa_Mapped._empty_constructor(lambda: <function body>) + # the function body is maintained so it gets type checked internally + rvalue = names.expr_to_mapped_constructor( + LambdaExpr(stmt.func.arguments, stmt.func.body) + ) + + new_stmt = AssignmentStmt([left_node], rvalue) + new_stmt.type = left_node.node.type + + attributes.append( + util.SQLAlchemyAttribute( + name=left_node.name, + line=stmt.line, + column=stmt.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) + cls.defs.body[dec_index] = new_stmt + + +def _scan_declarative_assignment_stmt( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Extract mapping information from an assignment statement in a + declarative class. + + """ + lvalue = stmt.lvalues[0] + if not isinstance(lvalue, NameExpr): + return + + sym = cls.info.names.get(lvalue.name) + + # this establishes that semantic analysis has taken place, which + # means the nodes are populated and we are called from an appropriate + # hook. + assert sym is not None + node = sym.node + + if isinstance(node, PlaceholderNode): + return + + assert node is lvalue.node + assert isinstance(node, Var) + + if node.name == "__abstract__": + if api.parse_bool(stmt.rvalue) is True: + util.set_is_base(cls.info) + return + elif node.name == "__tablename__": + util.set_has_table(cls.info) + elif node.name.startswith("__"): + return + elif node.name == "_mypy_mapped_attrs": + if not isinstance(stmt.rvalue, ListExpr): + util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt) + else: + for item in stmt.rvalue.items: + if isinstance(item, (NameExpr, StrExpr)): + apply.apply_mypy_mapped_attr(cls, api, item, attributes) + + left_hand_mapped_type: Optional[Type] = None + left_hand_explicit_type: Optional[ProperType] = None + + if node.is_inferred or node.type is None: + if isinstance(stmt.type, UnboundType): + # look for an explicit Mapped[] type annotation on the left + # side with nothing on the right + + # print(stmt.type) + # Mapped?[Optional?[A?]] + + left_hand_explicit_type = stmt.type + + if stmt.type.name == "Mapped": + mapped_sym = api.lookup_qualified("Mapped", cls) + if ( + mapped_sym is not None + and mapped_sym.node is not None + and names.type_id_for_named_node(mapped_sym.node) + is names.MAPPED + ): + left_hand_explicit_type = get_proper_type( + stmt.type.args[0] + ) + left_hand_mapped_type = stmt.type + + # TODO: do we need to convert from unbound for this case? + # left_hand_explicit_type = util._unbound_to_instance( + # api, left_hand_explicit_type + # ) + else: + node_type = get_proper_type(node.type) + if ( + isinstance(node_type, Instance) + and names.type_id_for_named_node(node_type.type) is names.MAPPED + ): + # print(node.type) + # sqlalchemy.orm.attributes.Mapped[<python type>] + left_hand_explicit_type = get_proper_type(node_type.args[0]) + left_hand_mapped_type = node_type + else: + # print(node.type) + # <python type> + left_hand_explicit_type = node_type + left_hand_mapped_type = None + + if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None: + # annotation without assignment and Mapped is present + # as type annotation + # equivalent to using _infer_type_from_left_hand_type_only. + + python_type_for_type = left_hand_explicit_type + elif isinstance(stmt.rvalue, CallExpr) and isinstance( + stmt.rvalue.callee, RefExpr + ): + python_type_for_type = infer.infer_type_from_right_hand_nameexpr( + api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee + ) + + if python_type_for_type is None: + return + + else: + return + + assert python_type_for_type is not None + + attributes.append( + util.SQLAlchemyAttribute( + name=node.name, + line=stmt.line, + column=stmt.column, + typ=python_type_for_type, + info=cls.info, + ) + ) + + apply.apply_type_to_mapped_statement( + api, + stmt, + lvalue, + left_hand_explicit_type, + python_type_for_type, + ) + + +def _scan_for_mapped_bases( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, +) -> None: + """Given a class, iterate through its superclass hierarchy to find + all other classes that are considered as ORM-significant. + + Locates non-mapped mixins and scans them for mapped attributes to be + applied to subclasses. + + """ + + info = util.info_for_cls(cls, api) + + if info is None: + return + + for base_info in info.mro[1:-1]: + if base_info.fullname.startswith("builtins"): + continue + + # scan each base for mapped attributes. if they are not already + # scanned (but have all their type info), that means they are unmapped + # mixins + scan_declarative_assignments_and_apply_types( + base_info.defn, api, is_mixin_scan=True + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/infer.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/infer.py new file mode 100644 index 0000000..09b3c44 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/infer.py @@ -0,0 +1,590 @@ +# ext/mypy/infer.py +# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Optional +from typing import Sequence + +from mypy.maptype import map_instance_to_supertype +from mypy.nodes import AssignmentStmt +from mypy.nodes import CallExpr +from mypy.nodes import Expression +from mypy.nodes import FuncDef +from mypy.nodes import LambdaExpr +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import RefExpr +from mypy.nodes import StrExpr +from mypy.nodes import TypeInfo +from mypy.nodes import Var +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.subtypes import is_subtype +from mypy.types import AnyType +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneType +from mypy.types import ProperType +from mypy.types import TypeOfAny +from mypy.types import UnionType + +from . import names +from . import util + + +def infer_type_from_right_hand_nameexpr( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + infer_from_right_side: RefExpr, +) -> Optional[ProperType]: + type_id = names.type_id_for_callee(infer_from_right_side) + if type_id is None: + return None + elif type_id is names.MAPPED: + python_type_for_type = _infer_type_from_mapped( + api, stmt, node, left_hand_explicit_type, infer_from_right_side + ) + elif type_id is names.COLUMN: + python_type_for_type = _infer_type_from_decl_column( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.RELATIONSHIP: + python_type_for_type = _infer_type_from_relationship( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.COLUMN_PROPERTY: + python_type_for_type = _infer_type_from_decl_column_property( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.SYNONYM_PROPERTY: + python_type_for_type = infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif type_id is names.COMPOSITE_PROPERTY: + python_type_for_type = _infer_type_from_decl_composite_property( + api, stmt, node, left_hand_explicit_type + ) + else: + return None + + return python_type_for_type + + +def _infer_type_from_relationship( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Infer the type of mapping from a relationship. + + E.g.:: + + @reg.mapped + class MyClass: + # ... + + addresses = relationship(Address, uselist=True) + + order: Mapped["Order"] = relationship("Order") + + Will resolve in mypy as:: + + @reg.mapped + class MyClass: + # ... + + addresses: Mapped[List[Address]] + + order: Mapped["Order"] + + """ + + assert isinstance(stmt.rvalue, CallExpr) + target_cls_arg = stmt.rvalue.args[0] + python_type_for_type: Optional[ProperType] = None + + if isinstance(target_cls_arg, NameExpr) and isinstance( + target_cls_arg.node, TypeInfo + ): + # type + related_object_type = target_cls_arg.node + python_type_for_type = Instance(related_object_type, []) + + # other cases not covered - an error message directs the user + # to set an explicit type annotation + # + # node.type == str, it's a string + # if isinstance(target_cls_arg, NameExpr) and isinstance( + # target_cls_arg.node, Var + # ) + # points to a type + # isinstance(target_cls_arg, NameExpr) and isinstance( + # target_cls_arg.node, TypeAlias + # ) + # string expression + # isinstance(target_cls_arg, StrExpr) + + uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist") + collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg( + stmt.rvalue, "collection_class" + ) + type_is_a_collection = False + + # this can be used to determine Optional for a many-to-one + # in the same way nullable=False could be used, if we start supporting + # that. + # innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin") + + if ( + uselist_arg is not None + and api.parse_bool(uselist_arg) is True + and collection_cls_arg is None + ): + type_is_a_collection = True + if python_type_for_type is not None: + python_type_for_type = api.named_type( + names.NAMED_TYPE_BUILTINS_LIST, [python_type_for_type] + ) + elif ( + uselist_arg is None or api.parse_bool(uselist_arg) is True + ) and collection_cls_arg is not None: + type_is_a_collection = True + if isinstance(collection_cls_arg, CallExpr): + collection_cls_arg = collection_cls_arg.callee + + if isinstance(collection_cls_arg, NameExpr) and isinstance( + collection_cls_arg.node, TypeInfo + ): + if python_type_for_type is not None: + # this can still be overridden by the left hand side + # within _infer_Type_from_left_and_inferred_right + python_type_for_type = Instance( + collection_cls_arg.node, [python_type_for_type] + ) + elif ( + isinstance(collection_cls_arg, NameExpr) + and isinstance(collection_cls_arg.node, FuncDef) + and collection_cls_arg.node.type is not None + ): + if python_type_for_type is not None: + # this can still be overridden by the left hand side + # within _infer_Type_from_left_and_inferred_right + + # TODO: handle mypy.types.Overloaded + if isinstance(collection_cls_arg.node.type, CallableType): + rt = get_proper_type(collection_cls_arg.node.type.ret_type) + + if isinstance(rt, CallableType): + callable_ret_type = get_proper_type(rt.ret_type) + if isinstance(callable_ret_type, Instance): + python_type_for_type = Instance( + callable_ret_type.type, + [python_type_for_type], + ) + else: + util.fail( + api, + "Expected Python collection type for " + "collection_class parameter", + stmt.rvalue, + ) + python_type_for_type = None + elif uselist_arg is not None and api.parse_bool(uselist_arg) is False: + if collection_cls_arg is not None: + util.fail( + api, + "Sending uselist=False and collection_class at the same time " + "does not make sense", + stmt.rvalue, + ) + if python_type_for_type is not None: + python_type_for_type = UnionType( + [python_type_for_type, NoneType()] + ) + + else: + if left_hand_explicit_type is None: + msg = ( + "Can't infer scalar or collection for ORM mapped expression " + "assigned to attribute '{}' if both 'uselist' and " + "'collection_class' arguments are absent from the " + "relationship(); please specify a " + "type annotation on the left hand side." + ) + util.fail(api, msg.format(node.name), node) + + if python_type_for_type is None: + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif left_hand_explicit_type is not None: + if type_is_a_collection: + assert isinstance(left_hand_explicit_type, Instance) + assert isinstance(python_type_for_type, Instance) + return _infer_collection_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + else: + return _infer_type_from_left_and_inferred_right( + api, + node, + left_hand_explicit_type, + python_type_for_type, + ) + else: + return python_type_for_type + + +def _infer_type_from_decl_composite_property( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Infer the type of mapping from a Composite.""" + + assert isinstance(stmt.rvalue, CallExpr) + target_cls_arg = stmt.rvalue.args[0] + python_type_for_type = None + + if isinstance(target_cls_arg, NameExpr) and isinstance( + target_cls_arg.node, TypeInfo + ): + related_object_type = target_cls_arg.node + python_type_for_type = Instance(related_object_type, []) + else: + python_type_for_type = None + + if python_type_for_type is None: + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif left_hand_explicit_type is not None: + return _infer_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + else: + return python_type_for_type + + +def _infer_type_from_mapped( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + infer_from_right_side: RefExpr, +) -> Optional[ProperType]: + """Infer the type of mapping from a right side expression + that returns Mapped. + + + """ + assert isinstance(stmt.rvalue, CallExpr) + + # (Pdb) print(stmt.rvalue.callee) + # NameExpr(query_expression [sqlalchemy.orm._orm_constructors.query_expression]) # noqa: E501 + # (Pdb) stmt.rvalue.callee.node + # <mypy.nodes.FuncDef object at 0x7f8d92fb5940> + # (Pdb) stmt.rvalue.callee.node.type + # def [_T] (default_expr: sqlalchemy.sql.elements.ColumnElement[_T`-1] =) -> sqlalchemy.orm.base.Mapped[_T`-1] # noqa: E501 + # sqlalchemy.orm.base.Mapped[_T`-1] + # the_mapped_type = stmt.rvalue.callee.node.type.ret_type + + # TODO: look at generic ref and either use that, + # or reconcile w/ what's present, etc. + the_mapped_type = util.type_for_callee(infer_from_right_side) # noqa + + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + +def _infer_type_from_decl_column_property( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Infer the type of mapping from a ColumnProperty. + + This includes mappings against ``column_property()`` as well as the + ``deferred()`` function. + + """ + assert isinstance(stmt.rvalue, CallExpr) + + if stmt.rvalue.args: + first_prop_arg = stmt.rvalue.args[0] + + if isinstance(first_prop_arg, CallExpr): + type_id = names.type_id_for_callee(first_prop_arg.callee) + + # look for column_property() / deferred() etc with Column as first + # argument + if type_id is names.COLUMN: + return _infer_type_from_decl_column( + api, + stmt, + node, + left_hand_explicit_type, + right_hand_expression=first_prop_arg, + ) + + if isinstance(stmt.rvalue, CallExpr): + type_id = names.type_id_for_callee(stmt.rvalue.callee) + # this is probably not strictly necessary as we have to use the left + # hand type for query expression in any case. any other no-arg + # column prop objects would go here also + if type_id is names.QUERY_EXPRESSION: + return _infer_type_from_decl_column( + api, + stmt, + node, + left_hand_explicit_type, + ) + + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + +def _infer_type_from_decl_column( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + right_hand_expression: Optional[CallExpr] = None, +) -> Optional[ProperType]: + """Infer the type of mapping from a Column. + + E.g.:: + + @reg.mapped + class MyClass: + # ... + + a = Column(Integer) + + b = Column("b", String) + + c: Mapped[int] = Column(Integer) + + d: bool = Column(Boolean) + + Will resolve in MyPy as:: + + @reg.mapped + class MyClass: + # ... + + a : Mapped[int] + + b : Mapped[str] + + c: Mapped[int] + + d: Mapped[bool] + + """ + assert isinstance(node, Var) + + callee = None + + if right_hand_expression is None: + if not isinstance(stmt.rvalue, CallExpr): + return None + + right_hand_expression = stmt.rvalue + + for column_arg in right_hand_expression.args[0:2]: + if isinstance(column_arg, CallExpr): + if isinstance(column_arg.callee, RefExpr): + # x = Column(String(50)) + callee = column_arg.callee + type_args: Sequence[Expression] = column_arg.args + break + elif isinstance(column_arg, (NameExpr, MemberExpr)): + if isinstance(column_arg.node, TypeInfo): + # x = Column(String) + callee = column_arg + type_args = () + break + else: + # x = Column(some_name, String), go to next argument + continue + elif isinstance(column_arg, (StrExpr,)): + # x = Column("name", String), go to next argument + continue + elif isinstance(column_arg, (LambdaExpr,)): + # x = Column("name", String, default=lambda: uuid.uuid4()) + # go to next argument + continue + else: + assert False + + if callee is None: + return None + + if isinstance(callee.node, TypeInfo) and names.mro_has_id( + callee.node.mro, names.TYPEENGINE + ): + python_type_for_type = extract_python_type_from_typeengine( + api, callee.node, type_args + ) + + if left_hand_explicit_type is not None: + return _infer_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + + else: + return UnionType([python_type_for_type, NoneType()]) + else: + # it's not TypeEngine, it's typically implicitly typed + # like ForeignKey. we can't infer from the right side. + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + +def _infer_type_from_left_and_inferred_right( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: ProperType, + python_type_for_type: ProperType, + orig_left_hand_type: Optional[ProperType] = None, + orig_python_type_for_type: Optional[ProperType] = None, +) -> Optional[ProperType]: + """Validate type when a left hand annotation is present and we also + could infer the right hand side:: + + attrname: SomeType = Column(SomeDBType) + + """ + + if orig_left_hand_type is None: + orig_left_hand_type = left_hand_explicit_type + if orig_python_type_for_type is None: + orig_python_type_for_type = python_type_for_type + + if not is_subtype(left_hand_explicit_type, python_type_for_type): + effective_type = api.named_type( + names.NAMED_TYPE_SQLA_MAPPED, [orig_python_type_for_type] + ) + + msg = ( + "Left hand assignment '{}: {}' not compatible " + "with ORM mapped expression of type {}" + ) + util.fail( + api, + msg.format( + node.name, + util.format_type(orig_left_hand_type, api.options), + util.format_type(effective_type, api.options), + ), + node, + ) + + return orig_left_hand_type + + +def _infer_collection_type_from_left_and_inferred_right( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: Instance, + python_type_for_type: Instance, +) -> Optional[ProperType]: + orig_left_hand_type = left_hand_explicit_type + orig_python_type_for_type = python_type_for_type + + if left_hand_explicit_type.args: + left_hand_arg = get_proper_type(left_hand_explicit_type.args[0]) + python_type_arg = get_proper_type(python_type_for_type.args[0]) + else: + left_hand_arg = left_hand_explicit_type + python_type_arg = python_type_for_type + + assert isinstance(left_hand_arg, (Instance, UnionType)) + assert isinstance(python_type_arg, (Instance, UnionType)) + + return _infer_type_from_left_and_inferred_right( + api, + node, + left_hand_arg, + python_type_arg, + orig_left_hand_type=orig_left_hand_type, + orig_python_type_for_type=orig_python_type_for_type, + ) + + +def infer_type_from_left_hand_type_only( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Determine the type based on explicit annotation only. + + if no annotation were present, note that we need one there to know + the type. + + """ + if left_hand_explicit_type is None: + msg = ( + "Can't infer type from ORM mapped expression " + "assigned to attribute '{}'; please specify a " + "Python type or " + "Mapped[<python type>] on the left hand side." + ) + util.fail(api, msg.format(node.name), node) + + return api.named_type( + names.NAMED_TYPE_SQLA_MAPPED, [AnyType(TypeOfAny.special_form)] + ) + + else: + # use type from the left hand side + return left_hand_explicit_type + + +def extract_python_type_from_typeengine( + api: SemanticAnalyzerPluginInterface, + node: TypeInfo, + type_args: Sequence[Expression], +) -> ProperType: + if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args: + first_arg = type_args[0] + if isinstance(first_arg, RefExpr) and isinstance( + first_arg.node, TypeInfo + ): + for base_ in first_arg.node.mro: + if base_.fullname == "enum.Enum": + return Instance(first_arg.node, []) + # TODO: support other pep-435 types here + else: + return api.named_type(names.NAMED_TYPE_BUILTINS_STR, []) + + assert node.has_base("sqlalchemy.sql.type_api.TypeEngine"), ( + "could not extract Python type from node: %s" % node + ) + + type_engine_sym = api.lookup_fully_qualified_or_none( + "sqlalchemy.sql.type_api.TypeEngine" + ) + + assert type_engine_sym is not None and isinstance( + type_engine_sym.node, TypeInfo + ) + type_engine = map_instance_to_supertype( + Instance(node, []), + type_engine_sym.node, + ) + return get_proper_type(type_engine.args[-1]) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/names.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/names.py new file mode 100644 index 0000000..fc3d708 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/names.py @@ -0,0 +1,335 @@ +# ext/mypy/names.py +# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Dict +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Union + +from mypy.nodes import ARG_POS +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import Decorator +from mypy.nodes import Expression +from mypy.nodes import FuncDef +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import OverloadedFuncDef +from mypy.nodes import SymbolNode +from mypy.nodes import TypeAlias +from mypy.nodes import TypeInfo +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import UnboundType + +from ... import util + +COLUMN: int = util.symbol("COLUMN") +RELATIONSHIP: int = util.symbol("RELATIONSHIP") +REGISTRY: int = util.symbol("REGISTRY") +COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") +TYPEENGINE: int = util.symbol("TYPEENGNE") +MAPPED: int = util.symbol("MAPPED") +DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE") +DECLARATIVE_META: int = util.symbol("DECLARATIVE_META") +MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR") +SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY") +COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY") +DECLARED_ATTR: int = util.symbol("DECLARED_ATTR") +MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY") +AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE") +AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE") +DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN") +QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") + +# names that must succeed with mypy.api.named_type +NAMED_TYPE_BUILTINS_OBJECT = "builtins.object" +NAMED_TYPE_BUILTINS_STR = "builtins.str" +NAMED_TYPE_BUILTINS_LIST = "builtins.list" +NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped" + +_RelFullNames = { + "sqlalchemy.orm.relationships.Relationship", + "sqlalchemy.orm.relationships.RelationshipProperty", + "sqlalchemy.orm.relationships._RelationshipDeclared", + "sqlalchemy.orm.Relationship", + "sqlalchemy.orm.RelationshipProperty", +} + +_lookup: Dict[str, Tuple[int, Set[str]]] = { + "Column": ( + COLUMN, + { + "sqlalchemy.sql.schema.Column", + "sqlalchemy.sql.Column", + }, + ), + "Relationship": (RELATIONSHIP, _RelFullNames), + "RelationshipProperty": (RELATIONSHIP, _RelFullNames), + "_RelationshipDeclared": (RELATIONSHIP, _RelFullNames), + "registry": ( + REGISTRY, + { + "sqlalchemy.orm.decl_api.registry", + "sqlalchemy.orm.registry", + }, + ), + "ColumnProperty": ( + COLUMN_PROPERTY, + { + "sqlalchemy.orm.properties.MappedSQLExpression", + "sqlalchemy.orm.MappedSQLExpression", + "sqlalchemy.orm.properties.ColumnProperty", + "sqlalchemy.orm.ColumnProperty", + }, + ), + "MappedSQLExpression": ( + COLUMN_PROPERTY, + { + "sqlalchemy.orm.properties.MappedSQLExpression", + "sqlalchemy.orm.MappedSQLExpression", + "sqlalchemy.orm.properties.ColumnProperty", + "sqlalchemy.orm.ColumnProperty", + }, + ), + "Synonym": ( + SYNONYM_PROPERTY, + { + "sqlalchemy.orm.descriptor_props.Synonym", + "sqlalchemy.orm.Synonym", + "sqlalchemy.orm.descriptor_props.SynonymProperty", + "sqlalchemy.orm.SynonymProperty", + }, + ), + "SynonymProperty": ( + SYNONYM_PROPERTY, + { + "sqlalchemy.orm.descriptor_props.Synonym", + "sqlalchemy.orm.Synonym", + "sqlalchemy.orm.descriptor_props.SynonymProperty", + "sqlalchemy.orm.SynonymProperty", + }, + ), + "Composite": ( + COMPOSITE_PROPERTY, + { + "sqlalchemy.orm.descriptor_props.Composite", + "sqlalchemy.orm.Composite", + "sqlalchemy.orm.descriptor_props.CompositeProperty", + "sqlalchemy.orm.CompositeProperty", + }, + ), + "CompositeProperty": ( + COMPOSITE_PROPERTY, + { + "sqlalchemy.orm.descriptor_props.Composite", + "sqlalchemy.orm.Composite", + "sqlalchemy.orm.descriptor_props.CompositeProperty", + "sqlalchemy.orm.CompositeProperty", + }, + ), + "MapperProperty": ( + MAPPER_PROPERTY, + { + "sqlalchemy.orm.interfaces.MapperProperty", + "sqlalchemy.orm.MapperProperty", + }, + ), + "TypeEngine": (TYPEENGINE, {"sqlalchemy.sql.type_api.TypeEngine"}), + "Mapped": (MAPPED, {NAMED_TYPE_SQLA_MAPPED}), + "declarative_base": ( + DECLARATIVE_BASE, + { + "sqlalchemy.ext.declarative.declarative_base", + "sqlalchemy.orm.declarative_base", + "sqlalchemy.orm.decl_api.declarative_base", + }, + ), + "DeclarativeMeta": ( + DECLARATIVE_META, + { + "sqlalchemy.ext.declarative.DeclarativeMeta", + "sqlalchemy.orm.DeclarativeMeta", + "sqlalchemy.orm.decl_api.DeclarativeMeta", + }, + ), + "mapped": ( + MAPPED_DECORATOR, + { + "sqlalchemy.orm.decl_api.registry.mapped", + "sqlalchemy.orm.registry.mapped", + }, + ), + "as_declarative": ( + AS_DECLARATIVE, + { + "sqlalchemy.ext.declarative.as_declarative", + "sqlalchemy.orm.decl_api.as_declarative", + "sqlalchemy.orm.as_declarative", + }, + ), + "as_declarative_base": ( + AS_DECLARATIVE_BASE, + { + "sqlalchemy.orm.decl_api.registry.as_declarative_base", + "sqlalchemy.orm.registry.as_declarative_base", + }, + ), + "declared_attr": ( + DECLARED_ATTR, + { + "sqlalchemy.orm.decl_api.declared_attr", + "sqlalchemy.orm.declared_attr", + }, + ), + "declarative_mixin": ( + DECLARATIVE_MIXIN, + { + "sqlalchemy.orm.decl_api.declarative_mixin", + "sqlalchemy.orm.declarative_mixin", + }, + ), + "query_expression": ( + QUERY_EXPRESSION, + { + "sqlalchemy.orm.query_expression", + "sqlalchemy.orm._orm_constructors.query_expression", + }, + ), +} + + +def has_base_type_id(info: TypeInfo, type_id: int) -> bool: + for mr in info.mro: + check_type_id, fullnames = _lookup.get(mr.name, (None, None)) + if check_type_id == type_id: + break + else: + return False + + if fullnames is None: + return False + + return mr.fullname in fullnames + + +def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool: + for mr in mro: + check_type_id, fullnames = _lookup.get(mr.name, (None, None)) + if check_type_id == type_id: + break + else: + return False + + if fullnames is None: + return False + + return mr.fullname in fullnames + + +def type_id_for_unbound_type( + type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface +) -> Optional[int]: + sym = api.lookup_qualified(type_.name, type_) + if sym is not None: + if isinstance(sym.node, TypeAlias): + target_type = get_proper_type(sym.node.target) + if isinstance(target_type, Instance): + return type_id_for_named_node(target_type.type) + elif isinstance(sym.node, TypeInfo): + return type_id_for_named_node(sym.node) + + return None + + +def type_id_for_callee(callee: Expression) -> Optional[int]: + if isinstance(callee, (MemberExpr, NameExpr)): + if isinstance(callee.node, Decorator) and isinstance( + callee.node.func, FuncDef + ): + if callee.node.func.type and isinstance( + callee.node.func.type, CallableType + ): + ret_type = get_proper_type(callee.node.func.type.ret_type) + + if isinstance(ret_type, Instance): + return type_id_for_fullname(ret_type.type.fullname) + + return None + + elif isinstance(callee.node, OverloadedFuncDef): + if ( + callee.node.impl + and callee.node.impl.type + and isinstance(callee.node.impl.type, CallableType) + ): + ret_type = get_proper_type(callee.node.impl.type.ret_type) + + if isinstance(ret_type, Instance): + return type_id_for_fullname(ret_type.type.fullname) + + return None + elif isinstance(callee.node, FuncDef): + if callee.node.type and isinstance(callee.node.type, CallableType): + ret_type = get_proper_type(callee.node.type.ret_type) + + if isinstance(ret_type, Instance): + return type_id_for_fullname(ret_type.type.fullname) + + return None + elif isinstance(callee.node, TypeAlias): + target_type = get_proper_type(callee.node.target) + if isinstance(target_type, Instance): + return type_id_for_fullname(target_type.type.fullname) + elif isinstance(callee.node, TypeInfo): + return type_id_for_named_node(callee) + return None + + +def type_id_for_named_node( + node: Union[NameExpr, MemberExpr, SymbolNode] +) -> Optional[int]: + type_id, fullnames = _lookup.get(node.name, (None, None)) + + if type_id is None or fullnames is None: + return None + elif node.fullname in fullnames: + return type_id + else: + return None + + +def type_id_for_fullname(fullname: str) -> Optional[int]: + tokens = fullname.split(".") + immediate = tokens[-1] + + type_id, fullnames = _lookup.get(immediate, (None, None)) + + if type_id is None or fullnames is None: + return None + elif fullname in fullnames: + return type_id + else: + return None + + +def expr_to_mapped_constructor(expr: Expression) -> CallExpr: + column_descriptor = NameExpr("__sa_Mapped") + column_descriptor.fullname = NAMED_TYPE_SQLA_MAPPED + member_expr = MemberExpr(column_descriptor, "_empty_constructor") + return CallExpr( + member_expr, + [expr], + [ARG_POS], + ["arg1"], + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/plugin.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/plugin.py new file mode 100644 index 0000000..00eb4d1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/plugin.py @@ -0,0 +1,303 @@ +# ext/mypy/plugin.py +# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +Mypy plugin for SQLAlchemy ORM. + +""" +from __future__ import annotations + +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type as TypingType +from typing import Union + +from mypy import nodes +from mypy.mro import calculate_mro +from mypy.mro import MroError +from mypy.nodes import Block +from mypy.nodes import ClassDef +from mypy.nodes import GDEF +from mypy.nodes import MypyFile +from mypy.nodes import NameExpr +from mypy.nodes import SymbolTable +from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeInfo +from mypy.plugin import AttributeContext +from mypy.plugin import ClassDefContext +from mypy.plugin import DynamicClassDefContext +from mypy.plugin import Plugin +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import Type + +from . import decl_class +from . import names +from . import util + +try: + __import__("sqlalchemy-stubs") +except ImportError: + pass +else: + raise ImportError( + "The SQLAlchemy mypy plugin in SQLAlchemy " + "2.0 does not work with sqlalchemy-stubs or " + "sqlalchemy2-stubs installed, as well as with any other third party " + "SQLAlchemy stubs. Please uninstall all SQLAlchemy stubs " + "packages." + ) + + +class SQLAlchemyPlugin(Plugin): + def get_dynamic_class_hook( + self, fullname: str + ) -> Optional[Callable[[DynamicClassDefContext], None]]: + if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE: + return _dynamic_class_hook + return None + + def get_customize_class_mro_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + return _fill_in_decorators + + def get_class_decorator_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + sym = self.lookup_fully_qualified(fullname) + + if sym is not None and sym.node is not None: + type_id = names.type_id_for_named_node(sym.node) + if type_id is names.MAPPED_DECORATOR: + return _cls_decorator_hook + elif type_id in ( + names.AS_DECLARATIVE, + names.AS_DECLARATIVE_BASE, + ): + return _base_cls_decorator_hook + elif type_id is names.DECLARATIVE_MIXIN: + return _declarative_mixin_hook + + return None + + def get_metaclass_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META: + # Set any classes that explicitly have metaclass=DeclarativeMeta + # as declarative so the check in `get_base_class_hook()` works + return _metaclass_cls_hook + + return None + + def get_base_class_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + sym = self.lookup_fully_qualified(fullname) + + if ( + sym + and isinstance(sym.node, TypeInfo) + and util.has_declarative_base(sym.node) + ): + return _base_cls_hook + + return None + + def get_attribute_hook( + self, fullname: str + ) -> Optional[Callable[[AttributeContext], Type]]: + if fullname.startswith( + "sqlalchemy.orm.attributes.QueryableAttribute." + ): + return _queryable_getattr_hook + + return None + + def get_additional_deps( + self, file: MypyFile + ) -> List[Tuple[int, str, int]]: + return [ + # + (10, "sqlalchemy.orm", -1), + (10, "sqlalchemy.orm.attributes", -1), + (10, "sqlalchemy.orm.decl_api", -1), + ] + + +def plugin(version: str) -> TypingType[SQLAlchemyPlugin]: + return SQLAlchemyPlugin + + +def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None: + """Generate a declarative Base class when the declarative_base() function + is encountered.""" + + _add_globals(ctx) + + cls = ClassDef(ctx.name, Block([])) + cls.fullname = ctx.api.qualified_name(ctx.name) + + info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id) + cls.info = info + _set_declarative_metaclass(ctx.api, cls) + + cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,)) + if cls_arg is not None and isinstance(cls_arg.node, TypeInfo): + util.set_is_base(cls_arg.node) + decl_class.scan_declarative_assignments_and_apply_types( + cls_arg.node.defn, ctx.api, is_mixin_scan=True + ) + info.bases = [Instance(cls_arg.node, [])] + else: + obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT) + + info.bases = [obj] + + try: + calculate_mro(info) + except MroError: + util.fail( + ctx.api, "Not able to calculate MRO for declarative base", ctx.call + ) + obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT) + info.bases = [obj] + info.fallback_to_any = True + + ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) + util.set_is_base(info) + + +def _fill_in_decorators(ctx: ClassDefContext) -> None: + for decorator in ctx.cls.decorators: + # set the ".fullname" attribute of a class decorator + # that is a MemberExpr. This causes the logic in + # semanal.py->apply_class_plugin_hooks to invoke the + # get_class_decorator_hook for our "registry.map_class()" + # and "registry.as_declarative_base()" methods. + # this seems like a bug in mypy that these decorators are otherwise + # skipped. + + if ( + isinstance(decorator, nodes.CallExpr) + and isinstance(decorator.callee, nodes.MemberExpr) + and decorator.callee.name == "as_declarative_base" + ): + target = decorator.callee + elif ( + isinstance(decorator, nodes.MemberExpr) + and decorator.name == "mapped" + ): + target = decorator + else: + continue + + if isinstance(target.expr, NameExpr): + sym = ctx.api.lookup_qualified( + target.expr.name, target, suppress_errors=True + ) + else: + continue + + if sym and sym.node: + sym_type = get_proper_type(sym.type) + if isinstance(sym_type, Instance): + target.fullname = f"{sym_type.type.fullname}.{target.name}" + else: + # if the registry is in the same file as where the + # decorator is used, it might not have semantic + # symbols applied and we can't get a fully qualified + # name or an inferred type, so we are actually going to + # flag an error in this case that they need to annotate + # it. The "registry" is declared just + # once (or few times), so they have to just not use + # type inference for its assignment in this one case. + util.fail( + ctx.api, + "Class decorator called %s(), but we can't " + "tell if it's from an ORM registry. Please " + "annotate the registry assignment, e.g. " + "my_registry: registry = registry()" % target.name, + sym.node, + ) + + +def _cls_decorator_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + assert isinstance(ctx.reason, nodes.MemberExpr) + expr = ctx.reason.expr + + assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var) + + node_type = get_proper_type(expr.node.type) + + assert ( + isinstance(node_type, Instance) + and names.type_id_for_named_node(node_type.type) is names.REGISTRY + ) + + decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) + + +def _base_cls_decorator_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + + cls = ctx.cls + + _set_declarative_metaclass(ctx.api, cls) + + util.set_is_base(ctx.cls.info) + decl_class.scan_declarative_assignments_and_apply_types( + cls, ctx.api, is_mixin_scan=True + ) + + +def _declarative_mixin_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + util.set_is_base(ctx.cls.info) + decl_class.scan_declarative_assignments_and_apply_types( + ctx.cls, ctx.api, is_mixin_scan=True + ) + + +def _metaclass_cls_hook(ctx: ClassDefContext) -> None: + util.set_is_base(ctx.cls.info) + + +def _base_cls_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) + + +def _queryable_getattr_hook(ctx: AttributeContext) -> Type: + # how do I....tell it it has no attribute of a certain name? + # can't find any Type that seems to match that + return ctx.default_attr_type + + +def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None: + """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space + for all class defs + + """ + + util.add_global(ctx, "sqlalchemy.orm", "Mapped", "__sa_Mapped") + + +def _set_declarative_metaclass( + api: SemanticAnalyzerPluginInterface, target_cls: ClassDef +) -> None: + info = target_cls.info + sym = api.lookup_fully_qualified_or_none( + "sqlalchemy.orm.decl_api.DeclarativeMeta" + ) + assert sym is not None and isinstance(sym.node, TypeInfo) + info.declared_metaclass = info.metaclass_type = Instance(sym.node, []) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/util.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/util.py new file mode 100644 index 0000000..7f04c48 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/util.py @@ -0,0 +1,338 @@ +# ext/mypy/util.py +# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import re +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type as TypingType +from typing import TypeVar +from typing import Union + +from mypy import version +from mypy.messages import format_type as _mypy_format_type +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import CLASSDEF_NO_INFO +from mypy.nodes import Context +from mypy.nodes import Expression +from mypy.nodes import FuncDef +from mypy.nodes import IfStmt +from mypy.nodes import JsonDict +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import Statement +from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeAlias +from mypy.nodes import TypeInfo +from mypy.options import Options +from mypy.plugin import ClassDefContext +from mypy.plugin import DynamicClassDefContext +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.plugins.common import deserialize_and_fixup_type +from mypy.typeops import map_type_from_supertype +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneType +from mypy.types import Type +from mypy.types import TypeVarType +from mypy.types import UnboundType +from mypy.types import UnionType + +_vers = tuple( + [int(x) for x in version.__version__.split(".") if re.match(r"^\d+$", x)] +) +mypy_14 = _vers >= (1, 4) + + +_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr]) + + +class SQLAlchemyAttribute: + def __init__( + self, + name: str, + line: int, + column: int, + typ: Optional[Type], + info: TypeInfo, + ) -> None: + self.name = name + self.line = line + self.column = column + self.type = typ + self.info = info + + def serialize(self) -> JsonDict: + assert self.type + return { + "name": self.name, + "line": self.line, + "column": self.column, + "type": self.type.serialize(), + } + + def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: + """Expands type vars in the context of a subtype when an attribute is + inherited from a generic super type. + """ + if not isinstance(self.type, TypeVarType): + return + + self.type = map_type_from_supertype(self.type, sub_type, self.info) + + @classmethod + def deserialize( + cls, + info: TypeInfo, + data: JsonDict, + api: SemanticAnalyzerPluginInterface, + ) -> SQLAlchemyAttribute: + data = data.copy() + typ = deserialize_and_fixup_type(data.pop("type"), api) + return cls(typ=typ, info=info, **data) + + +def name_is_dunder(name: str) -> bool: + return bool(re.match(r"^__.+?__$", name)) + + +def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None: + info.metadata.setdefault("sqlalchemy", {})[key] = data + + +def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]: + return info.metadata.get("sqlalchemy", {}).get(key, None) + + +def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]: + if info.mro: + for base in info.mro: + metadata = _get_info_metadata(base, key) + if metadata is not None: + return metadata + return None + + +def establish_as_sqlalchemy(info: TypeInfo) -> None: + info.metadata.setdefault("sqlalchemy", {}) + + +def set_is_base(info: TypeInfo) -> None: + _set_info_metadata(info, "is_base", True) + + +def get_is_base(info: TypeInfo) -> bool: + is_base = _get_info_metadata(info, "is_base") + return is_base is True + + +def has_declarative_base(info: TypeInfo) -> bool: + is_base = _get_info_mro_metadata(info, "is_base") + return is_base is True + + +def set_has_table(info: TypeInfo) -> None: + _set_info_metadata(info, "has_table", True) + + +def get_has_table(info: TypeInfo) -> bool: + is_base = _get_info_metadata(info, "has_table") + return is_base is True + + +def get_mapped_attributes( + info: TypeInfo, api: SemanticAnalyzerPluginInterface +) -> Optional[List[SQLAlchemyAttribute]]: + mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata( + info, "mapped_attributes" + ) + if mapped_attributes is None: + return None + + attributes: List[SQLAlchemyAttribute] = [] + + for data in mapped_attributes: + attr = SQLAlchemyAttribute.deserialize(info, data, api) + attr.expand_typevar_from_subtype(info) + attributes.append(attr) + + return attributes + + +def format_type(typ_: Type, options: Options) -> str: + if mypy_14: + return _mypy_format_type(typ_, options) + else: + return _mypy_format_type(typ_) # type: ignore + + +def set_mapped_attributes( + info: TypeInfo, attributes: List[SQLAlchemyAttribute] +) -> None: + _set_info_metadata( + info, + "mapped_attributes", + [attribute.serialize() for attribute in attributes], + ) + + +def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None: + msg = "[SQLAlchemy Mypy plugin] %s" % msg + return api.fail(msg, ctx) + + +def add_global( + ctx: Union[ClassDefContext, DynamicClassDefContext], + module: str, + symbol_name: str, + asname: str, +) -> None: + module_globals = ctx.api.modules[ctx.api.cur_mod_id].names + + if asname not in module_globals: + lookup_sym: SymbolTableNode = ctx.api.modules[module].names[ + symbol_name + ] + + module_globals[asname] = lookup_sym + + +@overload +def get_callexpr_kwarg( + callexpr: CallExpr, name: str, *, expr_types: None = ... +) -> Optional[Union[CallExpr, NameExpr]]: ... + + +@overload +def get_callexpr_kwarg( + callexpr: CallExpr, + name: str, + *, + expr_types: Tuple[TypingType[_TArgType], ...], +) -> Optional[_TArgType]: ... + + +def get_callexpr_kwarg( + callexpr: CallExpr, + name: str, + *, + expr_types: Optional[Tuple[TypingType[Any], ...]] = None, +) -> Optional[Any]: + try: + arg_idx = callexpr.arg_names.index(name) + except ValueError: + return None + + kwarg = callexpr.args[arg_idx] + if isinstance( + kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr) + ): + return kwarg + + return None + + +def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: + for stmt in stmts: + if ( + isinstance(stmt, IfStmt) + and isinstance(stmt.expr[0], NameExpr) + and stmt.expr[0].fullname == "typing.TYPE_CHECKING" + ): + yield from stmt.body[0].body + else: + yield stmt + + +def type_for_callee(callee: Expression) -> Optional[Union[Instance, TypeInfo]]: + if isinstance(callee, (MemberExpr, NameExpr)): + if isinstance(callee.node, FuncDef): + if callee.node.type and isinstance(callee.node.type, CallableType): + ret_type = get_proper_type(callee.node.type.ret_type) + + if isinstance(ret_type, Instance): + return ret_type + + return None + elif isinstance(callee.node, TypeAlias): + target_type = get_proper_type(callee.node.target) + if isinstance(target_type, Instance): + return target_type + elif isinstance(callee.node, TypeInfo): + return callee.node + return None + + +def unbound_to_instance( + api: SemanticAnalyzerPluginInterface, typ: Type +) -> Type: + """Take the UnboundType that we seem to get as the ret_type from a FuncDef + and convert it into an Instance/TypeInfo kind of structure that seems + to work as the left-hand type of an AssignmentStatement. + + """ + + if not isinstance(typ, UnboundType): + return typ + + # TODO: figure out a more robust way to check this. The node is some + # kind of _SpecialForm, there's a typing.Optional that's _SpecialForm, + # but I can't figure out how to get them to match up + if typ.name == "Optional": + # convert from "Optional?" to the more familiar + # UnionType[..., NoneType()] + return unbound_to_instance( + api, + UnionType( + [unbound_to_instance(api, typ_arg) for typ_arg in typ.args] + + [NoneType()] + ), + ) + + node = api.lookup_qualified(typ.name, typ) + + if ( + node is not None + and isinstance(node, SymbolTableNode) + and isinstance(node.node, TypeInfo) + ): + bound_type = node.node + + return Instance( + bound_type, + [ + ( + unbound_to_instance(api, arg) + if isinstance(arg, UnboundType) + else arg + ) + for arg in typ.args + ], + ) + else: + return typ + + +def info_for_cls( + cls: ClassDef, api: SemanticAnalyzerPluginInterface +) -> Optional[TypeInfo]: + if cls.info is CLASSDEF_NO_INFO: + sym = api.lookup_qualified(cls.name, cls) + if sym is None: + return None + assert sym and isinstance(sym.node, TypeInfo) + return sym.node + + return cls.info diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/orderinglist.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/orderinglist.py new file mode 100644 index 0000000..1a12cf3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/orderinglist.py @@ -0,0 +1,416 @@ +# ext/orderinglist.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""A custom list that manages index/position information for contained +elements. + +:author: Jason Kirtland + +``orderinglist`` is a helper for mutable ordered relationships. It will +intercept list operations performed on a :func:`_orm.relationship`-managed +collection and +automatically synchronize changes in list position onto a target scalar +attribute. + +Example: A ``slide`` table, where each row refers to zero or more entries +in a related ``bullet`` table. The bullets within a slide are +displayed in order based on the value of the ``position`` column in the +``bullet`` table. As entries are reordered in memory, the value of the +``position`` attribute should be updated to reflect the new sort order:: + + + Base = declarative_base() + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position") + + class Bullet(Base): + __tablename__ = 'bullet' + id = Column(Integer, primary_key=True) + slide_id = Column(Integer, ForeignKey('slide.id')) + position = Column(Integer) + text = Column(String) + +The standard relationship mapping will produce a list-like attribute on each +``Slide`` containing all related ``Bullet`` objects, +but coping with changes in ordering is not handled automatically. +When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position`` +attribute will remain unset until manually assigned. When the ``Bullet`` +is inserted into the middle of the list, the following ``Bullet`` objects +will also need to be renumbered. + +The :class:`.OrderingList` object automates this task, managing the +``position`` attribute on all ``Bullet`` objects in the collection. It is +constructed using the :func:`.ordering_list` factory:: + + from sqlalchemy.ext.orderinglist import ordering_list + + Base = declarative_base() + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position", + collection_class=ordering_list('position')) + + class Bullet(Base): + __tablename__ = 'bullet' + id = Column(Integer, primary_key=True) + slide_id = Column(Integer, ForeignKey('slide.id')) + position = Column(Integer) + text = Column(String) + +With the above mapping the ``Bullet.position`` attribute is managed:: + + s = Slide() + s.bullets.append(Bullet()) + s.bullets.append(Bullet()) + s.bullets[1].position + >>> 1 + s.bullets.insert(1, Bullet()) + s.bullets[2].position + >>> 2 + +The :class:`.OrderingList` construct only works with **changes** to a +collection, and not the initial load from the database, and requires that the +list be sorted when loaded. Therefore, be sure to specify ``order_by`` on the +:func:`_orm.relationship` against the target ordering attribute, so that the +ordering is correct when first loaded. + +.. warning:: + + :class:`.OrderingList` only provides limited functionality when a primary + key column or unique column is the target of the sort. Operations + that are unsupported or are problematic include: + + * two entries must trade values. This is not supported directly in the + case of a primary key or unique constraint because it means at least + one row would need to be temporarily removed first, or changed to + a third, neutral value while the switch occurs. + + * an entry must be deleted in order to make room for a new entry. + SQLAlchemy's unit of work performs all INSERTs before DELETEs within a + single flush. In the case of a primary key, it will trade + an INSERT/DELETE of the same primary key for an UPDATE statement in order + to lessen the impact of this limitation, however this does not take place + for a UNIQUE column. + A future feature will allow the "DELETE before INSERT" behavior to be + possible, alleviating this limitation, though this feature will require + explicit configuration at the mapper level for sets of columns that + are to be handled in this way. + +:func:`.ordering_list` takes the name of the related object's ordering +attribute as an argument. By default, the zero-based integer index of the +object's position in the :func:`.ordering_list` is synchronized with the +ordering attribute: index 0 will get position 0, index 1 position 1, etc. To +start numbering at 1 or some other integer, provide ``count_from=1``. + + +""" +from __future__ import annotations + +from typing import Callable +from typing import List +from typing import Optional +from typing import Sequence +from typing import TypeVar + +from ..orm.collections import collection +from ..orm.collections import collection_adapter + +_T = TypeVar("_T") +OrderingFunc = Callable[[int, Sequence[_T]], int] + + +__all__ = ["ordering_list"] + + +def ordering_list( + attr: str, + count_from: Optional[int] = None, + ordering_func: Optional[OrderingFunc] = None, + reorder_on_append: bool = False, +) -> Callable[[], OrderingList]: + """Prepares an :class:`OrderingList` factory for use in mapper definitions. + + Returns an object suitable for use as an argument to a Mapper + relationship's ``collection_class`` option. e.g.:: + + from sqlalchemy.ext.orderinglist import ordering_list + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position", + collection_class=ordering_list('position')) + + :param attr: + Name of the mapped attribute to use for storage and retrieval of + ordering information + + :param count_from: + Set up an integer-based ordering, starting at ``count_from``. For + example, ``ordering_list('pos', count_from=1)`` would create a 1-based + list in SQL, storing the value in the 'pos' column. Ignored if + ``ordering_func`` is supplied. + + Additional arguments are passed to the :class:`.OrderingList` constructor. + + """ + + kw = _unsugar_count_from( + count_from=count_from, + ordering_func=ordering_func, + reorder_on_append=reorder_on_append, + ) + return lambda: OrderingList(attr, **kw) + + +# Ordering utility functions + + +def count_from_0(index, collection): + """Numbering function: consecutive integers starting at 0.""" + + return index + + +def count_from_1(index, collection): + """Numbering function: consecutive integers starting at 1.""" + + return index + 1 + + +def count_from_n_factory(start): + """Numbering function: consecutive integers starting at arbitrary start.""" + + def f(index, collection): + return index + start + + try: + f.__name__ = "count_from_%i" % start + except TypeError: + pass + return f + + +def _unsugar_count_from(**kw): + """Builds counting functions from keyword arguments. + + Keyword argument filter, prepares a simple ``ordering_func`` from a + ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged. + """ + + count_from = kw.pop("count_from", None) + if kw.get("ordering_func", None) is None and count_from is not None: + if count_from == 0: + kw["ordering_func"] = count_from_0 + elif count_from == 1: + kw["ordering_func"] = count_from_1 + else: + kw["ordering_func"] = count_from_n_factory(count_from) + return kw + + +class OrderingList(List[_T]): + """A custom list that manages position information for its children. + + The :class:`.OrderingList` object is normally set up using the + :func:`.ordering_list` factory function, used in conjunction with + the :func:`_orm.relationship` function. + + """ + + ordering_attr: str + ordering_func: OrderingFunc + reorder_on_append: bool + + def __init__( + self, + ordering_attr: Optional[str] = None, + ordering_func: Optional[OrderingFunc] = None, + reorder_on_append: bool = False, + ): + """A custom list that manages position information for its children. + + ``OrderingList`` is a ``collection_class`` list implementation that + syncs position in a Python list with a position attribute on the + mapped objects. + + This implementation relies on the list starting in the proper order, + so be **sure** to put an ``order_by`` on your relationship. + + :param ordering_attr: + Name of the attribute that stores the object's order in the + relationship. + + :param ordering_func: Optional. A function that maps the position in + the Python list to a value to store in the + ``ordering_attr``. Values returned are usually (but need not be!) + integers. + + An ``ordering_func`` is called with two positional parameters: the + index of the element in the list, and the list itself. + + If omitted, Python list indexes are used for the attribute values. + Two basic pre-built numbering functions are provided in this module: + ``count_from_0`` and ``count_from_1``. For more exotic examples + like stepped numbering, alphabetical and Fibonacci numbering, see + the unit tests. + + :param reorder_on_append: + Default False. When appending an object with an existing (non-None) + ordering value, that value will be left untouched unless + ``reorder_on_append`` is true. This is an optimization to avoid a + variety of dangerous unexpected database writes. + + SQLAlchemy will add instances to the list via append() when your + object loads. If for some reason the result set from the database + skips a step in the ordering (say, row '1' is missing but you get + '2', '3', and '4'), reorder_on_append=True would immediately + renumber the items to '1', '2', '3'. If you have multiple sessions + making changes, any of whom happen to load this collection even in + passing, all of the sessions would try to "clean up" the numbering + in their commits, possibly causing all but one to fail with a + concurrent modification error. + + Recommend leaving this with the default of False, and just call + ``reorder()`` if you're doing ``append()`` operations with + previously ordered instances or when doing some housekeeping after + manual sql operations. + + """ + self.ordering_attr = ordering_attr + if ordering_func is None: + ordering_func = count_from_0 + self.ordering_func = ordering_func + self.reorder_on_append = reorder_on_append + + # More complex serialization schemes (multi column, e.g.) are possible by + # subclassing and reimplementing these two methods. + def _get_order_value(self, entity): + return getattr(entity, self.ordering_attr) + + def _set_order_value(self, entity, value): + setattr(entity, self.ordering_attr, value) + + def reorder(self) -> None: + """Synchronize ordering for the entire collection. + + Sweeps through the list and ensures that each object has accurate + ordering information set. + + """ + for index, entity in enumerate(self): + self._order_entity(index, entity, True) + + # As of 0.5, _reorder is no longer semi-private + _reorder = reorder + + def _order_entity(self, index, entity, reorder=True): + have = self._get_order_value(entity) + + # Don't disturb existing ordering if reorder is False + if have is not None and not reorder: + return + + should_be = self.ordering_func(index, self) + if have != should_be: + self._set_order_value(entity, should_be) + + def append(self, entity): + super().append(entity) + self._order_entity(len(self) - 1, entity, self.reorder_on_append) + + def _raw_append(self, entity): + """Append without any ordering behavior.""" + + super().append(entity) + + _raw_append = collection.adds(1)(_raw_append) + + def insert(self, index, entity): + super().insert(index, entity) + self._reorder() + + def remove(self, entity): + super().remove(entity) + + adapter = collection_adapter(self) + if adapter and adapter._referenced_by_owner: + self._reorder() + + def pop(self, index=-1): + entity = super().pop(index) + self._reorder() + return entity + + def __setitem__(self, index, entity): + if isinstance(index, slice): + step = index.step or 1 + start = index.start or 0 + if start < 0: + start += len(self) + stop = index.stop or len(self) + if stop < 0: + stop += len(self) + + for i in range(start, stop, step): + self.__setitem__(i, entity[i]) + else: + self._order_entity(index, entity, True) + super().__setitem__(index, entity) + + def __delitem__(self, index): + super().__delitem__(index) + self._reorder() + + def __setslice__(self, start, end, values): + super().__setslice__(start, end, values) + self._reorder() + + def __delslice__(self, start, end): + super().__delslice__(start, end) + self._reorder() + + def __reduce__(self): + return _reconstitute, (self.__class__, self.__dict__, list(self)) + + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): + func.__doc__ = getattr(list, func_name).__doc__ + del func_name, func + + +def _reconstitute(cls, dict_, items): + """Reconstitute an :class:`.OrderingList`. + + This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for + unpickling :class:`.OrderingList` objects. + + """ + obj = cls.__new__(cls) + obj.__dict__.update(dict_) + list.extend(obj, items) + return obj diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/serializer.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/serializer.py new file mode 100644 index 0000000..f21e997 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/serializer.py @@ -0,0 +1,185 @@ +# ext/serializer.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""Serializer/Deserializer objects for usage with SQLAlchemy query structures, +allowing "contextual" deserialization. + +.. legacy:: + + The serializer extension is **legacy** and should not be used for + new development. + +Any SQLAlchemy query structure, either based on sqlalchemy.sql.* +or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session +etc. which are referenced by the structure are not persisted in serialized +form, but are instead re-associated with the query structure +when it is deserialized. + +.. warning:: The serializer extension uses pickle to serialize and + deserialize objects, so the same security consideration mentioned + in the `python documentation + <https://docs.python.org/3/library/pickle.html>`_ apply. + +Usage is nearly the same as that of the standard Python pickle module:: + + from sqlalchemy.ext.serializer import loads, dumps + metadata = MetaData(bind=some_engine) + Session = scoped_session(sessionmaker()) + + # ... define mappers + + query = Session.query(MyClass). + filter(MyClass.somedata=='foo').order_by(MyClass.sortkey) + + # pickle the query + serialized = dumps(query) + + # unpickle. Pass in metadata + scoped_session + query2 = loads(serialized, metadata, Session) + + print query2.all() + +Similar restrictions as when using raw pickle apply; mapped classes must be +themselves be pickleable, meaning they are importable from a module-level +namespace. + +The serializer module is only appropriate for query structures. It is not +needed for: + +* instances of user-defined classes. These contain no references to engines, + sessions or expression constructs in the typical case and can be serialized + directly. + +* Table metadata that is to be loaded entirely from the serialized structure + (i.e. is not already declared in the application). Regular + pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object, + typically one which was reflected from an existing database at some previous + point in time. The serializer module is specifically for the opposite case, + where the Table metadata is already present in memory. + +""" + +from io import BytesIO +import pickle +import re + +from .. import Column +from .. import Table +from ..engine import Engine +from ..orm import class_mapper +from ..orm.interfaces import MapperProperty +from ..orm.mapper import Mapper +from ..orm.session import Session +from ..util import b64decode +from ..util import b64encode + + +__all__ = ["Serializer", "Deserializer", "dumps", "loads"] + + +def Serializer(*args, **kw): + pickler = pickle.Pickler(*args, **kw) + + def persistent_id(obj): + # print "serializing:", repr(obj) + if isinstance(obj, Mapper) and not obj.non_primary: + id_ = "mapper:" + b64encode(pickle.dumps(obj.class_)) + elif isinstance(obj, MapperProperty) and not obj.parent.non_primary: + id_ = ( + "mapperprop:" + + b64encode(pickle.dumps(obj.parent.class_)) + + ":" + + obj.key + ) + elif isinstance(obj, Table): + if "parententity" in obj._annotations: + id_ = "mapper_selectable:" + b64encode( + pickle.dumps(obj._annotations["parententity"].class_) + ) + else: + id_ = f"table:{obj.key}" + elif isinstance(obj, Column) and isinstance(obj.table, Table): + id_ = f"column:{obj.table.key}:{obj.key}" + elif isinstance(obj, Session): + id_ = "session:" + elif isinstance(obj, Engine): + id_ = "engine:" + else: + return None + return id_ + + pickler.persistent_id = persistent_id + return pickler + + +our_ids = re.compile( + r"(mapperprop|mapper|mapper_selectable|table|column|" + r"session|attribute|engine):(.*)" +) + + +def Deserializer(file, metadata=None, scoped_session=None, engine=None): + unpickler = pickle.Unpickler(file) + + def get_engine(): + if engine: + return engine + elif scoped_session and scoped_session().bind: + return scoped_session().bind + elif metadata and metadata.bind: + return metadata.bind + else: + return None + + def persistent_load(id_): + m = our_ids.match(str(id_)) + if not m: + return None + else: + type_, args = m.group(1, 2) + if type_ == "attribute": + key, clsarg = args.split(":") + cls = pickle.loads(b64decode(clsarg)) + return getattr(cls, key) + elif type_ == "mapper": + cls = pickle.loads(b64decode(args)) + return class_mapper(cls) + elif type_ == "mapper_selectable": + cls = pickle.loads(b64decode(args)) + return class_mapper(cls).__clause_element__() + elif type_ == "mapperprop": + mapper, keyname = args.split(":") + cls = pickle.loads(b64decode(mapper)) + return class_mapper(cls).attrs[keyname] + elif type_ == "table": + return metadata.tables[args] + elif type_ == "column": + table, colname = args.split(":") + return metadata.tables[table].c[colname] + elif type_ == "session": + return scoped_session() + elif type_ == "engine": + return get_engine() + else: + raise Exception("Unknown token: %s" % type_) + + unpickler.persistent_load = persistent_load + return unpickler + + +def dumps(obj, protocol=pickle.HIGHEST_PROTOCOL): + buf = BytesIO() + pickler = Serializer(buf, protocol) + pickler.dump(obj) + return buf.getvalue() + + +def loads(data, metadata=None, scoped_session=None, engine=None): + buf = BytesIO(data) + unpickler = Deserializer(buf, metadata, scoped_session, engine) + return unpickler.load() |