diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/sqlalchemy/orm/evaluator.py | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/orm/evaluator.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/sqlalchemy/orm/evaluator.py | 368 |
1 files changed, 368 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/evaluator.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/evaluator.py new file mode 100644 index 0000000..f264454 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/evaluator.py @@ -0,0 +1,368 @@ +# orm/evaluator.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""Evaluation functions used **INTERNALLY** by ORM DML use cases. + + +This module is **private, for internal use by SQLAlchemy**. + +.. versionchanged:: 2.0.4 renamed ``EvaluatorCompiler`` to + ``_EvaluatorCompiler``. + +""" + + +from __future__ import annotations + +from typing import Type + +from . import exc as orm_exc +from .base import LoaderCallableStatus +from .base import PassiveFlag +from .. import exc +from .. import inspect +from ..sql import and_ +from ..sql import operators +from ..sql.sqltypes import Integer +from ..sql.sqltypes import Numeric +from ..util import warn_deprecated + + +class UnevaluatableError(exc.InvalidRequestError): + pass + + +class _NoObject(operators.ColumnOperators): + def operate(self, *arg, **kw): + return None + + def reverse_operate(self, *arg, **kw): + return None + + +class _ExpiredObject(operators.ColumnOperators): + def operate(self, *arg, **kw): + return self + + def reverse_operate(self, *arg, **kw): + return self + + +_NO_OBJECT = _NoObject() +_EXPIRED_OBJECT = _ExpiredObject() + + +class _EvaluatorCompiler: + def __init__(self, target_cls=None): + self.target_cls = target_cls + + def process(self, clause, *clauses): + if clauses: + clause = and_(clause, *clauses) + + meth = getattr(self, f"visit_{clause.__visit_name__}", None) + if not meth: + raise UnevaluatableError( + f"Cannot evaluate {type(clause).__name__}" + ) + return meth(clause) + + def visit_grouping(self, clause): + return self.process(clause.element) + + def visit_null(self, clause): + return lambda obj: None + + def visit_false(self, clause): + return lambda obj: False + + def visit_true(self, clause): + return lambda obj: True + + def visit_column(self, clause): + try: + parentmapper = clause._annotations["parentmapper"] + except KeyError as ke: + raise UnevaluatableError( + f"Cannot evaluate column: {clause}" + ) from ke + + if self.target_cls and not issubclass( + self.target_cls, parentmapper.class_ + ): + raise UnevaluatableError( + "Can't evaluate criteria against " + f"alternate class {parentmapper.class_}" + ) + + parentmapper._check_configure() + + # we'd like to use "proxy_key" annotation to get the "key", however + # in relationship primaryjoin cases proxy_key is sometimes deannotated + # and sometimes apparently not present in the first place (?). + # While I can stop it from being deannotated (though need to see if + # this breaks other things), not sure right now about cases where it's + # not there in the first place. can fix at some later point. + # key = clause._annotations["proxy_key"] + + # for now, use the old way + try: + key = parentmapper._columntoproperty[clause].key + except orm_exc.UnmappedColumnError as err: + raise UnevaluatableError( + f"Cannot evaluate expression: {err}" + ) from err + + # note this used to fall back to a simple `getattr(obj, key)` evaluator + # if impl was None; as of #8656, we ensure mappers are configured + # so that impl is available + impl = parentmapper.class_manager[key].impl + + def get_corresponding_attr(obj): + if obj is None: + return _NO_OBJECT + state = inspect(obj) + dict_ = state.dict + + value = impl.get( + state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH + ) + if value is LoaderCallableStatus.PASSIVE_NO_RESULT: + return _EXPIRED_OBJECT + return value + + return get_corresponding_attr + + def visit_tuple(self, clause): + return self.visit_clauselist(clause) + + def visit_expression_clauselist(self, clause): + return self.visit_clauselist(clause) + + def visit_clauselist(self, clause): + evaluators = [self.process(clause) for clause in clause.clauses] + + dispatch = ( + f"visit_{clause.operator.__name__.rstrip('_')}_clauselist_op" + ) + meth = getattr(self, dispatch, None) + if meth: + return meth(clause.operator, evaluators, clause) + else: + raise UnevaluatableError( + f"Cannot evaluate clauselist with operator {clause.operator}" + ) + + def visit_binary(self, clause): + eval_left = self.process(clause.left) + eval_right = self.process(clause.right) + + dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op" + meth = getattr(self, dispatch, None) + if meth: + return meth(clause.operator, eval_left, eval_right, clause) + else: + raise UnevaluatableError( + f"Cannot evaluate {type(clause).__name__} with " + f"operator {clause.operator}" + ) + + def visit_or_clauselist_op(self, operator, evaluators, clause): + def evaluate(obj): + has_null = False + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value: + return True + has_null = has_null or value is None + if has_null: + return None + return False + + return evaluate + + def visit_and_clauselist_op(self, operator, evaluators, clause): + def evaluate(obj): + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + + if not value: + if value is None or value is _NO_OBJECT: + return None + return False + return True + + return evaluate + + def visit_comma_op_clauselist_op(self, operator, evaluators, clause): + def evaluate(obj): + values = [] + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value is None or value is _NO_OBJECT: + return None + values.append(value) + return tuple(values) + + return evaluate + + def visit_custom_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + if operator.python_impl: + return self._straight_evaluate( + operator, eval_left, eval_right, clause + ) + else: + raise UnevaluatableError( + f"Custom operator {operator.opstring!r} can't be evaluated " + "in Python unless it specifies a callable using " + "`.python_impl`." + ) + + def visit_is_binary_op(self, operator, eval_left, eval_right, clause): + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + return left_val == right_val + + return evaluate + + def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause): + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + return left_val != right_val + + return evaluate + + def _straight_evaluate(self, operator, eval_left, eval_right, clause): + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif left_val is None or right_val is None: + return None + + return operator(eval_left(obj), eval_right(obj)) + + return evaluate + + def _straight_evaluate_numeric_only( + self, operator, eval_left, eval_right, clause + ): + if clause.left.type._type_affinity not in ( + Numeric, + Integer, + ) or clause.right.type._type_affinity not in (Numeric, Integer): + raise UnevaluatableError( + f'Cannot evaluate math operator "{operator.__name__}" for ' + f"datatypes {clause.left.type}, {clause.right.type}" + ) + + return self._straight_evaluate(operator, eval_left, eval_right, clause) + + visit_add_binary_op = _straight_evaluate_numeric_only + visit_mul_binary_op = _straight_evaluate_numeric_only + visit_sub_binary_op = _straight_evaluate_numeric_only + visit_mod_binary_op = _straight_evaluate_numeric_only + visit_truediv_binary_op = _straight_evaluate_numeric_only + visit_lt_binary_op = _straight_evaluate + visit_le_binary_op = _straight_evaluate + visit_ne_binary_op = _straight_evaluate + visit_gt_binary_op = _straight_evaluate + visit_ge_binary_op = _straight_evaluate + visit_eq_binary_op = _straight_evaluate + + def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause): + return self._straight_evaluate( + lambda a, b: a in b if a is not _NO_OBJECT else None, + eval_left, + eval_right, + clause, + ) + + def visit_not_in_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + return self._straight_evaluate( + lambda a, b: a not in b if a is not _NO_OBJECT else None, + eval_left, + eval_right, + clause, + ) + + def visit_concat_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + return self._straight_evaluate( + lambda a, b: a + b, eval_left, eval_right, clause + ) + + def visit_startswith_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + return self._straight_evaluate( + lambda a, b: a.startswith(b), eval_left, eval_right, clause + ) + + def visit_endswith_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + return self._straight_evaluate( + lambda a, b: a.endswith(b), eval_left, eval_right, clause + ) + + def visit_unary(self, clause): + eval_inner = self.process(clause.element) + if clause.operator is operators.inv: + + def evaluate(obj): + value = eval_inner(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value is None: + return None + return not value + + return evaluate + raise UnevaluatableError( + f"Cannot evaluate {type(clause).__name__} " + f"with operator {clause.operator}" + ) + + def visit_bindparam(self, clause): + if clause.callable: + val = clause.callable() + else: + val = clause.value + return lambda obj: val + + +def __getattr__(name: str) -> Type[_EvaluatorCompiler]: + if name == "EvaluatorCompiler": + warn_deprecated( + "Direct use of 'EvaluatorCompiler' is not supported, and this " + "name will be removed in a future release. " + "'_EvaluatorCompiler' is for internal use only", + "2.0", + ) + return _EvaluatorCompiler + else: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") |