diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/sql/lambdas.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/sqlalchemy/sql/lambdas.py | 1449 |
1 files changed, 0 insertions, 1449 deletions
diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/lambdas.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/lambdas.py deleted file mode 100644 index 7a6b7b8..0000000 --- a/venv/lib/python3.11/site-packages/sqlalchemy/sql/lambdas.py +++ /dev/null @@ -1,1449 +0,0 @@ -# sql/lambdas.py -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# <see AUTHORS file> -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: allow-untyped-defs, allow-untyped-calls - -from __future__ import annotations - -import collections.abc as collections_abc -import inspect -import itertools -import operator -import threading -import types -from types import CodeType -from typing import Any -from typing import Callable -from typing import cast -from typing import List -from typing import MutableMapping -from typing import Optional -from typing import Tuple -from typing import Type -from typing import TYPE_CHECKING -from typing import TypeVar -from typing import Union -import weakref - -from . import cache_key as _cache_key -from . import coercions -from . import elements -from . import roles -from . import schema -from . import visitors -from .base import _clone -from .base import Executable -from .base import Options -from .cache_key import CacheConst -from .operators import ColumnOperators -from .. import exc -from .. import inspection -from .. import util -from ..util.typing import Literal - - -if TYPE_CHECKING: - from .elements import BindParameter - from .elements import ClauseElement - from .roles import SQLRole - from .visitors import _CloneCallableType - -_LambdaCacheType = MutableMapping[ - Tuple[Any, ...], Union["NonAnalyzedFunction", "AnalyzedFunction"] -] -_BoundParameterGetter = Callable[..., Any] - -_closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000) - - -_LambdaType = Callable[[], Any] - -_AnyLambdaType = Callable[..., Any] - -_StmtLambdaType = Callable[[], Any] - -_E = TypeVar("_E", bound=Executable) -_StmtLambdaElementType = Callable[[_E], Any] - - -class LambdaOptions(Options): - enable_tracking = True - track_closure_variables = True - track_on: Optional[object] = None - global_track_bound_values = True - track_bound_values = True - lambda_cache: Optional[_LambdaCacheType] = None - - -def lambda_stmt( - lmb: _StmtLambdaType, - enable_tracking: bool = True, - track_closure_variables: bool = True, - track_on: Optional[object] = None, - global_track_bound_values: bool = True, - track_bound_values: bool = True, - lambda_cache: Optional[_LambdaCacheType] = None, -) -> StatementLambdaElement: - """Produce a SQL statement that is cached as a lambda. - - The Python code object within the lambda is scanned for both Python - literals that will become bound parameters as well as closure variables - that refer to Core or ORM constructs that may vary. The lambda itself - will be invoked only once per particular set of constructs detected. - - E.g.:: - - from sqlalchemy import lambda_stmt - - stmt = lambda_stmt(lambda: table.select()) - stmt += lambda s: s.where(table.c.id == 5) - - result = connection.execute(stmt) - - The object returned is an instance of :class:`_sql.StatementLambdaElement`. - - .. versionadded:: 1.4 - - :param lmb: a Python function, typically a lambda, which takes no arguments - and returns a SQL expression construct - :param enable_tracking: when False, all scanning of the given lambda for - changes in closure variables or bound parameters is disabled. Use for - a lambda that produces the identical results in all cases with no - parameterization. - :param track_closure_variables: when False, changes in closure variables - within the lambda will not be scanned. Use for a lambda where the - state of its closure variables will never change the SQL structure - returned by the lambda. - :param track_bound_values: when False, bound parameter tracking will - be disabled for the given lambda. Use for a lambda that either does - not produce any bound values, or where the initial bound values never - change. - :param global_track_bound_values: when False, bound parameter tracking - will be disabled for the entire statement including additional links - added via the :meth:`_sql.StatementLambdaElement.add_criteria` method. - :param lambda_cache: a dictionary or other mapping-like object where - information about the lambda's Python code as well as the tracked closure - variables in the lambda itself will be stored. Defaults - to a global LRU cache. This cache is independent of the "compiled_cache" - used by the :class:`_engine.Connection` object. - - .. seealso:: - - :ref:`engine_lambda_caching` - - - """ - - return StatementLambdaElement( - lmb, - roles.StatementRole, - LambdaOptions( - enable_tracking=enable_tracking, - track_on=track_on, - track_closure_variables=track_closure_variables, - global_track_bound_values=global_track_bound_values, - track_bound_values=track_bound_values, - lambda_cache=lambda_cache, - ), - ) - - -class LambdaElement(elements.ClauseElement): - """A SQL construct where the state is stored as an un-invoked lambda. - - The :class:`_sql.LambdaElement` is produced transparently whenever - passing lambda expressions into SQL constructs, such as:: - - stmt = select(table).where(lambda: table.c.col == parameter) - - The :class:`_sql.LambdaElement` is the base of the - :class:`_sql.StatementLambdaElement` which represents a full statement - within a lambda. - - .. versionadded:: 1.4 - - .. seealso:: - - :ref:`engine_lambda_caching` - - """ - - __visit_name__ = "lambda_element" - - _is_lambda_element = True - - _traverse_internals = [ - ("_resolved", visitors.InternalTraversal.dp_clauseelement) - ] - - _transforms: Tuple[_CloneCallableType, ...] = () - - _resolved_bindparams: List[BindParameter[Any]] - parent_lambda: Optional[StatementLambdaElement] = None - closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]] - role: Type[SQLRole] - _rec: Union[AnalyzedFunction, NonAnalyzedFunction] - fn: _AnyLambdaType - tracker_key: Tuple[CodeType, ...] - - def __repr__(self): - return "%s(%r)" % ( - self.__class__.__name__, - self.fn.__code__, - ) - - def __init__( - self, - fn: _LambdaType, - role: Type[SQLRole], - opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions, - apply_propagate_attrs: Optional[ClauseElement] = None, - ): - self.fn = fn - self.role = role - self.tracker_key = (fn.__code__,) - self.opts = opts - - if apply_propagate_attrs is None and (role is roles.StatementRole): - apply_propagate_attrs = self - - rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts) - - if apply_propagate_attrs is not None: - propagate_attrs = rec.propagate_attrs - if propagate_attrs: - apply_propagate_attrs._propagate_attrs = propagate_attrs - - def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts): - lambda_cache = opts.lambda_cache - if lambda_cache is None: - lambda_cache = _closure_per_cache_key - - tracker_key = self.tracker_key - - fn = self.fn - closure = fn.__closure__ - tracker = AnalyzedCode.get( - fn, - self, - opts, - ) - - bindparams: List[BindParameter[Any]] - self._resolved_bindparams = bindparams = [] - - if self.parent_lambda is not None: - parent_closure_cache_key = self.parent_lambda.closure_cache_key - else: - parent_closure_cache_key = () - - cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]] - - if parent_closure_cache_key is not _cache_key.NO_CACHE: - anon_map = visitors.anon_map() - cache_key = tuple( - [ - getter(closure, opts, anon_map, bindparams) - for getter in tracker.closure_trackers - ] - ) - - if _cache_key.NO_CACHE not in anon_map: - cache_key = parent_closure_cache_key + cache_key - - self.closure_cache_key = cache_key - - try: - rec = lambda_cache[tracker_key + cache_key] - except KeyError: - rec = None - else: - cache_key = _cache_key.NO_CACHE - rec = None - - else: - cache_key = _cache_key.NO_CACHE - rec = None - - self.closure_cache_key = cache_key - - if rec is None: - if cache_key is not _cache_key.NO_CACHE: - with AnalyzedCode._generation_mutex: - key = tracker_key + cache_key - if key not in lambda_cache: - rec = AnalyzedFunction( - tracker, self, apply_propagate_attrs, fn - ) - rec.closure_bindparams = bindparams - lambda_cache[key] = rec - else: - rec = lambda_cache[key] - else: - rec = NonAnalyzedFunction(self._invoke_user_fn(fn)) - - else: - bindparams[:] = [ - orig_bind._with_value(new_bind.value, maintain_key=True) - for orig_bind, new_bind in zip( - rec.closure_bindparams, bindparams - ) - ] - - self._rec = rec - - if cache_key is not _cache_key.NO_CACHE: - if self.parent_lambda is not None: - bindparams[:0] = self.parent_lambda._resolved_bindparams - - lambda_element: Optional[LambdaElement] = self - while lambda_element is not None: - rec = lambda_element._rec - if rec.bindparam_trackers: - tracker_instrumented_fn = rec.tracker_instrumented_fn - for tracker in rec.bindparam_trackers: - tracker( - lambda_element.fn, - tracker_instrumented_fn, - bindparams, - ) - lambda_element = lambda_element.parent_lambda - - return rec - - def __getattr__(self, key): - return getattr(self._rec.expected_expr, key) - - @property - def _is_sequence(self): - return self._rec.is_sequence - - @property - def _select_iterable(self): - if self._is_sequence: - return itertools.chain.from_iterable( - [element._select_iterable for element in self._resolved] - ) - - else: - return self._resolved._select_iterable - - @property - def _from_objects(self): - if self._is_sequence: - return itertools.chain.from_iterable( - [element._from_objects for element in self._resolved] - ) - - else: - return self._resolved._from_objects - - def _param_dict(self): - return {b.key: b.value for b in self._resolved_bindparams} - - def _setup_binds_for_tracked_expr(self, expr): - bindparam_lookup = {b.key: b for b in self._resolved_bindparams} - - def replace( - element: Optional[visitors.ExternallyTraversible], **kw: Any - ) -> Optional[visitors.ExternallyTraversible]: - if isinstance(element, elements.BindParameter): - if element.key in bindparam_lookup: - bind = bindparam_lookup[element.key] - if element.expanding: - bind.expanding = True - bind.expand_op = element.expand_op - bind.type = element.type - return bind - - return None - - if self._rec.is_sequence: - expr = [ - visitors.replacement_traverse(sub_expr, {}, replace) - for sub_expr in expr - ] - elif getattr(expr, "is_clause_element", False): - expr = visitors.replacement_traverse(expr, {}, replace) - - return expr - - def _copy_internals( - self, - clone: _CloneCallableType = _clone, - deferred_copy_internals: Optional[_CloneCallableType] = None, - **kw: Any, - ) -> None: - # TODO: this needs A LOT of tests - self._resolved = clone( - self._resolved, - deferred_copy_internals=deferred_copy_internals, - **kw, - ) - - @util.memoized_property - def _resolved(self): - expr = self._rec.expected_expr - - if self._resolved_bindparams: - expr = self._setup_binds_for_tracked_expr(expr) - - return expr - - def _gen_cache_key(self, anon_map, bindparams): - if self.closure_cache_key is _cache_key.NO_CACHE: - anon_map[_cache_key.NO_CACHE] = True - return None - - cache_key = ( - self.fn.__code__, - self.__class__, - ) + self.closure_cache_key - - parent = self.parent_lambda - - while parent is not None: - assert parent.closure_cache_key is not CacheConst.NO_CACHE - parent_closure_cache_key: Tuple[Any, ...] = ( - parent.closure_cache_key - ) - - cache_key = ( - (parent.fn.__code__,) + parent_closure_cache_key + cache_key - ) - - parent = parent.parent_lambda - - if self._resolved_bindparams: - bindparams.extend(self._resolved_bindparams) - return cache_key - - def _invoke_user_fn(self, fn: _AnyLambdaType, *arg: Any) -> ClauseElement: - return fn() # type: ignore[no-any-return] - - -class DeferredLambdaElement(LambdaElement): - """A LambdaElement where the lambda accepts arguments and is - invoked within the compile phase with special context. - - This lambda doesn't normally produce its real SQL expression outside of the - compile phase. It is passed a fixed set of initial arguments - so that it can generate a sample expression. - - """ - - def __init__( - self, - fn: _AnyLambdaType, - role: Type[roles.SQLRole], - opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions, - lambda_args: Tuple[Any, ...] = (), - ): - self.lambda_args = lambda_args - super().__init__(fn, role, opts) - - def _invoke_user_fn(self, fn, *arg): - return fn(*self.lambda_args) - - def _resolve_with_args(self, *lambda_args: Any) -> ClauseElement: - assert isinstance(self._rec, AnalyzedFunction) - tracker_fn = self._rec.tracker_instrumented_fn - expr = tracker_fn(*lambda_args) - - expr = coercions.expect(self.role, expr) - - expr = self._setup_binds_for_tracked_expr(expr) - - # this validation is getting very close, but not quite, to achieving - # #5767. The problem is if the base lambda uses an unnamed column - # as is very common with mixins, the parameter name is different - # and it produces a false positive; that is, for the documented case - # that is exactly what people will be doing, it doesn't work, so - # I'm not really sure how to handle this right now. - # expected_binds = [ - # b._orig_key - # for b in self._rec.expr._generate_cache_key()[1] - # if b.required - # ] - # got_binds = [ - # b._orig_key for b in expr._generate_cache_key()[1] if b.required - # ] - # if expected_binds != got_binds: - # raise exc.InvalidRequestError( - # "Lambda callable at %s produced a different set of bound " - # "parameters than its original run: %s" - # % (self.fn.__code__, ", ".join(got_binds)) - # ) - - # TODO: TEST TEST TEST, this is very out there - for deferred_copy_internals in self._transforms: - expr = deferred_copy_internals(expr) - - return expr # type: ignore - - def _copy_internals( - self, clone=_clone, deferred_copy_internals=None, **kw - ): - super()._copy_internals( - clone=clone, - deferred_copy_internals=deferred_copy_internals, # **kw - opts=kw, - ) - - # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know - # our expression yet. so hold onto the replacement - if deferred_copy_internals: - self._transforms += (deferred_copy_internals,) - - -class StatementLambdaElement( - roles.AllowsLambdaRole, LambdaElement, Executable -): - """Represent a composable SQL statement as a :class:`_sql.LambdaElement`. - - The :class:`_sql.StatementLambdaElement` is constructed using the - :func:`_sql.lambda_stmt` function:: - - - from sqlalchemy import lambda_stmt - - stmt = lambda_stmt(lambda: select(table)) - - Once constructed, additional criteria can be built onto the statement - by adding subsequent lambdas, which accept the existing statement - object as a single parameter:: - - stmt += lambda s: s.where(table.c.col == parameter) - - - .. versionadded:: 1.4 - - .. seealso:: - - :ref:`engine_lambda_caching` - - """ - - if TYPE_CHECKING: - - def __init__( - self, - fn: _StmtLambdaType, - role: Type[SQLRole], - opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions, - apply_propagate_attrs: Optional[ClauseElement] = None, - ): ... - - def __add__( - self, other: _StmtLambdaElementType[Any] - ) -> StatementLambdaElement: - return self.add_criteria(other) - - def add_criteria( - self, - other: _StmtLambdaElementType[Any], - enable_tracking: bool = True, - track_on: Optional[Any] = None, - track_closure_variables: bool = True, - track_bound_values: bool = True, - ) -> StatementLambdaElement: - """Add new criteria to this :class:`_sql.StatementLambdaElement`. - - E.g.:: - - >>> def my_stmt(parameter): - ... stmt = lambda_stmt( - ... lambda: select(table.c.x, table.c.y), - ... ) - ... stmt = stmt.add_criteria( - ... lambda: table.c.x > parameter - ... ) - ... return stmt - - The :meth:`_sql.StatementLambdaElement.add_criteria` method is - equivalent to using the Python addition operator to add a new - lambda, except that additional arguments may be added including - ``track_closure_values`` and ``track_on``:: - - >>> def my_stmt(self, foo): - ... stmt = lambda_stmt( - ... lambda: select(func.max(foo.x, foo.y)), - ... track_closure_variables=False - ... ) - ... stmt = stmt.add_criteria( - ... lambda: self.where_criteria, - ... track_on=[self] - ... ) - ... return stmt - - See :func:`_sql.lambda_stmt` for a description of the parameters - accepted. - - """ - - opts = self.opts + dict( - enable_tracking=enable_tracking, - track_closure_variables=track_closure_variables, - global_track_bound_values=self.opts.global_track_bound_values, - track_on=track_on, - track_bound_values=track_bound_values, - ) - - return LinkedLambdaElement(other, parent_lambda=self, opts=opts) - - def _execute_on_connection( - self, connection, distilled_params, execution_options - ): - if TYPE_CHECKING: - assert isinstance(self._rec.expected_expr, ClauseElement) - if self._rec.expected_expr.supports_execution: - return connection._execute_clauseelement( - self, distilled_params, execution_options - ) - else: - raise exc.ObjectNotExecutableError(self) - - @property - def _proxied(self) -> Any: - return self._rec_expected_expr - - @property - def _with_options(self): - return self._proxied._with_options - - @property - def _effective_plugin_target(self): - return self._proxied._effective_plugin_target - - @property - def _execution_options(self): - return self._proxied._execution_options - - @property - def _all_selected_columns(self): - return self._proxied._all_selected_columns - - @property - def is_select(self): - return self._proxied.is_select - - @property - def is_update(self): - return self._proxied.is_update - - @property - def is_insert(self): - return self._proxied.is_insert - - @property - def is_text(self): - return self._proxied.is_text - - @property - def is_delete(self): - return self._proxied.is_delete - - @property - def is_dml(self): - return self._proxied.is_dml - - def spoil(self) -> NullLambdaStatement: - """Return a new :class:`.StatementLambdaElement` that will run - all lambdas unconditionally each time. - - """ - return NullLambdaStatement(self.fn()) - - -class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement): - """Provides the :class:`.StatementLambdaElement` API but does not - cache or analyze lambdas. - - the lambdas are instead invoked immediately. - - The intended use is to isolate issues that may arise when using - lambda statements. - - """ - - __visit_name__ = "lambda_element" - - _is_lambda_element = True - - _traverse_internals = [ - ("_resolved", visitors.InternalTraversal.dp_clauseelement) - ] - - def __init__(self, statement): - self._resolved = statement - self._propagate_attrs = statement._propagate_attrs - - def __getattr__(self, key): - return getattr(self._resolved, key) - - def __add__(self, other): - statement = other(self._resolved) - - return NullLambdaStatement(statement) - - def add_criteria(self, other, **kw): - statement = other(self._resolved) - - return NullLambdaStatement(statement) - - def _execute_on_connection( - self, connection, distilled_params, execution_options - ): - if self._resolved.supports_execution: - return connection._execute_clauseelement( - self, distilled_params, execution_options - ) - else: - raise exc.ObjectNotExecutableError(self) - - -class LinkedLambdaElement(StatementLambdaElement): - """Represent subsequent links of a :class:`.StatementLambdaElement`.""" - - parent_lambda: StatementLambdaElement - - def __init__( - self, - fn: _StmtLambdaElementType[Any], - parent_lambda: StatementLambdaElement, - opts: Union[Type[LambdaOptions], LambdaOptions], - ): - self.opts = opts - self.fn = fn - self.parent_lambda = parent_lambda - - self.tracker_key = parent_lambda.tracker_key + (fn.__code__,) - self._retrieve_tracker_rec(fn, self, opts) - self._propagate_attrs = parent_lambda._propagate_attrs - - def _invoke_user_fn(self, fn, *arg): - return fn(self.parent_lambda._resolved) - - -class AnalyzedCode: - __slots__ = ( - "track_closure_variables", - "track_bound_values", - "bindparam_trackers", - "closure_trackers", - "build_py_wrappers", - ) - _fns: weakref.WeakKeyDictionary[CodeType, AnalyzedCode] = ( - weakref.WeakKeyDictionary() - ) - - _generation_mutex = threading.RLock() - - @classmethod - def get(cls, fn, lambda_element, lambda_kw, **kw): - try: - # TODO: validate kw haven't changed? - return cls._fns[fn.__code__] - except KeyError: - pass - - with cls._generation_mutex: - # check for other thread already created object - if fn.__code__ in cls._fns: - return cls._fns[fn.__code__] - - analyzed: AnalyzedCode - cls._fns[fn.__code__] = analyzed = AnalyzedCode( - fn, lambda_element, lambda_kw, **kw - ) - return analyzed - - def __init__(self, fn, lambda_element, opts): - if inspect.ismethod(fn): - raise exc.ArgumentError( - "Method %s may not be passed as a SQL expression" % fn - ) - closure = fn.__closure__ - - self.track_bound_values = ( - opts.track_bound_values and opts.global_track_bound_values - ) - enable_tracking = opts.enable_tracking - track_on = opts.track_on - track_closure_variables = opts.track_closure_variables - - self.track_closure_variables = track_closure_variables and not track_on - - # a list of callables generated from _bound_parameter_getter_* - # functions. Each of these uses a PyWrapper object to retrieve - # a parameter value - self.bindparam_trackers = [] - - # a list of callables generated from _cache_key_getter_* functions - # these callables work to generate a cache key for the lambda - # based on what's inside its closure variables. - self.closure_trackers = [] - - self.build_py_wrappers = [] - - if enable_tracking: - if track_on: - self._init_track_on(track_on) - - self._init_globals(fn) - - if closure: - self._init_closure(fn) - - self._setup_additional_closure_trackers(fn, lambda_element, opts) - - def _init_track_on(self, track_on): - self.closure_trackers.extend( - self._cache_key_getter_track_on(idx, elem) - for idx, elem in enumerate(track_on) - ) - - def _init_globals(self, fn): - build_py_wrappers = self.build_py_wrappers - bindparam_trackers = self.bindparam_trackers - track_bound_values = self.track_bound_values - - for name in fn.__code__.co_names: - if name not in fn.__globals__: - continue - - _bound_value = self._roll_down_to_literal(fn.__globals__[name]) - - if coercions._deep_is_literal(_bound_value): - build_py_wrappers.append((name, None)) - if track_bound_values: - bindparam_trackers.append( - self._bound_parameter_getter_func_globals(name) - ) - - def _init_closure(self, fn): - build_py_wrappers = self.build_py_wrappers - closure = fn.__closure__ - - track_bound_values = self.track_bound_values - track_closure_variables = self.track_closure_variables - bindparam_trackers = self.bindparam_trackers - closure_trackers = self.closure_trackers - - for closure_index, (fv, cell) in enumerate( - zip(fn.__code__.co_freevars, closure) - ): - _bound_value = self._roll_down_to_literal(cell.cell_contents) - - if coercions._deep_is_literal(_bound_value): - build_py_wrappers.append((fv, closure_index)) - if track_bound_values: - bindparam_trackers.append( - self._bound_parameter_getter_func_closure( - fv, closure_index - ) - ) - else: - # for normal cell contents, add them to a list that - # we can compare later when we get new lambdas. if - # any identities have changed, then we will - # recalculate the whole lambda and run it again. - - if track_closure_variables: - closure_trackers.append( - self._cache_key_getter_closure_variable( - fn, fv, closure_index, cell.cell_contents - ) - ) - - def _setup_additional_closure_trackers(self, fn, lambda_element, opts): - # an additional step is to actually run the function, then - # go through the PyWrapper objects that were set up to catch a bound - # parameter. then if they *didn't* make a param, oh they're another - # object in the closure we have to track for our cache key. so - # create trackers to catch those. - - analyzed_function = AnalyzedFunction( - self, - lambda_element, - None, - fn, - ) - - closure_trackers = self.closure_trackers - - for pywrapper in analyzed_function.closure_pywrappers: - if not pywrapper._sa__has_param: - closure_trackers.append( - self._cache_key_getter_tracked_literal(fn, pywrapper) - ) - - @classmethod - def _roll_down_to_literal(cls, element): - is_clause_element = hasattr(element, "__clause_element__") - - if is_clause_element: - while not isinstance( - element, (elements.ClauseElement, schema.SchemaItem, type) - ): - try: - element = element.__clause_element__() - except AttributeError: - break - - if not is_clause_element: - insp = inspection.inspect(element, raiseerr=False) - if insp is not None: - try: - return insp.__clause_element__() - except AttributeError: - return insp - - # TODO: should we coerce consts None/True/False here? - return element - else: - return element - - def _bound_parameter_getter_func_globals(self, name): - """Return a getter that will extend a list of bound parameters - with new entries from the ``__globals__`` collection of a particular - lambda. - - """ - - def extract_parameter_value( - current_fn, tracker_instrumented_fn, result - ): - wrapper = tracker_instrumented_fn.__globals__[name] - object.__getattribute__(wrapper, "_extract_bound_parameters")( - current_fn.__globals__[name], result - ) - - return extract_parameter_value - - def _bound_parameter_getter_func_closure(self, name, closure_index): - """Return a getter that will extend a list of bound parameters - with new entries from the ``__closure__`` collection of a particular - lambda. - - """ - - def extract_parameter_value( - current_fn, tracker_instrumented_fn, result - ): - wrapper = tracker_instrumented_fn.__closure__[ - closure_index - ].cell_contents - object.__getattribute__(wrapper, "_extract_bound_parameters")( - current_fn.__closure__[closure_index].cell_contents, result - ) - - return extract_parameter_value - - def _cache_key_getter_track_on(self, idx, elem): - """Return a getter that will extend a cache key with new entries - from the "track_on" parameter passed to a :class:`.LambdaElement`. - - """ - - if isinstance(elem, tuple): - # tuple must contain hascachekey elements - def get(closure, opts, anon_map, bindparams): - return tuple( - tup_elem._gen_cache_key(anon_map, bindparams) - for tup_elem in opts.track_on[idx] - ) - - elif isinstance(elem, _cache_key.HasCacheKey): - - def get(closure, opts, anon_map, bindparams): - return opts.track_on[idx]._gen_cache_key(anon_map, bindparams) - - else: - - def get(closure, opts, anon_map, bindparams): - return opts.track_on[idx] - - return get - - def _cache_key_getter_closure_variable( - self, - fn, - variable_name, - idx, - cell_contents, - use_clause_element=False, - use_inspect=False, - ): - """Return a getter that will extend a cache key with new entries - from the ``__closure__`` collection of a particular lambda. - - """ - - if isinstance(cell_contents, _cache_key.HasCacheKey): - - def get(closure, opts, anon_map, bindparams): - obj = closure[idx].cell_contents - if use_inspect: - obj = inspection.inspect(obj) - elif use_clause_element: - while hasattr(obj, "__clause_element__"): - if not getattr(obj, "is_clause_element", False): - obj = obj.__clause_element__() - - return obj._gen_cache_key(anon_map, bindparams) - - elif isinstance(cell_contents, types.FunctionType): - - def get(closure, opts, anon_map, bindparams): - return closure[idx].cell_contents.__code__ - - elif isinstance(cell_contents, collections_abc.Sequence): - - def get(closure, opts, anon_map, bindparams): - contents = closure[idx].cell_contents - - try: - return tuple( - elem._gen_cache_key(anon_map, bindparams) - for elem in contents - ) - except AttributeError as ae: - self._raise_for_uncacheable_closure_variable( - variable_name, fn, from_=ae - ) - - else: - # if the object is a mapped class or aliased class, or some - # other object in the ORM realm of things like that, imitate - # the logic used in coercions.expect() to roll it down to the - # SQL element - element = cell_contents - is_clause_element = False - while hasattr(element, "__clause_element__"): - is_clause_element = True - if not getattr(element, "is_clause_element", False): - element = element.__clause_element__() - else: - break - - if not is_clause_element: - insp = inspection.inspect(element, raiseerr=False) - if insp is not None: - return self._cache_key_getter_closure_variable( - fn, variable_name, idx, insp, use_inspect=True - ) - else: - return self._cache_key_getter_closure_variable( - fn, variable_name, idx, element, use_clause_element=True - ) - - self._raise_for_uncacheable_closure_variable(variable_name, fn) - - return get - - def _raise_for_uncacheable_closure_variable( - self, variable_name, fn, from_=None - ): - raise exc.InvalidRequestError( - "Closure variable named '%s' inside of lambda callable %s " - "does not refer to a cacheable SQL element, and also does not " - "appear to be serving as a SQL literal bound value based on " - "the default " - "SQL expression returned by the function. This variable " - "needs to remain outside the scope of a SQL-generating lambda " - "so that a proper cache key may be generated from the " - "lambda's state. Evaluate this variable outside of the " - "lambda, set track_on=[<elements>] to explicitly select " - "closure elements to track, or set " - "track_closure_variables=False to exclude " - "closure variables from being part of the cache key." - % (variable_name, fn.__code__), - ) from from_ - - def _cache_key_getter_tracked_literal(self, fn, pytracker): - """Return a getter that will extend a cache key with new entries - from the ``__closure__`` collection of a particular lambda. - - this getter differs from _cache_key_getter_closure_variable - in that these are detected after the function is run, and PyWrapper - objects have recorded that a particular literal value is in fact - not being interpreted as a bound parameter. - - """ - - elem = pytracker._sa__to_evaluate - closure_index = pytracker._sa__closure_index - variable_name = pytracker._sa__name - - return self._cache_key_getter_closure_variable( - fn, variable_name, closure_index, elem - ) - - -class NonAnalyzedFunction: - __slots__ = ("expr",) - - closure_bindparams: Optional[List[BindParameter[Any]]] = None - bindparam_trackers: Optional[List[_BoundParameterGetter]] = None - - is_sequence = False - - expr: ClauseElement - - def __init__(self, expr: ClauseElement): - self.expr = expr - - @property - def expected_expr(self) -> ClauseElement: - return self.expr - - -class AnalyzedFunction: - __slots__ = ( - "analyzed_code", - "fn", - "closure_pywrappers", - "tracker_instrumented_fn", - "expr", - "bindparam_trackers", - "expected_expr", - "is_sequence", - "propagate_attrs", - "closure_bindparams", - ) - - closure_bindparams: Optional[List[BindParameter[Any]]] - expected_expr: Union[ClauseElement, List[ClauseElement]] - bindparam_trackers: Optional[List[_BoundParameterGetter]] - - def __init__( - self, - analyzed_code, - lambda_element, - apply_propagate_attrs, - fn, - ): - self.analyzed_code = analyzed_code - self.fn = fn - - self.bindparam_trackers = analyzed_code.bindparam_trackers - - self._instrument_and_run_function(lambda_element) - - self._coerce_expression(lambda_element, apply_propagate_attrs) - - def _instrument_and_run_function(self, lambda_element): - analyzed_code = self.analyzed_code - - fn = self.fn - self.closure_pywrappers = closure_pywrappers = [] - - build_py_wrappers = analyzed_code.build_py_wrappers - - if not build_py_wrappers: - self.tracker_instrumented_fn = tracker_instrumented_fn = fn - self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) - else: - track_closure_variables = analyzed_code.track_closure_variables - closure = fn.__closure__ - - # will form the __closure__ of the function when we rebuild it - if closure: - new_closure = { - fv: cell.cell_contents - for fv, cell in zip(fn.__code__.co_freevars, closure) - } - else: - new_closure = {} - - # will form the __globals__ of the function when we rebuild it - new_globals = fn.__globals__.copy() - - for name, closure_index in build_py_wrappers: - if closure_index is not None: - value = closure[closure_index].cell_contents - new_closure[name] = bind = PyWrapper( - fn, - name, - value, - closure_index=closure_index, - track_bound_values=( - self.analyzed_code.track_bound_values - ), - ) - if track_closure_variables: - closure_pywrappers.append(bind) - else: - value = fn.__globals__[name] - new_globals[name] = bind = PyWrapper(fn, name, value) - - # rewrite the original fn. things that look like they will - # become bound parameters are wrapped in a PyWrapper. - self.tracker_instrumented_fn = tracker_instrumented_fn = ( - self._rewrite_code_obj( - fn, - [new_closure[name] for name in fn.__code__.co_freevars], - new_globals, - ) - ) - - # now invoke the function. This will give us a new SQL - # expression, but all the places that there would be a bound - # parameter, the PyWrapper in its place will give us a bind - # with a predictable name we can match up later. - - # additionally, each PyWrapper will log that it did in fact - # create a parameter, otherwise, it's some kind of Python - # object in the closure and we want to track that, to make - # sure it doesn't change to something else, or if it does, - # that we create a different tracked function with that - # variable. - self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) - - def _coerce_expression(self, lambda_element, apply_propagate_attrs): - """Run the tracker-generated expression through coercion rules. - - After the user-defined lambda has been invoked to produce a statement - for re-use, run it through coercion rules to both check that it's the - correct type of object and also to coerce it to its useful form. - - """ - - parent_lambda = lambda_element.parent_lambda - expr = self.expr - - if parent_lambda is None: - if isinstance(expr, collections_abc.Sequence): - self.expected_expr = [ - cast( - "ClauseElement", - coercions.expect( - lambda_element.role, - sub_expr, - apply_propagate_attrs=apply_propagate_attrs, - ), - ) - for sub_expr in expr - ] - self.is_sequence = True - else: - self.expected_expr = cast( - "ClauseElement", - coercions.expect( - lambda_element.role, - expr, - apply_propagate_attrs=apply_propagate_attrs, - ), - ) - self.is_sequence = False - else: - self.expected_expr = expr - self.is_sequence = False - - if apply_propagate_attrs is not None: - self.propagate_attrs = apply_propagate_attrs._propagate_attrs - else: - self.propagate_attrs = util.EMPTY_DICT - - def _rewrite_code_obj(self, f, cell_values, globals_): - """Return a copy of f, with a new closure and new globals - - yes it works in pypy :P - - """ - - argrange = range(len(cell_values)) - - code = "def make_cells():\n" - if cell_values: - code += " (%s) = (%s)\n" % ( - ", ".join("i%d" % i for i in argrange), - ", ".join("o%d" % i for i in argrange), - ) - code += " def closure():\n" - code += " return %s\n" % ", ".join("i%d" % i for i in argrange) - code += " return closure.__closure__" - vars_ = {"o%d" % i: cell_values[i] for i in argrange} - exec(code, vars_, vars_) - closure = vars_["make_cells"]() - - func = type(f)( - f.__code__, globals_, f.__name__, f.__defaults__, closure - ) - func.__annotations__ = f.__annotations__ - func.__kwdefaults__ = f.__kwdefaults__ - func.__doc__ = f.__doc__ - func.__module__ = f.__module__ - - return func - - -class PyWrapper(ColumnOperators): - """A wrapper object that is injected into the ``__globals__`` and - ``__closure__`` of a Python function. - - When the function is instrumented with :class:`.PyWrapper` objects, it is - then invoked just once in order to set up the wrappers. We look through - all the :class:`.PyWrapper` objects we made to find the ones that generated - a :class:`.BindParameter` object, e.g. the expression system interpreted - something as a literal. Those positions in the globals/closure are then - ones that we will look at, each time a new lambda comes in that refers to - the same ``__code__`` object. In this way, we keep a single version of - the SQL expression that this lambda produced, without calling upon the - Python function that created it more than once, unless its other closure - variables have changed. The expression is then transformed to have the - new bound values embedded into it. - - """ - - def __init__( - self, - fn, - name, - to_evaluate, - closure_index=None, - getter=None, - track_bound_values=True, - ): - self.fn = fn - self._name = name - self._to_evaluate = to_evaluate - self._param = None - self._has_param = False - self._bind_paths = {} - self._getter = getter - self._closure_index = closure_index - self.track_bound_values = track_bound_values - - def __call__(self, *arg, **kw): - elem = object.__getattribute__(self, "_to_evaluate") - value = elem(*arg, **kw) - if ( - self._sa_track_bound_values - and coercions._deep_is_literal(value) - and not isinstance( - # TODO: coverage where an ORM option or similar is here - value, - _cache_key.HasCacheKey, - ) - ): - name = object.__getattribute__(self, "_name") - raise exc.InvalidRequestError( - "Can't invoke Python callable %s() inside of lambda " - "expression argument at %s; lambda SQL constructs should " - "not invoke functions from closure variables to produce " - "literal values since the " - "lambda SQL system normally extracts bound values without " - "actually " - "invoking the lambda or any functions within it. Call the " - "function outside of the " - "lambda and assign to a local variable that is used in the " - "lambda as a closure variable, or set " - "track_bound_values=False if the return value of this " - "function is used in some other way other than a SQL bound " - "value." % (name, self._sa_fn.__code__) - ) - else: - return value - - def operate(self, op, *other, **kwargs): - elem = object.__getattribute__(self, "_py_wrapper_literal")() - return op(elem, *other, **kwargs) - - def reverse_operate(self, op, other, **kwargs): - elem = object.__getattribute__(self, "_py_wrapper_literal")() - return op(other, elem, **kwargs) - - def _extract_bound_parameters(self, starting_point, result_list): - param = object.__getattribute__(self, "_param") - if param is not None: - param = param._with_value(starting_point, maintain_key=True) - result_list.append(param) - for pywrapper in object.__getattribute__(self, "_bind_paths").values(): - getter = object.__getattribute__(pywrapper, "_getter") - element = getter(starting_point) - pywrapper._sa__extract_bound_parameters(element, result_list) - - def _py_wrapper_literal(self, expr=None, operator=None, **kw): - param = object.__getattribute__(self, "_param") - to_evaluate = object.__getattribute__(self, "_to_evaluate") - if param is None: - name = object.__getattribute__(self, "_name") - self._param = param = elements.BindParameter( - name, - required=False, - unique=True, - _compared_to_operator=operator, - _compared_to_type=expr.type if expr is not None else None, - ) - self._has_param = True - return param._with_value(to_evaluate, maintain_key=True) - - def __bool__(self): - to_evaluate = object.__getattribute__(self, "_to_evaluate") - return bool(to_evaluate) - - def __getattribute__(self, key): - if key.startswith("_sa_"): - return object.__getattribute__(self, key[4:]) - elif key in ( - "__clause_element__", - "operate", - "reverse_operate", - "_py_wrapper_literal", - "__class__", - "__dict__", - ): - return object.__getattribute__(self, key) - - if key.startswith("__"): - elem = object.__getattribute__(self, "_to_evaluate") - return getattr(elem, key) - else: - return self._sa__add_getter(key, operator.attrgetter) - - def __iter__(self): - elem = object.__getattribute__(self, "_to_evaluate") - return iter(elem) - - def __getitem__(self, key): - elem = object.__getattribute__(self, "_to_evaluate") - if not hasattr(elem, "__getitem__"): - raise AttributeError("__getitem__") - - if isinstance(key, PyWrapper): - # TODO: coverage - raise exc.InvalidRequestError( - "Dictionary keys / list indexes inside of a cached " - "lambda must be Python literals only" - ) - return self._sa__add_getter(key, operator.itemgetter) - - def _add_getter(self, key, getter_fn): - bind_paths = object.__getattribute__(self, "_bind_paths") - - bind_path_key = (key, getter_fn) - if bind_path_key in bind_paths: - return bind_paths[bind_path_key] - - getter = getter_fn(key) - elem = object.__getattribute__(self, "_to_evaluate") - value = getter(elem) - - rolled_down_value = AnalyzedCode._roll_down_to_literal(value) - - if coercions._deep_is_literal(rolled_down_value): - wrapper = PyWrapper(self._sa_fn, key, value, getter=getter) - bind_paths[bind_path_key] = wrapper - return wrapper - else: - return value - - -@inspection._inspects(LambdaElement) -def insp(lmb): - return inspection.inspect(lmb._resolved) |