diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/sqlalchemy/testing | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/testing')
76 files changed, 20469 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__init__.py new file mode 100644 index 0000000..d3a6f32 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__init__.py @@ -0,0 +1,95 @@ +# testing/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from unittest import mock + +from . import config +from .assertions import assert_raises +from .assertions import assert_raises_context_ok +from .assertions import assert_raises_message +from .assertions import assert_raises_message_context_ok +from .assertions import assert_warns +from .assertions import assert_warns_message +from .assertions import AssertsCompiledSQL +from .assertions import AssertsExecutionResults +from .assertions import ComparesIndexes +from .assertions import ComparesTables +from .assertions import emits_warning +from .assertions import emits_warning_on +from .assertions import eq_ +from .assertions import eq_ignore_whitespace +from .assertions import eq_regex +from .assertions import expect_deprecated +from .assertions import expect_deprecated_20 +from .assertions import expect_raises +from .assertions import expect_raises_message +from .assertions import expect_warnings +from .assertions import in_ +from .assertions import int_within_variance +from .assertions import is_ +from .assertions import is_false +from .assertions import is_instance_of +from .assertions import is_none +from .assertions import is_not +from .assertions import is_not_ +from .assertions import is_not_none +from .assertions import is_true +from .assertions import le_ +from .assertions import ne_ +from .assertions import not_in +from .assertions import not_in_ +from .assertions import startswith_ +from .assertions import uses_deprecated +from .config import add_to_marker +from .config import async_test +from .config import combinations +from .config import combinations_list +from .config import db +from .config import fixture +from .config import requirements as requires +from .config import skip_test +from .config import Variation +from .config import variation +from .config import variation_fixture +from .exclusions import _is_excluded +from .exclusions import _server_version +from .exclusions import against as _against +from .exclusions import db_spec +from .exclusions import exclude +from .exclusions import fails +from .exclusions import fails_if +from .exclusions import fails_on +from .exclusions import fails_on_everything_except +from .exclusions import future +from .exclusions import only_if +from .exclusions import only_on +from .exclusions import skip +from .exclusions import skip_if +from .schema import eq_clause_element +from .schema import eq_type_affinity +from .util import adict +from .util import fail +from .util import flag_combinations +from .util import force_drop_names +from .util import lambda_combinations +from .util import metadata_fixture +from .util import provide_metadata +from .util import resolve_lambda +from .util import rowset +from .util import run_as_contextmanager +from .util import teardown_events +from .warnings import assert_warnings +from .warnings import warn_test_suite + + +def against(*queries): + return _against(config._current, *queries) + + +crashes = skip diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..7f6a336 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/assertions.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/assertions.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..22f91d1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/assertions.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/assertsql.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/assertsql.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c435ea0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/assertsql.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/asyncio.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/asyncio.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f460eeb --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/asyncio.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/config.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/config.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..d420c76 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/config.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/engines.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/engines.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..e21f79d --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/engines.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/entities.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/entities.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..611fd62 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/entities.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/exclusions.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/exclusions.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..cfa8c09 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/exclusions.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/pickleable.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/pickleable.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ee832ea --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/pickleable.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/profiling.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/profiling.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..cabd6ee --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/profiling.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/provision.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/provision.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5da4c30 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/provision.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/requirements.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/requirements.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b87b700 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/requirements.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/schema.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/schema.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b64caf4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/schema.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/util.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/util.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..33f6afa --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/util.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/warnings.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/warnings.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..67682c8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/warnings.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertions.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertions.py new file mode 100644 index 0000000..baef79d --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertions.py @@ -0,0 +1,989 @@ +# testing/assertions.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +from collections import defaultdict +import contextlib +from copy import copy +from itertools import filterfalse +import re +import sys +import warnings + +from . import assertsql +from . import config +from . import engines +from . import mock +from .exclusions import db_spec +from .util import fail +from .. import exc as sa_exc +from .. import schema +from .. import sql +from .. import types as sqltypes +from .. import util +from ..engine import default +from ..engine import url +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..util import decorator + + +def expect_warnings(*messages, **kw): + """Context manager which expects one or more warnings. + + With no arguments, squelches all SAWarning emitted via + sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise + pass string expressions that will match selected warnings via regex; + all non-matching warnings are sent through. + + The expect version **asserts** that the warnings were in fact seen. + + Note that the test suite sets SAWarning warnings to raise exceptions. + + """ # noqa + return _expect_warnings_sqla_only(sa_exc.SAWarning, messages, **kw) + + +@contextlib.contextmanager +def expect_warnings_on(db, *messages, **kw): + """Context manager which expects one or more warnings on specific + dialects. + + The expect version **asserts** that the warnings were in fact seen. + + """ + spec = db_spec(db) + + if isinstance(db, str) and not spec(config._current): + yield + else: + with expect_warnings(*messages, **kw): + yield + + +def emits_warning(*messages): + """Decorator form of expect_warnings(). + + Note that emits_warning does **not** assert that the warnings + were in fact seen. + + """ + + @decorator + def decorate(fn, *args, **kw): + with expect_warnings(assert_=False, *messages): + return fn(*args, **kw) + + return decorate + + +def expect_deprecated(*messages, **kw): + return _expect_warnings_sqla_only( + sa_exc.SADeprecationWarning, messages, **kw + ) + + +def expect_deprecated_20(*messages, **kw): + return _expect_warnings_sqla_only( + sa_exc.Base20DeprecationWarning, messages, **kw + ) + + +def emits_warning_on(db, *messages): + """Mark a test as emitting a warning on a specific dialect. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + + Note that emits_warning_on does **not** assert that the warnings + were in fact seen. + + """ + + @decorator + def decorate(fn, *args, **kw): + with expect_warnings_on(db, assert_=False, *messages): + return fn(*args, **kw) + + return decorate + + +def uses_deprecated(*messages): + """Mark a test as immune from fatal deprecation warnings. + + With no arguments, squelches all SADeprecationWarning failures. + Or pass one or more strings; these will be matched to the root + of the warning description by warnings.filterwarnings(). + + As a special case, you may pass a function name prefixed with // + and it will be re-written as needed to match the standard warning + verbiage emitted by the sqlalchemy.util.deprecated decorator. + + Note that uses_deprecated does **not** assert that the warnings + were in fact seen. + + """ + + @decorator + def decorate(fn, *args, **kw): + with expect_deprecated(*messages, assert_=False): + return fn(*args, **kw) + + return decorate + + +_FILTERS = None +_SEEN = None +_EXC_CLS = None + + +def _expect_warnings_sqla_only( + exc_cls, + messages, + regex=True, + search_msg=False, + assert_=True, +): + """SQLAlchemy internal use only _expect_warnings(). + + Alembic is using _expect_warnings() directly, and should be updated + to use this new interface. + + """ + return _expect_warnings( + exc_cls, + messages, + regex=regex, + search_msg=search_msg, + assert_=assert_, + raise_on_any_unexpected=True, + ) + + +@contextlib.contextmanager +def _expect_warnings( + exc_cls, + messages, + regex=True, + search_msg=False, + assert_=True, + raise_on_any_unexpected=False, + squelch_other_warnings=False, +): + global _FILTERS, _SEEN, _EXC_CLS + + if regex or search_msg: + filters = [re.compile(msg, re.I | re.S) for msg in messages] + else: + filters = list(messages) + + if _FILTERS is not None: + # nested call; update _FILTERS and _SEEN, return. outer + # block will assert our messages + assert _SEEN is not None + assert _EXC_CLS is not None + _FILTERS.extend(filters) + _SEEN.update(filters) + _EXC_CLS += (exc_cls,) + yield + else: + seen = _SEEN = set(filters) + _FILTERS = filters + _EXC_CLS = (exc_cls,) + + if raise_on_any_unexpected: + + def real_warn(msg, *arg, **kw): + raise AssertionError("Got unexpected warning: %r" % msg) + + else: + real_warn = warnings.warn + + def our_warn(msg, *arg, **kw): + if isinstance(msg, _EXC_CLS): + exception = type(msg) + msg = str(msg) + elif arg: + exception = arg[0] + else: + exception = None + + if not exception or not issubclass(exception, _EXC_CLS): + if not squelch_other_warnings: + return real_warn(msg, *arg, **kw) + else: + return + + if not filters and not raise_on_any_unexpected: + return + + for filter_ in filters: + if ( + (search_msg and filter_.search(msg)) + or (regex and filter_.match(msg)) + or (not regex and filter_ == msg) + ): + seen.discard(filter_) + break + else: + if not squelch_other_warnings: + real_warn(msg, *arg, **kw) + + with mock.patch("warnings.warn", our_warn): + try: + yield + finally: + _SEEN = _FILTERS = _EXC_CLS = None + + if assert_: + assert not seen, "Warnings were not seen: %s" % ", ".join( + "%r" % (s.pattern if regex else s) for s in seen + ) + + +def global_cleanup_assertions(): + """Check things that have to be finalized at the end of a test suite. + + Hardcoded at the moment, a modular system can be built here + to support things like PG prepared transactions, tables all + dropped, etc. + + """ + _assert_no_stray_pool_connections() + + +def _assert_no_stray_pool_connections(): + engines.testing_reaper.assert_all_closed() + + +def int_within_variance(expected, received, variance): + deviance = int(expected * variance) + assert ( + abs(received - expected) < deviance + ), "Given int value %s is not within %d%% of expected value %s" % ( + received, + variance * 100, + expected, + ) + + +def eq_regex(a, b, msg=None): + assert re.match(b, a), msg or "%r !~ %r" % (a, b) + + +def eq_(a, b, msg=None): + """Assert a == b, with repr messaging on failure.""" + assert a == b, msg or "%r != %r" % (a, b) + + +def ne_(a, b, msg=None): + """Assert a != b, with repr messaging on failure.""" + assert a != b, msg or "%r == %r" % (a, b) + + +def le_(a, b, msg=None): + """Assert a <= b, with repr messaging on failure.""" + assert a <= b, msg or "%r != %r" % (a, b) + + +def is_instance_of(a, b, msg=None): + assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b) + + +def is_none(a, msg=None): + is_(a, None, msg=msg) + + +def is_not_none(a, msg=None): + is_not(a, None, msg=msg) + + +def is_true(a, msg=None): + is_(bool(a), True, msg=msg) + + +def is_false(a, msg=None): + is_(bool(a), False, msg=msg) + + +def is_(a, b, msg=None): + """Assert a is b, with repr messaging on failure.""" + assert a is b, msg or "%r is not %r" % (a, b) + + +def is_not(a, b, msg=None): + """Assert a is not b, with repr messaging on failure.""" + assert a is not b, msg or "%r is %r" % (a, b) + + +# deprecated. See #5429 +is_not_ = is_not + + +def in_(a, b, msg=None): + """Assert a in b, with repr messaging on failure.""" + assert a in b, msg or "%r not in %r" % (a, b) + + +def not_in(a, b, msg=None): + """Assert a in not b, with repr messaging on failure.""" + assert a not in b, msg or "%r is in %r" % (a, b) + + +# deprecated. See #5429 +not_in_ = not_in + + +def startswith_(a, fragment, msg=None): + """Assert a.startswith(fragment), with repr messaging on failure.""" + assert a.startswith(fragment), msg or "%r does not start with %r" % ( + a, + fragment, + ) + + +def eq_ignore_whitespace(a, b, msg=None): + a = re.sub(r"^\s+?|\n", "", a) + a = re.sub(r" {2,}", " ", a) + a = re.sub(r"\t", "", a) + b = re.sub(r"^\s+?|\n", "", b) + b = re.sub(r" {2,}", " ", b) + b = re.sub(r"\t", "", b) + + assert a == b, msg or "%r != %r" % (a, b) + + +def _assert_proper_exception_context(exception): + """assert that any exception we're catching does not have a __context__ + without a __cause__, and that __suppress_context__ is never set. + + Python 3 will report nested as exceptions as "during the handling of + error X, error Y occurred". That's not what we want to do. we want + these exceptions in a cause chain. + + """ + + if ( + exception.__context__ is not exception.__cause__ + and not exception.__suppress_context__ + ): + assert False, ( + "Exception %r was correctly raised but did not set a cause, " + "within context %r as its cause." + % (exception, exception.__context__) + ) + + +def assert_raises(except_cls, callable_, *args, **kw): + return _assert_raises(except_cls, callable_, args, kw, check_context=True) + + +def assert_raises_context_ok(except_cls, callable_, *args, **kw): + return _assert_raises(except_cls, callable_, args, kw) + + +def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): + return _assert_raises( + except_cls, callable_, args, kwargs, msg=msg, check_context=True + ) + + +def assert_warns(except_cls, callable_, *args, **kwargs): + """legacy adapter function for functions that were previously using + assert_raises with SAWarning or similar. + + has some workarounds to accommodate the fact that the callable completes + with this approach rather than stopping at the exception raise. + + + """ + with _expect_warnings_sqla_only(except_cls, [".*"]): + return callable_(*args, **kwargs) + + +def assert_warns_message(except_cls, msg, callable_, *args, **kwargs): + """legacy adapter function for functions that were previously using + assert_raises with SAWarning or similar. + + has some workarounds to accommodate the fact that the callable completes + with this approach rather than stopping at the exception raise. + + Also uses regex.search() to match the given message to the error string + rather than regex.match(). + + """ + with _expect_warnings_sqla_only( + except_cls, + [msg], + search_msg=True, + regex=False, + ): + return callable_(*args, **kwargs) + + +def assert_raises_message_context_ok( + except_cls, msg, callable_, *args, **kwargs +): + return _assert_raises(except_cls, callable_, args, kwargs, msg=msg) + + +def _assert_raises( + except_cls, callable_, args, kwargs, msg=None, check_context=False +): + with _expect_raises(except_cls, msg, check_context) as ec: + callable_(*args, **kwargs) + return ec.error + + +class _ErrorContainer: + error = None + + +@contextlib.contextmanager +def _expect_raises(except_cls, msg=None, check_context=False): + if ( + isinstance(except_cls, type) + and issubclass(except_cls, Warning) + or isinstance(except_cls, Warning) + ): + raise TypeError( + "Use expect_warnings for warnings, not " + "expect_raises / assert_raises" + ) + ec = _ErrorContainer() + if check_context: + are_we_already_in_a_traceback = sys.exc_info()[0] + try: + yield ec + success = False + except except_cls as err: + ec.error = err + success = True + if msg is not None: + # I'm often pdbing here, and "err" above isn't + # in scope, so assign the string explicitly + error_as_string = str(err) + assert re.search(msg, error_as_string, re.UNICODE), "%r !~ %s" % ( + msg, + error_as_string, + ) + if check_context and not are_we_already_in_a_traceback: + _assert_proper_exception_context(err) + print(str(err).encode("utf-8")) + + # it's generally a good idea to not carry traceback objects outside + # of the except: block, but in this case especially we seem to have + # hit some bug in either python 3.10.0b2 or greenlet or both which + # this seems to fix: + # https://github.com/python-greenlet/greenlet/issues/242 + del ec + + # assert outside the block so it works for AssertionError too ! + assert success, "Callable did not raise an exception" + + +def expect_raises(except_cls, check_context=True): + return _expect_raises(except_cls, check_context=check_context) + + +def expect_raises_message(except_cls, msg, check_context=True): + return _expect_raises(except_cls, msg=msg, check_context=check_context) + + +class AssertsCompiledSQL: + def assert_compile( + self, + clause, + result, + params=None, + checkparams=None, + for_executemany=False, + check_literal_execute=None, + check_post_param=None, + dialect=None, + checkpositional=None, + check_prefetch=None, + use_default_dialect=False, + allow_dialect_select=False, + supports_default_values=True, + supports_default_metavalue=True, + literal_binds=False, + render_postcompile=False, + schema_translate_map=None, + render_schema_translate=False, + default_schema_name=None, + from_linting=False, + check_param_order=True, + use_literal_execute_for_simple_int=False, + ): + if use_default_dialect: + dialect = default.DefaultDialect() + dialect.supports_default_values = supports_default_values + dialect.supports_default_metavalue = supports_default_metavalue + elif allow_dialect_select: + dialect = None + else: + if dialect is None: + dialect = getattr(self, "__dialect__", None) + + if dialect is None: + dialect = config.db.dialect + elif dialect == "default" or dialect == "default_qmark": + if dialect == "default": + dialect = default.DefaultDialect() + else: + dialect = default.DefaultDialect("qmark") + dialect.supports_default_values = supports_default_values + dialect.supports_default_metavalue = supports_default_metavalue + elif dialect == "default_enhanced": + dialect = default.StrCompileDialect() + elif isinstance(dialect, str): + dialect = url.URL.create(dialect).get_dialect()() + + if default_schema_name: + dialect.default_schema_name = default_schema_name + + kw = {} + compile_kwargs = {} + + if schema_translate_map: + kw["schema_translate_map"] = schema_translate_map + + if params is not None: + kw["column_keys"] = list(params) + + if literal_binds: + compile_kwargs["literal_binds"] = True + + if render_postcompile: + compile_kwargs["render_postcompile"] = True + + if use_literal_execute_for_simple_int: + compile_kwargs["use_literal_execute_for_simple_int"] = True + + if for_executemany: + kw["for_executemany"] = True + + if render_schema_translate: + kw["render_schema_translate"] = True + + if from_linting or getattr(self, "assert_from_linting", False): + kw["linting"] = sql.FROM_LINTING + + from sqlalchemy import orm + + if isinstance(clause, orm.Query): + stmt = clause._statement_20() + stmt._label_style = LABEL_STYLE_TABLENAME_PLUS_COL + clause = stmt + + if compile_kwargs: + kw["compile_kwargs"] = compile_kwargs + + class DontAccess: + def __getattribute__(self, key): + raise NotImplementedError( + "compiler accessed .statement; use " + "compiler.current_executable" + ) + + class CheckCompilerAccess: + def __init__(self, test_statement): + self.test_statement = test_statement + self._annotations = {} + self.supports_execution = getattr( + test_statement, "supports_execution", False + ) + + if self.supports_execution: + self._execution_options = test_statement._execution_options + + if hasattr(test_statement, "_returning"): + self._returning = test_statement._returning + if hasattr(test_statement, "_inline"): + self._inline = test_statement._inline + if hasattr(test_statement, "_return_defaults"): + self._return_defaults = test_statement._return_defaults + + @property + def _variant_mapping(self): + return self.test_statement._variant_mapping + + def _default_dialect(self): + return self.test_statement._default_dialect() + + def compile(self, dialect, **kw): + return self.test_statement.compile.__func__( + self, dialect=dialect, **kw + ) + + def _compiler(self, dialect, **kw): + return self.test_statement._compiler.__func__( + self, dialect, **kw + ) + + def _compiler_dispatch(self, compiler, **kwargs): + if hasattr(compiler, "statement"): + with mock.patch.object( + compiler, "statement", DontAccess() + ): + return self.test_statement._compiler_dispatch( + compiler, **kwargs + ) + else: + return self.test_statement._compiler_dispatch( + compiler, **kwargs + ) + + # no construct can assume it's the "top level" construct in all cases + # as anything can be nested. ensure constructs don't assume they + # are the "self.statement" element + c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw) + + if isinstance(clause, sqltypes.TypeEngine): + cache_key_no_warnings = clause._static_cache_key + if cache_key_no_warnings: + hash(cache_key_no_warnings) + else: + cache_key_no_warnings = clause._generate_cache_key() + if cache_key_no_warnings: + hash(cache_key_no_warnings[0]) + + param_str = repr(getattr(c, "params", {})) + param_str = param_str.encode("utf-8").decode("ascii", "ignore") + print(("\nSQL String:\n" + str(c) + param_str).encode("utf-8")) + + cc = re.sub(r"[\n\t]", "", str(c)) + + eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) + + if checkparams is not None: + if render_postcompile: + expanded_state = c.construct_expanded_state( + params, escape_names=False + ) + eq_(expanded_state.parameters, checkparams) + else: + eq_(c.construct_params(params), checkparams) + if checkpositional is not None: + if render_postcompile: + expanded_state = c.construct_expanded_state( + params, escape_names=False + ) + eq_( + tuple( + [ + expanded_state.parameters[x] + for x in expanded_state.positiontup + ] + ), + checkpositional, + ) + else: + p = c.construct_params(params, escape_names=False) + eq_(tuple([p[x] for x in c.positiontup]), checkpositional) + if check_prefetch is not None: + eq_(c.prefetch, check_prefetch) + if check_literal_execute is not None: + eq_( + { + c.bind_names[b]: b.effective_value + for b in c.literal_execute_params + }, + check_literal_execute, + ) + if check_post_param is not None: + eq_( + { + c.bind_names[b]: b.effective_value + for b in c.post_compile_params + }, + check_post_param, + ) + if check_param_order and getattr(c, "params", None): + + def get_dialect(paramstyle, positional): + cp = copy(dialect) + cp.paramstyle = paramstyle + cp.positional = positional + return cp + + pyformat_dialect = get_dialect("pyformat", False) + pyformat_c = clause.compile(dialect=pyformat_dialect, **kw) + stmt = re.sub(r"[\n\t]", "", str(pyformat_c)) + + qmark_dialect = get_dialect("qmark", True) + qmark_c = clause.compile(dialect=qmark_dialect, **kw) + values = list(qmark_c.positiontup) + escaped = qmark_c.escaped_bind_names + + for post_param in ( + qmark_c.post_compile_params | qmark_c.literal_execute_params + ): + name = qmark_c.bind_names[post_param] + if name in values: + values = [v for v in values if v != name] + positions = [] + pos_by_value = defaultdict(list) + for v in values: + try: + if v in pos_by_value: + start = pos_by_value[v][-1] + else: + start = 0 + esc = escaped.get(v, v) + pos = stmt.index("%%(%s)s" % (esc,), start) + 2 + positions.append(pos) + pos_by_value[v].append(pos) + except ValueError: + msg = "Expected to find bindparam %r in %r" % (v, stmt) + assert False, msg + + ordered = all( + positions[i - 1] < positions[i] + for i in range(1, len(positions)) + ) + + expected = [v for _, v in sorted(zip(positions, values))] + + msg = ( + "Order of parameters %s does not match the order " + "in the statement %s. Statement %r" % (values, expected, stmt) + ) + + is_true(ordered, msg) + + +class ComparesTables: + def assert_tables_equal( + self, + table, + reflected_table, + strict_types=False, + strict_constraints=True, + ): + assert len(table.c) == len(reflected_table.c) + for c, reflected_c in zip(table.c, reflected_table.c): + eq_(c.name, reflected_c.name) + assert reflected_c is reflected_table.c[c.name] + + if strict_constraints: + eq_(c.primary_key, reflected_c.primary_key) + eq_(c.nullable, reflected_c.nullable) + + if strict_types: + msg = "Type '%s' doesn't correspond to type '%s'" + assert isinstance(reflected_c.type, type(c.type)), msg % ( + reflected_c.type, + c.type, + ) + else: + self.assert_types_base(reflected_c, c) + + if isinstance(c.type, sqltypes.String): + eq_(c.type.length, reflected_c.type.length) + + if strict_constraints: + eq_( + {f.column.name for f in c.foreign_keys}, + {f.column.name for f in reflected_c.foreign_keys}, + ) + if c.server_default: + assert isinstance( + reflected_c.server_default, schema.FetchedValue + ) + + if strict_constraints: + assert len(table.primary_key) == len(reflected_table.primary_key) + for c in table.primary_key: + assert reflected_table.primary_key.columns[c.name] is not None + + def assert_types_base(self, c1, c2): + assert c1.type._compare_type_affinity( + c2.type + ), "On column %r, type '%s' doesn't correspond to type '%s'" % ( + c1.name, + c1.type, + c2.type, + ) + + +class AssertsExecutionResults: + def assert_result(self, result, class_, *objects): + result = list(result) + print(repr(result)) + self.assert_list(result, class_, objects) + + def assert_list(self, result, class_, list_): + self.assert_( + len(result) == len(list_), + "result list is not the same size as test list, " + + "for class " + + class_.__name__, + ) + for i in range(0, len(list_)): + self.assert_row(class_, result[i], list_[i]) + + def assert_row(self, class_, rowobj, desc): + self.assert_( + rowobj.__class__ is class_, "item class is not " + repr(class_) + ) + for key, value in desc.items(): + if isinstance(value, tuple): + if isinstance(value[1], list): + self.assert_list(getattr(rowobj, key), value[0], value[1]) + else: + self.assert_row(value[0], getattr(rowobj, key), value[1]) + else: + self.assert_( + getattr(rowobj, key) == value, + "attribute %s value %s does not match %s" + % (key, getattr(rowobj, key), value), + ) + + def assert_unordered_result(self, result, cls, *expected): + """As assert_result, but the order of objects is not considered. + + The algorithm is very expensive but not a big deal for the small + numbers of rows that the test suite manipulates. + """ + + class immutabledict(dict): + def __hash__(self): + return id(self) + + found = util.IdentitySet(result) + expected = {immutabledict(e) for e in expected} + + for wrong in filterfalse(lambda o: isinstance(o, cls), found): + fail( + 'Unexpected type "%s", expected "%s"' + % (type(wrong).__name__, cls.__name__) + ) + + if len(found) != len(expected): + fail( + 'Unexpected object count "%s", expected "%s"' + % (len(found), len(expected)) + ) + + NOVALUE = object() + + def _compare_item(obj, spec): + for key, value in spec.items(): + if isinstance(value, tuple): + try: + self.assert_unordered_result( + getattr(obj, key), value[0], *value[1] + ) + except AssertionError: + return False + else: + if getattr(obj, key, NOVALUE) != value: + return False + return True + + for expected_item in expected: + for found_item in found: + if _compare_item(found_item, expected_item): + found.remove(found_item) + break + else: + fail( + "Expected %s instance with attributes %s not found." + % (cls.__name__, repr(expected_item)) + ) + return True + + def sql_execution_asserter(self, db=None): + if db is None: + from . import db as db + + return assertsql.assert_engine(db) + + def assert_sql_execution(self, db, callable_, *rules): + with self.sql_execution_asserter(db) as asserter: + result = callable_() + asserter.assert_(*rules) + return result + + def assert_sql(self, db, callable_, rules): + newrules = [] + for rule in rules: + if isinstance(rule, dict): + newrule = assertsql.AllOf( + *[assertsql.CompiledSQL(k, v) for k, v in rule.items()] + ) + else: + newrule = assertsql.CompiledSQL(*rule) + newrules.append(newrule) + + return self.assert_sql_execution(db, callable_, *newrules) + + def assert_sql_count(self, db, callable_, count): + return self.assert_sql_execution( + db, callable_, assertsql.CountStatements(count) + ) + + @contextlib.contextmanager + def assert_execution(self, db, *rules): + with self.sql_execution_asserter(db) as asserter: + yield + asserter.assert_(*rules) + + def assert_statement_count(self, db, count): + return self.assert_execution(db, assertsql.CountStatements(count)) + + @contextlib.contextmanager + def assert_statement_count_multi_db(self, dbs, counts): + recs = [ + (self.sql_execution_asserter(db), db, count) + for (db, count) in zip(dbs, counts) + ] + asserters = [] + for ctx, db, count in recs: + asserters.append(ctx.__enter__()) + try: + yield + finally: + for asserter, (ctx, db, count) in zip(asserters, recs): + ctx.__exit__(None, None, None) + asserter.assert_(assertsql.CountStatements(count)) + + +class ComparesIndexes: + def compare_table_index_with_expected( + self, table: schema.Table, expected: list, dialect_name: str + ): + eq_(len(table.indexes), len(expected)) + idx_dict = {idx.name: idx for idx in table.indexes} + for exp in expected: + idx = idx_dict[exp["name"]] + eq_(idx.unique, exp["unique"]) + cols = [c for c in exp["column_names"] if c is not None] + eq_(len(idx.columns), len(cols)) + for c in cols: + is_true(c in idx.columns) + exprs = exp.get("expressions") + if exprs: + eq_(len(idx.expressions), len(exprs)) + for idx_exp, expr, col in zip( + idx.expressions, exprs, exp["column_names"] + ): + if col is None: + eq_(idx_exp.text, expr) + if ( + exp.get("dialect_options") + and f"{dialect_name}_include" in exp["dialect_options"] + ): + eq_( + idx.dialect_options[dialect_name]["include"], + exp["dialect_options"][f"{dialect_name}_include"], + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py new file mode 100644 index 0000000..ae4d335 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py @@ -0,0 +1,516 @@ +# testing/assertsql.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import collections +import contextlib +import itertools +import re + +from .. import event +from ..engine import url +from ..engine.default import DefaultDialect +from ..schema import BaseDDLElement + + +class AssertRule: + is_consumed = False + errormessage = None + consume_statement = True + + def process_statement(self, execute_observed): + pass + + def no_more_statements(self): + assert False, ( + "All statements are complete, but pending " + "assertion rules remain" + ) + + +class SQLMatchRule(AssertRule): + pass + + +class CursorSQL(SQLMatchRule): + def __init__(self, statement, params=None, consume_statement=True): + self.statement = statement + self.params = params + self.consume_statement = consume_statement + + def process_statement(self, execute_observed): + stmt = execute_observed.statements[0] + if self.statement != stmt.statement or ( + self.params is not None and self.params != stmt.parameters + ): + self.consume_statement = True + self.errormessage = ( + "Testing for exact SQL %s parameters %s received %s %s" + % ( + self.statement, + self.params, + stmt.statement, + stmt.parameters, + ) + ) + else: + execute_observed.statements.pop(0) + self.is_consumed = True + if not execute_observed.statements: + self.consume_statement = True + + +class CompiledSQL(SQLMatchRule): + def __init__( + self, statement, params=None, dialect="default", enable_returning=True + ): + self.statement = statement + self.params = params + self.dialect = dialect + self.enable_returning = enable_returning + + def _compare_sql(self, execute_observed, received_statement): + stmt = re.sub(r"[\n\t]", "", self.statement) + return received_statement == stmt + + def _compile_dialect(self, execute_observed): + if self.dialect == "default": + dialect = DefaultDialect() + # this is currently what tests are expecting + # dialect.supports_default_values = True + dialect.supports_default_metavalue = True + + if self.enable_returning: + dialect.insert_returning = dialect.update_returning = ( + dialect.delete_returning + ) = True + dialect.use_insertmanyvalues = True + dialect.supports_multivalues_insert = True + dialect.update_returning_multifrom = True + dialect.delete_returning_multifrom = True + # dialect.favor_returning_over_lastrowid = True + # dialect.insert_null_pk_still_autoincrements = True + + # this is calculated but we need it to be True for this + # to look like all the current RETURNING dialects + assert dialect.insert_executemany_returning + + return dialect + else: + return url.URL.create(self.dialect).get_dialect()() + + def _received_statement(self, execute_observed): + """reconstruct the statement and params in terms + of a target dialect, which for CompiledSQL is just DefaultDialect.""" + + context = execute_observed.context + compare_dialect = self._compile_dialect(execute_observed) + + # received_statement runs a full compile(). we should not need to + # consider extracted_parameters; if we do this indicates some state + # is being sent from a previous cached query, which some misbehaviors + # in the ORM can cause, see #6881 + cache_key = None # execute_observed.context.compiled.cache_key + extracted_parameters = ( + None # execute_observed.context.extracted_parameters + ) + + if "schema_translate_map" in context.execution_options: + map_ = context.execution_options["schema_translate_map"] + else: + map_ = None + + if isinstance(execute_observed.clauseelement, BaseDDLElement): + compiled = execute_observed.clauseelement.compile( + dialect=compare_dialect, + schema_translate_map=map_, + ) + else: + compiled = execute_observed.clauseelement.compile( + cache_key=cache_key, + dialect=compare_dialect, + column_keys=context.compiled.column_keys, + for_executemany=context.compiled.for_executemany, + schema_translate_map=map_, + ) + _received_statement = re.sub(r"[\n\t]", "", str(compiled)) + parameters = execute_observed.parameters + + if not parameters: + _received_parameters = [ + compiled.construct_params( + extracted_parameters=extracted_parameters + ) + ] + else: + _received_parameters = [ + compiled.construct_params( + m, extracted_parameters=extracted_parameters + ) + for m in parameters + ] + + return _received_statement, _received_parameters + + def process_statement(self, execute_observed): + context = execute_observed.context + + _received_statement, _received_parameters = self._received_statement( + execute_observed + ) + params = self._all_params(context) + + equivalent = self._compare_sql(execute_observed, _received_statement) + + if equivalent: + if params is not None: + all_params = list(params) + all_received = list(_received_parameters) + while all_params and all_received: + param = dict(all_params.pop(0)) + + for idx, received in enumerate(list(all_received)): + # do a positive compare only + for param_key in param: + # a key in param did not match current + # 'received' + if ( + param_key not in received + or received[param_key] != param[param_key] + ): + break + else: + # all keys in param matched 'received'; + # onto next param + del all_received[idx] + break + else: + # param did not match any entry + # in all_received + equivalent = False + break + if all_params or all_received: + equivalent = False + + if equivalent: + self.is_consumed = True + self.errormessage = None + else: + self.errormessage = self._failure_message( + execute_observed, params + ) % { + "received_statement": _received_statement, + "received_parameters": _received_parameters, + } + + def _all_params(self, context): + if self.params: + if callable(self.params): + params = self.params(context) + else: + params = self.params + if not isinstance(params, list): + params = [params] + return params + else: + return None + + def _failure_message(self, execute_observed, expected_params): + return ( + "Testing for compiled statement\n%r partial params %s, " + "received\n%%(received_statement)r with params " + "%%(received_parameters)r" + % ( + self.statement.replace("%", "%%"), + repr(expected_params).replace("%", "%%"), + ) + ) + + +class RegexSQL(CompiledSQL): + def __init__( + self, regex, params=None, dialect="default", enable_returning=False + ): + SQLMatchRule.__init__(self) + self.regex = re.compile(regex) + self.orig_regex = regex + self.params = params + self.dialect = dialect + self.enable_returning = enable_returning + + def _failure_message(self, execute_observed, expected_params): + return ( + "Testing for compiled statement ~%r partial params %s, " + "received %%(received_statement)r with params " + "%%(received_parameters)r" + % ( + self.orig_regex.replace("%", "%%"), + repr(expected_params).replace("%", "%%"), + ) + ) + + def _compare_sql(self, execute_observed, received_statement): + return bool(self.regex.match(received_statement)) + + +class DialectSQL(CompiledSQL): + def _compile_dialect(self, execute_observed): + return execute_observed.context.dialect + + def _compare_no_space(self, real_stmt, received_stmt): + stmt = re.sub(r"[\n\t]", "", real_stmt) + return received_stmt == stmt + + def _received_statement(self, execute_observed): + received_stmt, received_params = super()._received_statement( + execute_observed + ) + + # TODO: why do we need this part? + for real_stmt in execute_observed.statements: + if self._compare_no_space(real_stmt.statement, received_stmt): + break + else: + raise AssertionError( + "Can't locate compiled statement %r in list of " + "statements actually invoked" % received_stmt + ) + + return received_stmt, execute_observed.context.compiled_parameters + + def _dialect_adjusted_statement(self, dialect): + paramstyle = dialect.paramstyle + stmt = re.sub(r"[\n\t]", "", self.statement) + + # temporarily escape out PG double colons + stmt = stmt.replace("::", "!!") + + if paramstyle == "pyformat": + stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt) + else: + # positional params + repl = None + if paramstyle == "qmark": + repl = "?" + elif paramstyle == "format": + repl = r"%s" + elif paramstyle.startswith("numeric"): + counter = itertools.count(1) + + num_identifier = "$" if paramstyle == "numeric_dollar" else ":" + + def repl(m): + return f"{num_identifier}{next(counter)}" + + stmt = re.sub(r":([\w_]+)", repl, stmt) + + # put them back + stmt = stmt.replace("!!", "::") + + return stmt + + def _compare_sql(self, execute_observed, received_statement): + stmt = self._dialect_adjusted_statement( + execute_observed.context.dialect + ) + return received_statement == stmt + + def _failure_message(self, execute_observed, expected_params): + return ( + "Testing for compiled statement\n%r partial params %s, " + "received\n%%(received_statement)r with params " + "%%(received_parameters)r" + % ( + self._dialect_adjusted_statement( + execute_observed.context.dialect + ).replace("%", "%%"), + repr(expected_params).replace("%", "%%"), + ) + ) + + +class CountStatements(AssertRule): + def __init__(self, count): + self.count = count + self._statement_count = 0 + + def process_statement(self, execute_observed): + self._statement_count += 1 + + def no_more_statements(self): + if self.count != self._statement_count: + assert False, "desired statement count %d does not match %d" % ( + self.count, + self._statement_count, + ) + + +class AllOf(AssertRule): + def __init__(self, *rules): + self.rules = set(rules) + + def process_statement(self, execute_observed): + for rule in list(self.rules): + rule.errormessage = None + rule.process_statement(execute_observed) + if rule.is_consumed: + self.rules.discard(rule) + if not self.rules: + self.is_consumed = True + break + elif not rule.errormessage: + # rule is not done yet + self.errormessage = None + break + else: + self.errormessage = list(self.rules)[0].errormessage + + +class EachOf(AssertRule): + def __init__(self, *rules): + self.rules = list(rules) + + def process_statement(self, execute_observed): + if not self.rules: + self.is_consumed = True + self.consume_statement = False + + while self.rules: + rule = self.rules[0] + rule.process_statement(execute_observed) + if rule.is_consumed: + self.rules.pop(0) + elif rule.errormessage: + self.errormessage = rule.errormessage + if rule.consume_statement: + break + + if not self.rules: + self.is_consumed = True + + def no_more_statements(self): + if self.rules and not self.rules[0].is_consumed: + self.rules[0].no_more_statements() + elif self.rules: + super().no_more_statements() + + +class Conditional(EachOf): + def __init__(self, condition, rules, else_rules): + if condition: + super().__init__(*rules) + else: + super().__init__(*else_rules) + + +class Or(AllOf): + def process_statement(self, execute_observed): + for rule in self.rules: + rule.process_statement(execute_observed) + if rule.is_consumed: + self.is_consumed = True + break + else: + self.errormessage = list(self.rules)[0].errormessage + + +class SQLExecuteObserved: + def __init__(self, context, clauseelement, multiparams, params): + self.context = context + self.clauseelement = clauseelement + + if multiparams: + self.parameters = multiparams + elif params: + self.parameters = [params] + else: + self.parameters = [] + self.statements = [] + + def __repr__(self): + return str(self.statements) + + +class SQLCursorExecuteObserved( + collections.namedtuple( + "SQLCursorExecuteObserved", + ["statement", "parameters", "context", "executemany"], + ) +): + pass + + +class SQLAsserter: + def __init__(self): + self.accumulated = [] + + def _close(self): + self._final = self.accumulated + del self.accumulated + + def assert_(self, *rules): + rule = EachOf(*rules) + + observed = list(self._final) + while observed: + statement = observed.pop(0) + rule.process_statement(statement) + if rule.is_consumed: + break + elif rule.errormessage: + assert False, rule.errormessage + if observed: + assert False, "Additional SQL statements remain:\n%s" % observed + elif not rule.is_consumed: + rule.no_more_statements() + + +@contextlib.contextmanager +def assert_engine(engine): + asserter = SQLAsserter() + + orig = [] + + @event.listens_for(engine, "before_execute") + def connection_execute( + conn, clauseelement, multiparams, params, execution_options + ): + # grab the original statement + params before any cursor + # execution + orig[:] = clauseelement, multiparams, params + + @event.listens_for(engine, "after_cursor_execute") + def cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): + if not context: + return + # then grab real cursor statements and associate them all + # around a single context + if ( + asserter.accumulated + and asserter.accumulated[-1].context is context + ): + obs = asserter.accumulated[-1] + else: + obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2]) + asserter.accumulated.append(obs) + obs.statements.append( + SQLCursorExecuteObserved( + statement, parameters, context, executemany + ) + ) + + try: + yield asserter + finally: + event.remove(engine, "after_cursor_execute", cursor_execute) + event.remove(engine, "before_execute", connection_execute) + asserter._close() diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/asyncio.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/asyncio.py new file mode 100644 index 0000000..f71ca57 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/asyncio.py @@ -0,0 +1,135 @@ +# testing/asyncio.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +# functions and wrappers to run tests, fixtures, provisioning and +# setup/teardown in an asyncio event loop, conditionally based on the +# current DB driver being used for a test. + +# note that SQLAlchemy's asyncio integration also supports a method +# of running individual asyncio functions inside of separate event loops +# using "async_fallback" mode; however running whole functions in the event +# loop is a more accurate test for how SQLAlchemy's asyncio features +# would run in the real world. + + +from __future__ import annotations + +from functools import wraps +import inspect + +from . import config +from ..util.concurrency import _AsyncUtil + +# may be set to False if the +# --disable-asyncio flag is passed to the test runner. +ENABLE_ASYNCIO = True +_async_util = _AsyncUtil() # it has lazy init so just always create one + + +def _shutdown(): + """called when the test finishes""" + _async_util.close() + + +def _run_coroutine_function(fn, *args, **kwargs): + return _async_util.run(fn, *args, **kwargs) + + +def _assume_async(fn, *args, **kwargs): + """Run a function in an asyncio loop unconditionally. + + This function is used for provisioning features like + testing a database connection for server info. + + Note that for blocking IO database drivers, this means they block the + event loop. + + """ + + if not ENABLE_ASYNCIO: + return fn(*args, **kwargs) + + return _async_util.run_in_greenlet(fn, *args, **kwargs) + + +def _maybe_async_provisioning(fn, *args, **kwargs): + """Run a function in an asyncio loop if any current drivers might need it. + + This function is used for provisioning features that take + place outside of a specific database driver being selected, so if the + current driver that happens to be used for the provisioning operation + is an async driver, it will run in asyncio and not fail. + + Note that for blocking IO database drivers, this means they block the + event loop. + + """ + if not ENABLE_ASYNCIO: + return fn(*args, **kwargs) + + if config.any_async: + return _async_util.run_in_greenlet(fn, *args, **kwargs) + else: + return fn(*args, **kwargs) + + +def _maybe_async(fn, *args, **kwargs): + """Run a function in an asyncio loop if the current selected driver is + async. + + This function is used for test setup/teardown and tests themselves + where the current DB driver is known. + + + """ + if not ENABLE_ASYNCIO: + return fn(*args, **kwargs) + + is_async = config._current.is_async + + if is_async: + return _async_util.run_in_greenlet(fn, *args, **kwargs) + else: + return fn(*args, **kwargs) + + +def _maybe_async_wrapper(fn): + """Apply the _maybe_async function to an existing function and return + as a wrapped callable, supporting generator functions as well. + + This is currently used for pytest fixtures that support generator use. + + """ + + if inspect.isgeneratorfunction(fn): + _stop = object() + + def call_next(gen): + try: + return next(gen) + # can't raise StopIteration in an awaitable. + except StopIteration: + return _stop + + @wraps(fn) + def wrap_fixture(*args, **kwargs): + gen = fn(*args, **kwargs) + while True: + value = _maybe_async(call_next, gen) + if value is _stop: + break + yield value + + else: + + @wraps(fn) + def wrap_fixture(*args, **kwargs): + return _maybe_async(fn, *args, **kwargs) + + return wrap_fixture diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/config.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/config.py new file mode 100644 index 0000000..e2623ea --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/config.py @@ -0,0 +1,427 @@ +# testing/config.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +from argparse import Namespace +import collections +import inspect +import typing +from typing import Any +from typing import Callable +from typing import Iterable +from typing import NoReturn +from typing import Optional +from typing import Tuple +from typing import TypeVar +from typing import Union + +from . import mock +from . import requirements as _requirements +from .util import fail +from .. import util + +# default requirements; this is replaced by plugin_base when pytest +# is run +requirements = _requirements.SuiteRequirements() + +db = None +db_url = None +db_opts = None +file_config = None +test_schema = None +test_schema_2 = None +any_async = False +_current = None +ident = "main" +options: Namespace = None # type: ignore + +if typing.TYPE_CHECKING: + from .plugin.plugin_base import FixtureFunctions + + _fixture_functions: FixtureFunctions +else: + + class _NullFixtureFunctions: + def _null_decorator(self): + def go(fn): + return fn + + return go + + def skip_test_exception(self, *arg, **kw): + return Exception() + + @property + def add_to_marker(self): + return mock.Mock() + + def mark_base_test_class(self): + return self._null_decorator() + + def combinations(self, *arg_sets, **kw): + return self._null_decorator() + + def param_ident(self, *parameters): + return self._null_decorator() + + def fixture(self, *arg, **kw): + return self._null_decorator() + + def get_current_test_name(self): + return None + + def async_test(self, fn): + return fn + + # default fixture functions; these are replaced by plugin_base when + # pytest runs + _fixture_functions = _NullFixtureFunctions() + + +_FN = TypeVar("_FN", bound=Callable[..., Any]) + + +def combinations( + *comb: Union[Any, Tuple[Any, ...]], + argnames: Optional[str] = None, + id_: Optional[str] = None, + **kw: str, +) -> Callable[[_FN], _FN]: + r"""Deliver multiple versions of a test based on positional combinations. + + This is a facade over pytest.mark.parametrize. + + + :param \*comb: argument combinations. These are tuples that will be passed + positionally to the decorated function. + + :param argnames: optional list of argument names. These are the names + of the arguments in the test function that correspond to the entries + in each argument tuple. pytest.mark.parametrize requires this, however + the combinations function will derive it automatically if not present + by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the + first argument is "self" which is discarded. + + :param id\_: optional id template. This is a string template that + describes how the "id" for each parameter set should be defined, if any. + The number of characters in the template should match the number of + entries in each argument tuple. Each character describes how the + corresponding entry in the argument tuple should be handled, as far as + whether or not it is included in the arguments passed to the function, as + well as if it is included in the tokens used to create the id of the + parameter set. + + If omitted, the argument combinations are passed to parametrize as is. If + passed, each argument combination is turned into a pytest.param() object, + mapping the elements of the argument tuple to produce an id based on a + character value in the same position within the string template using the + following scheme:: + + i - the given argument is a string that is part of the id only, don't + pass it as an argument + + n - the given argument should be passed and it should be added to the + id by calling the .__name__ attribute + + r - the given argument should be passed and it should be added to the + id by calling repr() + + s - the given argument should be passed and it should be added to the + id by calling str() + + a - (argument) the given argument should be passed and it should not + be used to generated the id + + e.g.:: + + @testing.combinations( + (operator.eq, "eq"), + (operator.ne, "ne"), + (operator.gt, "gt"), + (operator.lt, "lt"), + id_="na" + ) + def test_operator(self, opfunc, name): + pass + + The above combination will call ``.__name__`` on the first member of + each tuple and use that as the "id" to pytest.param(). + + + """ + return _fixture_functions.combinations( + *comb, id_=id_, argnames=argnames, **kw + ) + + +def combinations_list(arg_iterable: Iterable[Tuple[Any, ...]], **kw): + "As combination, but takes a single iterable" + return combinations(*arg_iterable, **kw) + + +class Variation: + __slots__ = ("_name", "_argname") + + def __init__(self, case, argname, case_names): + self._name = case + self._argname = argname + for casename in case_names: + setattr(self, casename, casename == case) + + if typing.TYPE_CHECKING: + + def __getattr__(self, key: str) -> bool: ... + + @property + def name(self): + return self._name + + def __bool__(self): + return self._name == self._argname + + def __nonzero__(self): + return not self.__bool__() + + def __str__(self): + return f"{self._argname}={self._name!r}" + + def __repr__(self): + return str(self) + + def fail(self) -> NoReturn: + fail(f"Unknown {self}") + + @classmethod + def idfn(cls, variation): + return variation.name + + @classmethod + def generate_cases(cls, argname, cases): + case_names = [ + argname if c is True else "not_" + argname if c is False else c + for c in cases + ] + + typ = type( + argname, + (Variation,), + { + "__slots__": tuple(case_names), + }, + ) + + return [typ(casename, argname, case_names) for casename in case_names] + + +def variation(argname_or_fn, cases=None): + """a helper around testing.combinations that provides a single namespace + that can be used as a switch. + + e.g.:: + + @testing.variation("querytyp", ["select", "subquery", "legacy_query"]) + @testing.variation("lazy", ["select", "raise", "raise_on_sql"]) + def test_thing( + self, + querytyp, + lazy, + decl_base + ): + class Thing(decl_base): + __tablename__ = 'thing' + + # use name directly + rel = relationship("Rel", lazy=lazy.name) + + # use as a switch + if querytyp.select: + stmt = select(Thing) + elif querytyp.subquery: + stmt = select(Thing).subquery() + elif querytyp.legacy_query: + stmt = Session.query(Thing) + else: + querytyp.fail() + + + The variable provided is a slots object of boolean variables, as well + as the name of the case itself under the attribute ".name" + + """ + + if inspect.isfunction(argname_or_fn): + argname = argname_or_fn.__name__ + cases = argname_or_fn(None) + + @variation_fixture(argname, cases) + def go(self, request): + yield request.param + + return go + else: + argname = argname_or_fn + cases_plus_limitations = [ + ( + entry + if (isinstance(entry, tuple) and len(entry) == 2) + else (entry, None) + ) + for entry in cases + ] + + variations = Variation.generate_cases( + argname, [c for c, l in cases_plus_limitations] + ) + return combinations( + *[ + ( + (variation._name, variation, limitation) + if limitation is not None + else (variation._name, variation) + ) + for variation, (case, limitation) in zip( + variations, cases_plus_limitations + ) + ], + id_="ia", + argnames=argname, + ) + + +def variation_fixture(argname, cases, scope="function"): + return fixture( + params=Variation.generate_cases(argname, cases), + ids=Variation.idfn, + scope=scope, + ) + + +def fixture(*arg: Any, **kw: Any) -> Any: + return _fixture_functions.fixture(*arg, **kw) + + +def get_current_test_name() -> str: + return _fixture_functions.get_current_test_name() + + +def mark_base_test_class() -> Any: + return _fixture_functions.mark_base_test_class() + + +class _AddToMarker: + def __getattr__(self, attr: str) -> Any: + return getattr(_fixture_functions.add_to_marker, attr) + + +add_to_marker = _AddToMarker() + + +class Config: + def __init__(self, db, db_opts, options, file_config): + self._set_name(db) + self.db = db + self.db_opts = db_opts + self.options = options + self.file_config = file_config + self.test_schema = "test_schema" + self.test_schema_2 = "test_schema_2" + + self.is_async = db.dialect.is_async and not util.asbool( + db.url.query.get("async_fallback", False) + ) + + _stack = collections.deque() + _configs = set() + + def _set_name(self, db): + suffix = "_async" if db.dialect.is_async else "" + if db.dialect.server_version_info: + svi = ".".join(str(tok) for tok in db.dialect.server_version_info) + self.name = "%s+%s%s_[%s]" % (db.name, db.driver, suffix, svi) + else: + self.name = "%s+%s%s" % (db.name, db.driver, suffix) + + @classmethod + def register(cls, db, db_opts, options, file_config): + """add a config as one of the global configs. + + If there are no configs set up yet, this config also + gets set as the "_current". + """ + global any_async + + cfg = Config(db, db_opts, options, file_config) + + # if any backends include an async driver, then ensure + # all setup/teardown and tests are wrapped in the maybe_async() + # decorator that will set up a greenlet context for async drivers. + any_async = any_async or cfg.is_async + + cls._configs.add(cfg) + return cfg + + @classmethod + def set_as_current(cls, config, namespace): + global db, _current, db_url, test_schema, test_schema_2, db_opts + _current = config + db_url = config.db.url + db_opts = config.db_opts + test_schema = config.test_schema + test_schema_2 = config.test_schema_2 + namespace.db = db = config.db + + @classmethod + def push_engine(cls, db, namespace): + assert _current, "Can't push without a default Config set up" + cls.push( + Config( + db, _current.db_opts, _current.options, _current.file_config + ), + namespace, + ) + + @classmethod + def push(cls, config, namespace): + cls._stack.append(_current) + cls.set_as_current(config, namespace) + + @classmethod + def pop(cls, namespace): + if cls._stack: + # a failed test w/ -x option can call reset() ahead of time + _current = cls._stack[-1] + del cls._stack[-1] + cls.set_as_current(_current, namespace) + + @classmethod + def reset(cls, namespace): + if cls._stack: + cls.set_as_current(cls._stack[0], namespace) + cls._stack.clear() + + @classmethod + def all_configs(cls): + return cls._configs + + @classmethod + def all_dbs(cls): + for cfg in cls.all_configs(): + yield cfg.db + + def skip_test(self, msg): + skip_test(msg) + + +def skip_test(msg): + raise _fixture_functions.skip_test_exception(msg) + + +def async_test(fn): + return _fixture_functions.async_test(fn) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/engines.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/engines.py new file mode 100644 index 0000000..7cae807 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/engines.py @@ -0,0 +1,472 @@ +# testing/engines.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import collections +import re +import typing +from typing import Any +from typing import Dict +from typing import Optional +import warnings +import weakref + +from . import config +from .util import decorator +from .util import gc_collect +from .. import event +from .. import pool +from ..util import await_only +from ..util.typing import Literal + + +if typing.TYPE_CHECKING: + from ..engine import Engine + from ..engine.url import URL + from ..ext.asyncio import AsyncEngine + + +class ConnectionKiller: + def __init__(self): + self.proxy_refs = weakref.WeakKeyDictionary() + self.testing_engines = collections.defaultdict(set) + self.dbapi_connections = set() + + def add_pool(self, pool): + event.listen(pool, "checkout", self._add_conn) + event.listen(pool, "checkin", self._remove_conn) + event.listen(pool, "close", self._remove_conn) + event.listen(pool, "close_detached", self._remove_conn) + # note we are keeping "invalidated" here, as those are still + # opened connections we would like to roll back + + def _add_conn(self, dbapi_con, con_record, con_proxy): + self.dbapi_connections.add(dbapi_con) + self.proxy_refs[con_proxy] = True + + def _remove_conn(self, dbapi_conn, *arg): + self.dbapi_connections.discard(dbapi_conn) + + def add_engine(self, engine, scope): + self.add_pool(engine.pool) + + assert scope in ("class", "global", "function", "fixture") + self.testing_engines[scope].add(engine) + + def _safe(self, fn): + try: + fn() + except Exception as e: + warnings.warn( + "testing_reaper couldn't rollback/close connection: %s" % e + ) + + def rollback_all(self): + for rec in list(self.proxy_refs): + if rec is not None and rec.is_valid: + self._safe(rec.rollback) + + def checkin_all(self): + # run pool.checkin() for all ConnectionFairy instances we have + # tracked. + + for rec in list(self.proxy_refs): + if rec is not None and rec.is_valid: + self.dbapi_connections.discard(rec.dbapi_connection) + self._safe(rec._checkin) + + # for fairy refs that were GCed and could not close the connection, + # such as asyncio, roll back those remaining connections + for con in self.dbapi_connections: + self._safe(con.rollback) + self.dbapi_connections.clear() + + def close_all(self): + self.checkin_all() + + def prepare_for_drop_tables(self, connection): + # don't do aggressive checks for third party test suites + if not config.bootstrapped_as_sqlalchemy: + return + + from . import provision + + provision.prepare_for_drop_tables(connection.engine.url, connection) + + def _drop_testing_engines(self, scope): + eng = self.testing_engines[scope] + for rec in list(eng): + for proxy_ref in list(self.proxy_refs): + if proxy_ref is not None and proxy_ref.is_valid: + if ( + proxy_ref._pool is not None + and proxy_ref._pool is rec.pool + ): + self._safe(proxy_ref._checkin) + + if hasattr(rec, "sync_engine"): + await_only(rec.dispose()) + else: + rec.dispose() + eng.clear() + + def after_test(self): + self._drop_testing_engines("function") + + def after_test_outside_fixtures(self, test): + # don't do aggressive checks for third party test suites + if not config.bootstrapped_as_sqlalchemy: + return + + if test.__class__.__leave_connections_for_teardown__: + return + + self.checkin_all() + + # on PostgreSQL, this will test for any "idle in transaction" + # connections. useful to identify tests with unusual patterns + # that can't be cleaned up correctly. + from . import provision + + with config.db.connect() as conn: + provision.prepare_for_drop_tables(conn.engine.url, conn) + + def stop_test_class_inside_fixtures(self): + self.checkin_all() + self._drop_testing_engines("function") + self._drop_testing_engines("class") + + def stop_test_class_outside_fixtures(self): + # ensure no refs to checked out connections at all. + + if pool.base._strong_ref_connection_records: + gc_collect() + + if pool.base._strong_ref_connection_records: + ln = len(pool.base._strong_ref_connection_records) + pool.base._strong_ref_connection_records.clear() + assert ( + False + ), "%d connection recs not cleared after test suite" % (ln) + + def final_cleanup(self): + self.checkin_all() + for scope in self.testing_engines: + self._drop_testing_engines(scope) + + def assert_all_closed(self): + for rec in self.proxy_refs: + if rec.is_valid: + assert False + + +testing_reaper = ConnectionKiller() + + +@decorator +def assert_conns_closed(fn, *args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.assert_all_closed() + + +@decorator +def rollback_open_connections(fn, *args, **kw): + """Decorator that rolls back all open connections after fn execution.""" + + try: + fn(*args, **kw) + finally: + testing_reaper.rollback_all() + + +@decorator +def close_first(fn, *args, **kw): + """Decorator that closes all connections before fn execution.""" + + testing_reaper.checkin_all() + fn(*args, **kw) + + +@decorator +def close_open_connections(fn, *args, **kw): + """Decorator that closes all connections after fn execution.""" + try: + fn(*args, **kw) + finally: + testing_reaper.checkin_all() + + +def all_dialects(exclude=None): + import sqlalchemy.dialects as d + + for name in d.__all__: + # TEMPORARY + if exclude and name in exclude: + continue + mod = getattr(d, name, None) + if not mod: + mod = getattr( + __import__("sqlalchemy.dialects.%s" % name).dialects, name + ) + yield mod.dialect() + + +class ReconnectFixture: + def __init__(self, dbapi): + self.dbapi = dbapi + self.connections = [] + self.is_stopped = False + + def __getattr__(self, key): + return getattr(self.dbapi, key) + + def connect(self, *args, **kwargs): + conn = self.dbapi.connect(*args, **kwargs) + if self.is_stopped: + self._safe(conn.close) + curs = conn.cursor() # should fail on Oracle etc. + # should fail for everything that didn't fail + # above, connection is closed + curs.execute("select 1") + assert False, "simulated connect failure didn't work" + else: + self.connections.append(conn) + return conn + + def _safe(self, fn): + try: + fn() + except Exception as e: + warnings.warn("ReconnectFixture couldn't close connection: %s" % e) + + def shutdown(self, stop=False): + # TODO: this doesn't cover all cases + # as nicely as we'd like, namely MySQLdb. + # would need to implement R. Brewer's + # proxy server idea to get better + # coverage. + self.is_stopped = stop + for c in list(self.connections): + self._safe(c.close) + self.connections = [] + + def restart(self): + self.is_stopped = False + + +def reconnecting_engine(url=None, options=None): + url = url or config.db.url + dbapi = config.db.dialect.dbapi + if not options: + options = {} + options["module"] = ReconnectFixture(dbapi) + engine = testing_engine(url, options) + _dispose = engine.dispose + + def dispose(): + engine.dialect.dbapi.shutdown() + engine.dialect.dbapi.is_stopped = False + _dispose() + + engine.test_shutdown = engine.dialect.dbapi.shutdown + engine.test_restart = engine.dialect.dbapi.restart + engine.dispose = dispose + return engine + + +@typing.overload +def testing_engine( + url: Optional[URL] = None, + options: Optional[Dict[str, Any]] = None, + asyncio: Literal[False] = False, + transfer_staticpool: bool = False, +) -> Engine: ... + + +@typing.overload +def testing_engine( + url: Optional[URL] = None, + options: Optional[Dict[str, Any]] = None, + asyncio: Literal[True] = True, + transfer_staticpool: bool = False, +) -> AsyncEngine: ... + + +def testing_engine( + url=None, + options=None, + asyncio=False, + transfer_staticpool=False, + share_pool=False, + _sqlite_savepoint=False, +): + if asyncio: + assert not _sqlite_savepoint + from sqlalchemy.ext.asyncio import ( + create_async_engine as create_engine, + ) + else: + from sqlalchemy import create_engine + from sqlalchemy.engine.url import make_url + + if not options: + use_reaper = True + scope = "function" + sqlite_savepoint = False + else: + use_reaper = options.pop("use_reaper", True) + scope = options.pop("scope", "function") + sqlite_savepoint = options.pop("sqlite_savepoint", False) + + url = url or config.db.url + + url = make_url(url) + if options is None: + if config.db is None or url.drivername == config.db.url.drivername: + options = config.db_opts + else: + options = {} + elif config.db is not None and url.drivername == config.db.url.drivername: + default_opt = config.db_opts.copy() + default_opt.update(options) + + engine = create_engine(url, **options) + + if sqlite_savepoint and engine.name == "sqlite": + # apply SQLite savepoint workaround + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + dbapi_connection.isolation_level = None + + @event.listens_for(engine, "begin") + def do_begin(conn): + conn.exec_driver_sql("BEGIN") + + if transfer_staticpool: + from sqlalchemy.pool import StaticPool + + if config.db is not None and isinstance(config.db.pool, StaticPool): + use_reaper = False + engine.pool._transfer_from(config.db.pool) + elif share_pool: + engine.pool = config.db.pool + + if scope == "global": + if asyncio: + engine.sync_engine._has_events = True + else: + engine._has_events = ( + True # enable event blocks, helps with profiling + ) + + if ( + isinstance(engine.pool, pool.QueuePool) + and "pool" not in options + and "pool_timeout" not in options + and "max_overflow" not in options + ): + engine.pool._timeout = 0 + engine.pool._max_overflow = 0 + if use_reaper: + testing_reaper.add_engine(engine, scope) + + return engine + + +def mock_engine(dialect_name=None): + """Provides a mocking engine based on the current testing.db. + + This is normally used to test DDL generation flow as emitted + by an Engine. + + It should not be used in other cases, as assert_compile() and + assert_sql_execution() are much better choices with fewer + moving parts. + + """ + + from sqlalchemy import create_mock_engine + + if not dialect_name: + dialect_name = config.db.name + + buffer = [] + + def executor(sql, *a, **kw): + buffer.append(sql) + + def assert_sql(stmts): + recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer] + assert recv == stmts, recv + + def print_sql(): + d = engine.dialect + return "\n".join(str(s.compile(dialect=d)) for s in engine.mock) + + engine = create_mock_engine(dialect_name + "://", executor) + assert not hasattr(engine, "mock") + engine.mock = buffer + engine.assert_sql = assert_sql + engine.print_sql = print_sql + return engine + + +class DBAPIProxyCursor: + """Proxy a DBAPI cursor. + + Tests can provide subclasses of this to intercept + DBAPI-level cursor operations. + + """ + + def __init__(self, engine, conn, *args, **kwargs): + self.engine = engine + self.connection = conn + self.cursor = conn.cursor(*args, **kwargs) + + def execute(self, stmt, parameters=None, **kw): + if parameters: + return self.cursor.execute(stmt, parameters, **kw) + else: + return self.cursor.execute(stmt, **kw) + + def executemany(self, stmt, params, **kw): + return self.cursor.executemany(stmt, params, **kw) + + def __iter__(self): + return iter(self.cursor) + + def __getattr__(self, key): + return getattr(self.cursor, key) + + +class DBAPIProxyConnection: + """Proxy a DBAPI connection. + + Tests can provide subclasses of this to intercept + DBAPI-level connection operations. + + """ + + def __init__(self, engine, conn, cursor_cls): + self.conn = conn + self.engine = engine + self.cursor_cls = cursor_cls + + def cursor(self, *args, **kwargs): + return self.cursor_cls(self.engine, self.conn, *args, **kwargs) + + def close(self): + self.conn.close() + + def __getattr__(self, key): + return getattr(self.conn, key) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/entities.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/entities.py new file mode 100644 index 0000000..8f0f36b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/entities.py @@ -0,0 +1,117 @@ +# testing/entities.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import sqlalchemy as sa +from .. import exc as sa_exc +from ..orm.writeonly import WriteOnlyCollection + +_repr_stack = set() + + +class BasicEntity: + def __init__(self, **kw): + for key, value in kw.items(): + setattr(self, key, value) + + def __repr__(self): + if id(self) in _repr_stack: + return object.__repr__(self) + _repr_stack.add(id(self)) + try: + return "%s(%s)" % ( + (self.__class__.__name__), + ", ".join( + [ + "%s=%r" % (key, getattr(self, key)) + for key in sorted(self.__dict__.keys()) + if not key.startswith("_") + ] + ), + ) + finally: + _repr_stack.remove(id(self)) + + +_recursion_stack = set() + + +class ComparableMixin: + def __ne__(self, other): + return not self.__eq__(other) + + def __eq__(self, other): + """'Deep, sparse compare. + + Deeply compare two entities, following the non-None attributes of the + non-persisted object, if possible. + + """ + if other is self: + return True + elif not self.__class__ == other.__class__: + return False + + if id(self) in _recursion_stack: + return True + _recursion_stack.add(id(self)) + + try: + # pick the entity that's not SA persisted as the source + try: + self_key = sa.orm.attributes.instance_state(self).key + except sa.orm.exc.NO_STATE: + self_key = None + + if other is None: + a = self + b = other + elif self_key is not None: + a = other + b = self + else: + a = self + b = other + + for attr in list(a.__dict__): + if attr.startswith("_"): + continue + + value = getattr(a, attr) + + if isinstance(value, WriteOnlyCollection): + continue + + try: + # handle lazy loader errors + battr = getattr(b, attr) + except (AttributeError, sa_exc.UnboundExecutionError): + return False + + if hasattr(value, "__iter__") and not isinstance(value, str): + if hasattr(value, "__getitem__") and not hasattr( + value, "keys" + ): + if list(value) != list(battr): + return False + else: + if set(value) != set(battr): + return False + else: + if value is not None and value != battr: + return False + return True + finally: + _recursion_stack.remove(id(self)) + + +class ComparableEntity(ComparableMixin, BasicEntity): + def __hash__(self): + return hash(self.__class__) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/exclusions.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/exclusions.py new file mode 100644 index 0000000..addc4b7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/exclusions.py @@ -0,0 +1,435 @@ +# testing/exclusions.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import contextlib +import operator +import re +import sys + +from . import config +from .. import util +from ..util import decorator +from ..util.compat import inspect_getfullargspec + + +def skip_if(predicate, reason=None): + rule = compound() + pred = _as_predicate(predicate, reason) + rule.skips.add(pred) + return rule + + +def fails_if(predicate, reason=None): + rule = compound() + pred = _as_predicate(predicate, reason) + rule.fails.add(pred) + return rule + + +class compound: + def __init__(self): + self.fails = set() + self.skips = set() + + def __add__(self, other): + return self.add(other) + + def as_skips(self): + rule = compound() + rule.skips.update(self.skips) + rule.skips.update(self.fails) + return rule + + def add(self, *others): + copy = compound() + copy.fails.update(self.fails) + copy.skips.update(self.skips) + + for other in others: + copy.fails.update(other.fails) + copy.skips.update(other.skips) + return copy + + def not_(self): + copy = compound() + copy.fails.update(NotPredicate(fail) for fail in self.fails) + copy.skips.update(NotPredicate(skip) for skip in self.skips) + return copy + + @property + def enabled(self): + return self.enabled_for_config(config._current) + + def enabled_for_config(self, config): + for predicate in self.skips.union(self.fails): + if predicate(config): + return False + else: + return True + + def matching_config_reasons(self, config): + return [ + predicate._as_string(config) + for predicate in self.skips.union(self.fails) + if predicate(config) + ] + + def _extend(self, other): + self.skips.update(other.skips) + self.fails.update(other.fails) + + def __call__(self, fn): + if hasattr(fn, "_sa_exclusion_extend"): + fn._sa_exclusion_extend._extend(self) + return fn + + @decorator + def decorate(fn, *args, **kw): + return self._do(config._current, fn, *args, **kw) + + decorated = decorate(fn) + decorated._sa_exclusion_extend = self + return decorated + + @contextlib.contextmanager + def fail_if(self): + all_fails = compound() + all_fails.fails.update(self.skips.union(self.fails)) + + try: + yield + except Exception as ex: + all_fails._expect_failure(config._current, ex) + else: + all_fails._expect_success(config._current) + + def _do(self, cfg, fn, *args, **kw): + for skip in self.skips: + if skip(cfg): + msg = "'%s' : %s" % ( + config.get_current_test_name(), + skip._as_string(cfg), + ) + config.skip_test(msg) + + try: + return_value = fn(*args, **kw) + except Exception as ex: + self._expect_failure(cfg, ex, name=fn.__name__) + else: + self._expect_success(cfg, name=fn.__name__) + return return_value + + def _expect_failure(self, config, ex, name="block"): + for fail in self.fails: + if fail(config): + print( + "%s failed as expected (%s): %s " + % (name, fail._as_string(config), ex) + ) + break + else: + raise ex.with_traceback(sys.exc_info()[2]) + + def _expect_success(self, config, name="block"): + if not self.fails: + return + + for fail in self.fails: + if fail(config): + raise AssertionError( + "Unexpected success for '%s' (%s)" + % ( + name, + " and ".join( + fail._as_string(config) for fail in self.fails + ), + ) + ) + + +def only_if(predicate, reason=None): + predicate = _as_predicate(predicate) + return skip_if(NotPredicate(predicate), reason) + + +def succeeds_if(predicate, reason=None): + predicate = _as_predicate(predicate) + return fails_if(NotPredicate(predicate), reason) + + +class Predicate: + @classmethod + def as_predicate(cls, predicate, description=None): + if isinstance(predicate, compound): + return cls.as_predicate(predicate.enabled_for_config, description) + elif isinstance(predicate, Predicate): + if description and predicate.description is None: + predicate.description = description + return predicate + elif isinstance(predicate, (list, set)): + return OrPredicate( + [cls.as_predicate(pred) for pred in predicate], description + ) + elif isinstance(predicate, tuple): + return SpecPredicate(*predicate) + elif isinstance(predicate, str): + tokens = re.match( + r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate + ) + if not tokens: + raise ValueError( + "Couldn't locate DB name in predicate: %r" % predicate + ) + db = tokens.group(1) + op = tokens.group(2) + spec = ( + tuple(int(d) for d in tokens.group(3).split(".")) + if tokens.group(3) + else None + ) + + return SpecPredicate(db, op, spec, description=description) + elif callable(predicate): + return LambdaPredicate(predicate, description) + else: + assert False, "unknown predicate type: %s" % predicate + + def _format_description(self, config, negate=False): + bool_ = self(config) + if negate: + bool_ = not negate + return self.description % { + "driver": ( + config.db.url.get_driver_name() if config else "<no driver>" + ), + "database": ( + config.db.url.get_backend_name() if config else "<no database>" + ), + "doesnt_support": "doesn't support" if bool_ else "does support", + "does_support": "does support" if bool_ else "doesn't support", + } + + def _as_string(self, config=None, negate=False): + raise NotImplementedError() + + +class BooleanPredicate(Predicate): + def __init__(self, value, description=None): + self.value = value + self.description = description or "boolean %s" % value + + def __call__(self, config): + return self.value + + def _as_string(self, config, negate=False): + return self._format_description(config, negate=negate) + + +class SpecPredicate(Predicate): + def __init__(self, db, op=None, spec=None, description=None): + self.db = db + self.op = op + self.spec = spec + self.description = description + + _ops = { + "<": operator.lt, + ">": operator.gt, + "==": operator.eq, + "!=": operator.ne, + "<=": operator.le, + ">=": operator.ge, + "in": operator.contains, + "between": lambda val, pair: val >= pair[0] and val <= pair[1], + } + + def __call__(self, config): + if config is None: + return False + + engine = config.db + + if "+" in self.db: + dialect, driver = self.db.split("+") + else: + dialect, driver = self.db, None + + if dialect and engine.name != dialect: + return False + if driver is not None and engine.driver != driver: + return False + + if self.op is not None: + assert driver is None, "DBAPI version specs not supported yet" + + version = _server_version(engine) + oper = ( + hasattr(self.op, "__call__") and self.op or self._ops[self.op] + ) + return oper(version, self.spec) + else: + return True + + def _as_string(self, config, negate=False): + if self.description is not None: + return self._format_description(config) + elif self.op is None: + if negate: + return "not %s" % self.db + else: + return "%s" % self.db + else: + if negate: + return "not %s %s %s" % (self.db, self.op, self.spec) + else: + return "%s %s %s" % (self.db, self.op, self.spec) + + +class LambdaPredicate(Predicate): + def __init__(self, lambda_, description=None, args=None, kw=None): + spec = inspect_getfullargspec(lambda_) + if not spec[0]: + self.lambda_ = lambda db: lambda_() + else: + self.lambda_ = lambda_ + self.args = args or () + self.kw = kw or {} + if description: + self.description = description + elif lambda_.__doc__: + self.description = lambda_.__doc__ + else: + self.description = "custom function" + + def __call__(self, config): + return self.lambda_(config) + + def _as_string(self, config, negate=False): + return self._format_description(config) + + +class NotPredicate(Predicate): + def __init__(self, predicate, description=None): + self.predicate = predicate + self.description = description + + def __call__(self, config): + return not self.predicate(config) + + def _as_string(self, config, negate=False): + if self.description: + return self._format_description(config, not negate) + else: + return self.predicate._as_string(config, not negate) + + +class OrPredicate(Predicate): + def __init__(self, predicates, description=None): + self.predicates = predicates + self.description = description + + def __call__(self, config): + for pred in self.predicates: + if pred(config): + return True + return False + + def _eval_str(self, config, negate=False): + if negate: + conjunction = " and " + else: + conjunction = " or " + return conjunction.join( + p._as_string(config, negate=negate) for p in self.predicates + ) + + def _negation_str(self, config): + if self.description is not None: + return "Not " + self._format_description(config) + else: + return self._eval_str(config, negate=True) + + def _as_string(self, config, negate=False): + if negate: + return self._negation_str(config) + else: + if self.description is not None: + return self._format_description(config) + else: + return self._eval_str(config) + + +_as_predicate = Predicate.as_predicate + + +def _is_excluded(db, op, spec): + return SpecPredicate(db, op, spec)(config._current) + + +def _server_version(engine): + """Return a server_version_info tuple.""" + + # force metadata to be retrieved + conn = engine.connect() + version = getattr(engine.dialect, "server_version_info", None) + if version is None: + version = () + conn.close() + return version + + +def db_spec(*dbs): + return OrPredicate([Predicate.as_predicate(db) for db in dbs]) + + +def open(): # noqa + return skip_if(BooleanPredicate(False, "mark as execute")) + + +def closed(): + return skip_if(BooleanPredicate(True, "marked as skip")) + + +def fails(reason=None): + return fails_if(BooleanPredicate(True, reason or "expected to fail")) + + +def future(): + return fails_if(BooleanPredicate(True, "Future feature")) + + +def fails_on(db, reason=None): + return fails_if(db, reason) + + +def fails_on_everything_except(*dbs): + return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs])) + + +def skip(db, reason=None): + return skip_if(db, reason) + + +def only_on(dbs, reason=None): + return only_if( + OrPredicate( + [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)] + ) + ) + + +def exclude(db, op, spec, reason=None): + return skip_if(SpecPredicate(db, op, spec), reason) + + +def against(config, *queries): + assert queries, "no queries sent!" + return OrPredicate([Predicate.as_predicate(query) for query in queries])( + config + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__init__.py new file mode 100644 index 0000000..5981fb5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__init__.py @@ -0,0 +1,28 @@ +# testing/fixtures/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from .base import FutureEngineMixin as FutureEngineMixin +from .base import TestBase as TestBase +from .mypy import MypyTest as MypyTest +from .orm import after_test as after_test +from .orm import close_all_sessions as close_all_sessions +from .orm import DeclarativeMappedTest as DeclarativeMappedTest +from .orm import fixture_session as fixture_session +from .orm import MappedTest as MappedTest +from .orm import ORMTest as ORMTest +from .orm import RemoveORMEventsGlobally as RemoveORMEventsGlobally +from .orm import ( + stop_test_class_inside_fixtures as stop_test_class_inside_fixtures, +) +from .sql import CacheKeyFixture as CacheKeyFixture +from .sql import ( + ComputedReflectionFixtureTest as ComputedReflectionFixtureTest, +) +from .sql import insertmanyvalues_fixture as insertmanyvalues_fixture +from .sql import NoCache as NoCache +from .sql import RemovesEvents as RemovesEvents +from .sql import TablesTest as TablesTest diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0f95b6a --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ccff61f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/mypy.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/mypy.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..278a5f6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/mypy.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/orm.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/orm.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..71d2d9a --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/orm.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/sql.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/sql.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..bc870cb --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/sql.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/base.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/base.py new file mode 100644 index 0000000..0697f49 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/base.py @@ -0,0 +1,366 @@ +# testing/fixtures/base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import sqlalchemy as sa +from .. import assertions +from .. import config +from ..assertions import eq_ +from ..util import drop_all_tables_from_metadata +from ... import Column +from ... import func +from ... import Integer +from ... import select +from ... import Table +from ...orm import DeclarativeBase +from ...orm import MappedAsDataclass +from ...orm import registry + + +@config.mark_base_test_class() +class TestBase: + # A sequence of requirement names matching testing.requires decorators + __requires__ = () + + # A sequence of dialect names to exclude from the test class. + __unsupported_on__ = () + + # If present, test class is only runnable for the *single* specified + # dialect. If you need multiple, use __unsupported_on__ and invert. + __only_on__ = None + + # A sequence of no-arg callables. If any are True, the entire testcase is + # skipped. + __skip_if__ = None + + # if True, the testing reaper will not attempt to touch connection + # state after a test is completed and before the outer teardown + # starts + __leave_connections_for_teardown__ = False + + def assert_(self, val, msg=None): + assert val, msg + + @config.fixture() + def nocache(self): + _cache = config.db._compiled_cache + config.db._compiled_cache = None + yield + config.db._compiled_cache = _cache + + @config.fixture() + def connection_no_trans(self): + eng = getattr(self, "bind", None) or config.db + + with eng.connect() as conn: + yield conn + + @config.fixture() + def connection(self): + global _connection_fixture_connection + + eng = getattr(self, "bind", None) or config.db + + conn = eng.connect() + trans = conn.begin() + + _connection_fixture_connection = conn + yield conn + + _connection_fixture_connection = None + + if trans.is_active: + trans.rollback() + # trans would not be active here if the test is using + # the legacy @provide_metadata decorator still, as it will + # run a close all connections. + conn.close() + + @config.fixture() + def close_result_when_finished(self): + to_close = [] + to_consume = [] + + def go(result, consume=False): + to_close.append(result) + if consume: + to_consume.append(result) + + yield go + for r in to_consume: + try: + r.all() + except: + pass + for r in to_close: + try: + r.close() + except: + pass + + @config.fixture() + def registry(self, metadata): + reg = registry( + metadata=metadata, + type_annotation_map={ + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb", "oracle" + ) + }, + ) + yield reg + reg.dispose() + + @config.fixture + def decl_base(self, metadata): + _md = metadata + + class Base(DeclarativeBase): + metadata = _md + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb", "oracle" + ) + } + + yield Base + Base.registry.dispose() + + @config.fixture + def dc_decl_base(self, metadata): + _md = metadata + + class Base(MappedAsDataclass, DeclarativeBase): + metadata = _md + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb" + ) + } + + yield Base + Base.registry.dispose() + + @config.fixture() + def future_connection(self, future_engine, connection): + # integrate the future_engine and connection fixtures so + # that users of the "connection" fixture will get at the + # "future" connection + yield connection + + @config.fixture() + def future_engine(self): + yield + + @config.fixture() + def testing_engine(self): + from .. import engines + + def gen_testing_engine( + url=None, + options=None, + future=None, + asyncio=False, + transfer_staticpool=False, + share_pool=False, + ): + if options is None: + options = {} + options["scope"] = "fixture" + return engines.testing_engine( + url=url, + options=options, + asyncio=asyncio, + transfer_staticpool=transfer_staticpool, + share_pool=share_pool, + ) + + yield gen_testing_engine + + engines.testing_reaper._drop_testing_engines("fixture") + + @config.fixture() + def async_testing_engine(self, testing_engine): + def go(**kw): + kw["asyncio"] = True + return testing_engine(**kw) + + return go + + @config.fixture() + def metadata(self, request): + """Provide bound MetaData for a single test, dropping afterwards.""" + + from ...sql import schema + + metadata = schema.MetaData() + request.instance.metadata = metadata + yield metadata + del request.instance.metadata + + if ( + _connection_fixture_connection + and _connection_fixture_connection.in_transaction() + ): + trans = _connection_fixture_connection.get_transaction() + trans.rollback() + with _connection_fixture_connection.begin(): + drop_all_tables_from_metadata( + metadata, _connection_fixture_connection + ) + else: + drop_all_tables_from_metadata(metadata, config.db) + + @config.fixture( + params=[ + (rollback, second_operation, begin_nested) + for rollback in (True, False) + for second_operation in ("none", "execute", "begin") + for begin_nested in ( + True, + False, + ) + ] + ) + def trans_ctx_manager_fixture(self, request, metadata): + rollback, second_operation, begin_nested = request.param + + t = Table("test", metadata, Column("data", Integer)) + eng = getattr(self, "bind", None) or config.db + + t.create(eng) + + def run_test(subject, trans_on_subject, execute_on_subject): + with subject.begin() as trans: + if begin_nested: + if not config.requirements.savepoints.enabled: + config.skip_test("savepoints not enabled") + if execute_on_subject: + nested_trans = subject.begin_nested() + else: + nested_trans = trans.begin_nested() + + with nested_trans: + if execute_on_subject: + subject.execute(t.insert(), {"data": 10}) + else: + trans.execute(t.insert(), {"data": 10}) + + # for nested trans, we always commit/rollback on the + # "nested trans" object itself. + # only Session(future=False) will affect savepoint + # transaction for session.commit/rollback + + if rollback: + nested_trans.rollback() + else: + nested_trans.commit() + + if second_operation != "none": + with assertions.expect_raises_message( + sa.exc.InvalidRequestError, + "Can't operate on closed transaction " + "inside context " + "manager. Please complete the context " + "manager " + "before emitting further commands.", + ): + if second_operation == "execute": + if execute_on_subject: + subject.execute( + t.insert(), {"data": 12} + ) + else: + trans.execute(t.insert(), {"data": 12}) + elif second_operation == "begin": + if execute_on_subject: + subject.begin_nested() + else: + trans.begin_nested() + + # outside the nested trans block, but still inside the + # transaction block, we can run SQL, and it will be + # committed + if execute_on_subject: + subject.execute(t.insert(), {"data": 14}) + else: + trans.execute(t.insert(), {"data": 14}) + + else: + if execute_on_subject: + subject.execute(t.insert(), {"data": 10}) + else: + trans.execute(t.insert(), {"data": 10}) + + if trans_on_subject: + if rollback: + subject.rollback() + else: + subject.commit() + else: + if rollback: + trans.rollback() + else: + trans.commit() + + if second_operation != "none": + with assertions.expect_raises_message( + sa.exc.InvalidRequestError, + "Can't operate on closed transaction inside " + "context " + "manager. Please complete the context manager " + "before emitting further commands.", + ): + if second_operation == "execute": + if execute_on_subject: + subject.execute(t.insert(), {"data": 12}) + else: + trans.execute(t.insert(), {"data": 12}) + elif second_operation == "begin": + if hasattr(trans, "begin"): + trans.begin() + else: + subject.begin() + elif second_operation == "begin_nested": + if execute_on_subject: + subject.begin_nested() + else: + trans.begin_nested() + + expected_committed = 0 + if begin_nested: + # begin_nested variant, we inserted a row after the nested + # block + expected_committed += 1 + if not rollback: + # not rollback variant, our row inserted in the target + # block itself would be committed + expected_committed += 1 + + if execute_on_subject: + eq_( + subject.scalar(select(func.count()).select_from(t)), + expected_committed, + ) + else: + with subject.connect() as conn: + eq_( + conn.scalar(select(func.count()).select_from(t)), + expected_committed, + ) + + return run_test + + +_connection_fixture_connection = None + + +class FutureEngineMixin: + """alembic's suite still using this""" diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/mypy.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/mypy.py new file mode 100644 index 0000000..149df9f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/mypy.py @@ -0,0 +1,312 @@ +# testing/fixtures/mypy.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from __future__ import annotations + +import inspect +import os +from pathlib import Path +import re +import shutil +import sys +import tempfile + +from .base import TestBase +from .. import config +from ..assertions import eq_ +from ... import util + + +@config.add_to_marker.mypy +class MypyTest(TestBase): + __requires__ = ("no_sqlalchemy2_stubs",) + + @config.fixture(scope="function") + def per_func_cachedir(self): + yield from self._cachedir() + + @config.fixture(scope="class") + def cachedir(self): + yield from self._cachedir() + + def _cachedir(self): + # as of mypy 0.971 i think we need to keep mypy_path empty + mypy_path = "" + + with tempfile.TemporaryDirectory() as cachedir: + with open( + Path(cachedir) / "sqla_mypy_config.cfg", "w" + ) as config_file: + config_file.write( + f""" + [mypy]\n + plugins = sqlalchemy.ext.mypy.plugin\n + show_error_codes = True\n + {mypy_path} + disable_error_code = no-untyped-call + + [mypy-sqlalchemy.*] + ignore_errors = True + + """ + ) + with open( + Path(cachedir) / "plain_mypy_config.cfg", "w" + ) as config_file: + config_file.write( + f""" + [mypy]\n + show_error_codes = True\n + {mypy_path} + disable_error_code = var-annotated,no-untyped-call + [mypy-sqlalchemy.*] + ignore_errors = True + + """ + ) + yield cachedir + + @config.fixture() + def mypy_runner(self, cachedir): + from mypy import api + + def run(path, use_plugin=False, use_cachedir=None): + if use_cachedir is None: + use_cachedir = cachedir + args = [ + "--strict", + "--raise-exceptions", + "--cache-dir", + use_cachedir, + "--config-file", + os.path.join( + use_cachedir, + ( + "sqla_mypy_config.cfg" + if use_plugin + else "plain_mypy_config.cfg" + ), + ), + ] + + # mypy as of 0.990 is more aggressively blocking messaging + # for paths that are in sys.path, and as pytest puts currdir, + # test/ etc in sys.path, just copy the source file to the + # tempdir we are working in so that we don't have to try to + # manipulate sys.path and/or guess what mypy is doing + filename = os.path.basename(path) + test_program = os.path.join(use_cachedir, filename) + if path != test_program: + shutil.copyfile(path, test_program) + args.append(test_program) + + # I set this locally but for the suite here needs to be + # disabled + os.environ.pop("MYPY_FORCE_COLOR", None) + + stdout, stderr, exitcode = api.run(args) + return stdout, stderr, exitcode + + return run + + @config.fixture + def mypy_typecheck_file(self, mypy_runner): + def run(path, use_plugin=False): + expected_messages = self._collect_messages(path) + stdout, stderr, exitcode = mypy_runner(path, use_plugin=use_plugin) + self._check_output( + path, expected_messages, stdout, stderr, exitcode + ) + + return run + + @staticmethod + def file_combinations(dirname): + if os.path.isabs(dirname): + path = dirname + else: + caller_path = inspect.stack()[1].filename + path = os.path.join(os.path.dirname(caller_path), dirname) + files = list(Path(path).glob("**/*.py")) + + for extra_dir in config.options.mypy_extra_test_paths: + if extra_dir and os.path.isdir(extra_dir): + files.extend((Path(extra_dir) / dirname).glob("**/*.py")) + return files + + def _collect_messages(self, path): + from sqlalchemy.ext.mypy.util import mypy_14 + + expected_messages = [] + expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)") + py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)") + with open(path) as file_: + current_assert_messages = [] + for num, line in enumerate(file_, 1): + m = py_ver_re.match(line) + if m: + major, _, minor = m.group(1).partition(".") + if sys.version_info < (int(major), int(minor)): + config.skip_test( + "Requires python >= %s" % (m.group(1)) + ) + continue + + m = expected_re.match(line) + if m: + is_mypy = bool(m.group(1)) + is_re = bool(m.group(2)) + is_type = bool(m.group(3)) + + expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4)) + if is_type: + if not is_re: + # the goal here is that we can cut-and-paste + # from vscode -> pylance into the + # EXPECTED_TYPE: line, then the test suite will + # validate that line against what mypy produces + expected_msg = re.sub( + r"([\[\]])", + lambda m: rf"\{m.group(0)}", + expected_msg, + ) + + # note making sure preceding text matches + # with a dot, so that an expect for "Select" + # does not match "TypedSelect" + expected_msg = re.sub( + r"([\w_]+)", + lambda m: rf"(?:.*\.)?{m.group(1)}\*?", + expected_msg, + ) + + expected_msg = re.sub( + "List", "builtins.list", expected_msg + ) + + expected_msg = re.sub( + r"\b(int|str|float|bool)\b", + lambda m: rf"builtins.{m.group(0)}\*?", + expected_msg, + ) + # expected_msg = re.sub( + # r"(Sequence|Tuple|List|Union)", + # lambda m: fr"typing.{m.group(0)}\*?", + # expected_msg, + # ) + + is_mypy = is_re = True + expected_msg = f'Revealed type is "{expected_msg}"' + + if mypy_14 and util.py39: + # use_lowercase_names, py39 and above + # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L363 # noqa: E501 + + # skip first character which could be capitalized + # "List item x not found" type of message + expected_msg = expected_msg[0] + re.sub( + ( + r"\b(List|Tuple|Dict|Set)\b" + if is_type + else r"\b(List|Tuple|Dict|Set|Type)\b" + ), + lambda m: m.group(1).lower(), + expected_msg[1:], + ) + + if mypy_14 and util.py310: + # use_or_syntax, py310 and above + # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368 # noqa: E501 + expected_msg = re.sub( + r"Optional\[(.*?)\]", + lambda m: f"{m.group(1)} | None", + expected_msg, + ) + current_assert_messages.append( + (is_mypy, is_re, expected_msg.strip()) + ) + elif current_assert_messages: + expected_messages.extend( + (num, is_mypy, is_re, expected_msg) + for ( + is_mypy, + is_re, + expected_msg, + ) in current_assert_messages + ) + current_assert_messages[:] = [] + + return expected_messages + + def _check_output(self, path, expected_messages, stdout, stderr, exitcode): + not_located = [] + filename = os.path.basename(path) + if expected_messages: + # mypy 0.990 changed how return codes work, so don't assume a + # 1 or a 0 return code here, could be either depending on if + # errors were generated or not + + output = [] + + raw_lines = stdout.split("\n") + while raw_lines: + e = raw_lines.pop(0) + if re.match(r".+\.py:\d+: error: .*", e): + output.append(("error", e)) + elif re.match( + r".+\.py:\d+: note: +(?:Possible overload|def ).*", e + ): + while raw_lines: + ol = raw_lines.pop(0) + if not re.match(r".+\.py:\d+: note: +def \[.*", ol): + break + elif re.match( + r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I + ): + pass + elif re.match(r".+\.py:\d+: note: .*", e): + output.append(("note", e)) + + for num, is_mypy, is_re, msg in expected_messages: + msg = msg.replace("'", '"') + prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else "" + for idx, (typ, errmsg) in enumerate(output): + if is_re: + if re.match( + rf".*{filename}\:{num}\: {typ}\: {prefix}{msg}", + errmsg, + ): + break + elif ( + f"{filename}:{num}: {typ}: {prefix}{msg}" + in errmsg.replace("'", '"') + ): + break + else: + not_located.append(msg) + continue + del output[idx] + + if not_located: + missing = "\n".join(not_located) + print("Couldn't locate expected messages:", missing, sep="\n") + if output: + extra = "\n".join(msg for _, msg in output) + print("Remaining messages:", extra, sep="\n") + assert False, "expected messages not found, see stdout" + + if output: + print(f"{len(output)} messages from mypy were not consumed:") + print("\n".join(msg for _, msg in output)) + assert False, "errors and/or notes remain, see stdout" + + else: + if exitcode != 0: + print(stdout, stderr, sep="\n") + + eq_(exitcode, 0, msg=stdout) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/orm.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/orm.py new file mode 100644 index 0000000..5ddd21e --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/orm.py @@ -0,0 +1,227 @@ +# testing/fixtures/orm.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from __future__ import annotations + +from typing import Any + +import sqlalchemy as sa +from .base import TestBase +from .sql import TablesTest +from .. import assertions +from .. import config +from .. import schema +from ..entities import BasicEntity +from ..entities import ComparableEntity +from ..util import adict +from ... import orm +from ...orm import DeclarativeBase +from ...orm import events as orm_events +from ...orm import registry + + +class ORMTest(TestBase): + @config.fixture + def fixture_session(self): + return fixture_session() + + +class MappedTest(ORMTest, TablesTest, assertions.AssertsExecutionResults): + # 'once', 'each', None + run_setup_classes = "once" + + # 'once', 'each', None + run_setup_mappers = "each" + + classes: Any = None + + @config.fixture(autouse=True, scope="class") + def _setup_tables_test_class(self): + cls = self.__class__ + cls._init_class() + + if cls.classes is None: + cls.classes = adict() + + cls._setup_once_tables() + cls._setup_once_classes() + cls._setup_once_mappers() + cls._setup_once_inserts() + + yield + + cls._teardown_once_class() + cls._teardown_once_metadata_bind() + + @config.fixture(autouse=True, scope="function") + def _setup_tables_test_instance(self): + self._setup_each_tables() + self._setup_each_classes() + self._setup_each_mappers() + self._setup_each_inserts() + + yield + + orm.session.close_all_sessions() + self._teardown_each_mappers() + self._teardown_each_classes() + self._teardown_each_tables() + + @classmethod + def _teardown_once_class(cls): + cls.classes.clear() + + @classmethod + def _setup_once_classes(cls): + if cls.run_setup_classes == "once": + cls._with_register_classes(cls.setup_classes) + + @classmethod + def _setup_once_mappers(cls): + if cls.run_setup_mappers == "once": + cls.mapper_registry, cls.mapper = cls._generate_registry() + cls._with_register_classes(cls.setup_mappers) + + def _setup_each_mappers(self): + if self.run_setup_mappers != "once": + ( + self.__class__.mapper_registry, + self.__class__.mapper, + ) = self._generate_registry() + + if self.run_setup_mappers == "each": + self._with_register_classes(self.setup_mappers) + + def _setup_each_classes(self): + if self.run_setup_classes == "each": + self._with_register_classes(self.setup_classes) + + @classmethod + def _generate_registry(cls): + decl = registry(metadata=cls._tables_metadata) + return decl, decl.map_imperatively + + @classmethod + def _with_register_classes(cls, fn): + """Run a setup method, framing the operation with a Base class + that will catch new subclasses to be established within + the "classes" registry. + + """ + cls_registry = cls.classes + + class _Base: + def __init_subclass__(cls) -> None: + assert cls_registry is not None + cls_registry[cls.__name__] = cls + super().__init_subclass__() + + class Basic(BasicEntity, _Base): + pass + + class Comparable(ComparableEntity, _Base): + pass + + cls.Basic = Basic + cls.Comparable = Comparable + fn() + + def _teardown_each_mappers(self): + # some tests create mappers in the test bodies + # and will define setup_mappers as None - + # clear mappers in any case + if self.run_setup_mappers != "once": + orm.clear_mappers() + + def _teardown_each_classes(self): + if self.run_setup_classes != "once": + self.classes.clear() + + @classmethod + def setup_classes(cls): + pass + + @classmethod + def setup_mappers(cls): + pass + + +class DeclarativeMappedTest(MappedTest): + run_setup_classes = "once" + run_setup_mappers = "once" + + @classmethod + def _setup_once_tables(cls): + pass + + @classmethod + def _with_register_classes(cls, fn): + cls_registry = cls.classes + + class _DeclBase(DeclarativeBase): + __table_cls__ = schema.Table + metadata = cls._tables_metadata + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb", "oracle" + ) + } + + def __init_subclass__(cls, **kw) -> None: + assert cls_registry is not None + cls_registry[cls.__name__] = cls + super().__init_subclass__(**kw) + + cls.DeclarativeBasic = _DeclBase + + # sets up cls.Basic which is helpful for things like composite + # classes + super()._with_register_classes(fn) + + if cls._tables_metadata.tables and cls.run_create_tables: + cls._tables_metadata.create_all(config.db) + + +class RemoveORMEventsGlobally: + @config.fixture(autouse=True) + def _remove_listeners(self): + yield + orm_events.MapperEvents._clear() + orm_events.InstanceEvents._clear() + orm_events.SessionEvents._clear() + orm_events.InstrumentationEvents._clear() + orm_events.QueryEvents._clear() + + +_fixture_sessions = set() + + +def fixture_session(**kw): + kw.setdefault("autoflush", True) + kw.setdefault("expire_on_commit", True) + + bind = kw.pop("bind", config.db) + + sess = orm.Session(bind, **kw) + _fixture_sessions.add(sess) + return sess + + +def close_all_sessions(): + # will close all still-referenced sessions + orm.close_all_sessions() + _fixture_sessions.clear() + + +def stop_test_class_inside_fixtures(cls): + close_all_sessions() + orm.clear_mappers() + + +def after_test(): + if _fixture_sessions: + close_all_sessions() diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/sql.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/sql.py new file mode 100644 index 0000000..830fa27 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/sql.py @@ -0,0 +1,493 @@ +# testing/fixtures/sql.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from __future__ import annotations + +import itertools +import random +import re +import sys + +import sqlalchemy as sa +from .base import TestBase +from .. import config +from .. import mock +from ..assertions import eq_ +from ..assertions import ne_ +from ..util import adict +from ..util import drop_all_tables_from_metadata +from ... import event +from ... import util +from ...schema import sort_tables_and_constraints +from ...sql import visitors +from ...sql.elements import ClauseElement + + +class TablesTest(TestBase): + # 'once', None + run_setup_bind = "once" + + # 'once', 'each', None + run_define_tables = "once" + + # 'once', 'each', None + run_create_tables = "once" + + # 'once', 'each', None + run_inserts = "each" + + # 'each', None + run_deletes = "each" + + # 'once', None + run_dispose_bind = None + + bind = None + _tables_metadata = None + tables = None + other = None + sequences = None + + @config.fixture(autouse=True, scope="class") + def _setup_tables_test_class(self): + cls = self.__class__ + cls._init_class() + + cls._setup_once_tables() + + cls._setup_once_inserts() + + yield + + cls._teardown_once_metadata_bind() + + @config.fixture(autouse=True, scope="function") + def _setup_tables_test_instance(self): + self._setup_each_tables() + self._setup_each_inserts() + + yield + + self._teardown_each_tables() + + @property + def tables_test_metadata(self): + return self._tables_metadata + + @classmethod + def _init_class(cls): + if cls.run_define_tables == "each": + if cls.run_create_tables == "once": + cls.run_create_tables = "each" + assert cls.run_inserts in ("each", None) + + cls.other = adict() + cls.tables = adict() + cls.sequences = adict() + + cls.bind = cls.setup_bind() + cls._tables_metadata = sa.MetaData() + + @classmethod + def _setup_once_inserts(cls): + if cls.run_inserts == "once": + cls._load_fixtures() + with cls.bind.begin() as conn: + cls.insert_data(conn) + + @classmethod + def _setup_once_tables(cls): + if cls.run_define_tables == "once": + cls.define_tables(cls._tables_metadata) + if cls.run_create_tables == "once": + cls._tables_metadata.create_all(cls.bind) + cls.tables.update(cls._tables_metadata.tables) + cls.sequences.update(cls._tables_metadata._sequences) + + def _setup_each_tables(self): + if self.run_define_tables == "each": + self.define_tables(self._tables_metadata) + if self.run_create_tables == "each": + self._tables_metadata.create_all(self.bind) + self.tables.update(self._tables_metadata.tables) + self.sequences.update(self._tables_metadata._sequences) + elif self.run_create_tables == "each": + self._tables_metadata.create_all(self.bind) + + def _setup_each_inserts(self): + if self.run_inserts == "each": + self._load_fixtures() + with self.bind.begin() as conn: + self.insert_data(conn) + + def _teardown_each_tables(self): + if self.run_define_tables == "each": + self.tables.clear() + if self.run_create_tables == "each": + drop_all_tables_from_metadata(self._tables_metadata, self.bind) + self._tables_metadata.clear() + elif self.run_create_tables == "each": + drop_all_tables_from_metadata(self._tables_metadata, self.bind) + + savepoints = getattr(config.requirements, "savepoints", False) + if savepoints: + savepoints = savepoints.enabled + + # no need to run deletes if tables are recreated on setup + if ( + self.run_define_tables != "each" + and self.run_create_tables != "each" + and self.run_deletes == "each" + ): + with self.bind.begin() as conn: + for table in reversed( + [ + t + for (t, fks) in sort_tables_and_constraints( + self._tables_metadata.tables.values() + ) + if t is not None + ] + ): + try: + if savepoints: + with conn.begin_nested(): + conn.execute(table.delete()) + else: + conn.execute(table.delete()) + except sa.exc.DBAPIError as ex: + print( + ("Error emptying table %s: %r" % (table, ex)), + file=sys.stderr, + ) + + @classmethod + def _teardown_once_metadata_bind(cls): + if cls.run_create_tables: + drop_all_tables_from_metadata(cls._tables_metadata, cls.bind) + + if cls.run_dispose_bind == "once": + cls.dispose_bind(cls.bind) + + cls._tables_metadata.bind = None + + if cls.run_setup_bind is not None: + cls.bind = None + + @classmethod + def setup_bind(cls): + return config.db + + @classmethod + def dispose_bind(cls, bind): + if hasattr(bind, "dispose"): + bind.dispose() + elif hasattr(bind, "close"): + bind.close() + + @classmethod + def define_tables(cls, metadata): + pass + + @classmethod + def fixtures(cls): + return {} + + @classmethod + def insert_data(cls, connection): + pass + + def sql_count_(self, count, fn): + self.assert_sql_count(self.bind, fn, count) + + def sql_eq_(self, callable_, statements): + self.assert_sql(self.bind, callable_, statements) + + @classmethod + def _load_fixtures(cls): + """Insert rows as represented by the fixtures() method.""" + headers, rows = {}, {} + for table, data in cls.fixtures().items(): + if len(data) < 2: + continue + if isinstance(table, str): + table = cls.tables[table] + headers[table] = data[0] + rows[table] = data[1:] + for table, fks in sort_tables_and_constraints( + cls._tables_metadata.tables.values() + ): + if table is None: + continue + if table not in headers: + continue + with cls.bind.begin() as conn: + conn.execute( + table.insert(), + [ + dict(zip(headers[table], column_values)) + for column_values in rows[table] + ], + ) + + +class NoCache: + @config.fixture(autouse=True, scope="function") + def _disable_cache(self): + _cache = config.db._compiled_cache + config.db._compiled_cache = None + yield + config.db._compiled_cache = _cache + + +class RemovesEvents: + @util.memoized_property + def _event_fns(self): + return set() + + def event_listen(self, target, name, fn, **kw): + self._event_fns.add((target, name, fn)) + event.listen(target, name, fn, **kw) + + @config.fixture(autouse=True, scope="function") + def _remove_events(self): + yield + for key in self._event_fns: + event.remove(*key) + + +class ComputedReflectionFixtureTest(TablesTest): + run_inserts = run_deletes = None + + __backend__ = True + __requires__ = ("computed_columns", "table_reflection") + + regexp = re.compile(r"[\[\]\(\)\s`'\"]*") + + def normalize(self, text): + return self.regexp.sub("", text).lower() + + @classmethod + def define_tables(cls, metadata): + from ... import Integer + from ... import testing + from ...schema import Column + from ...schema import Computed + from ...schema import Table + + Table( + "computed_default_table", + metadata, + Column("id", Integer, primary_key=True), + Column("normal", Integer), + Column("computed_col", Integer, Computed("normal + 42")), + Column("with_default", Integer, server_default="42"), + ) + + t = Table( + "computed_column_table", + metadata, + Column("id", Integer, primary_key=True), + Column("normal", Integer), + Column("computed_no_flag", Integer, Computed("normal + 42")), + ) + + if testing.requires.schemas.enabled: + t2 = Table( + "computed_column_table", + metadata, + Column("id", Integer, primary_key=True), + Column("normal", Integer), + Column("computed_no_flag", Integer, Computed("normal / 42")), + schema=config.test_schema, + ) + + if testing.requires.computed_columns_virtual.enabled: + t.append_column( + Column( + "computed_virtual", + Integer, + Computed("normal + 2", persisted=False), + ) + ) + if testing.requires.schemas.enabled: + t2.append_column( + Column( + "computed_virtual", + Integer, + Computed("normal / 2", persisted=False), + ) + ) + if testing.requires.computed_columns_stored.enabled: + t.append_column( + Column( + "computed_stored", + Integer, + Computed("normal - 42", persisted=True), + ) + ) + if testing.requires.schemas.enabled: + t2.append_column( + Column( + "computed_stored", + Integer, + Computed("normal * 42", persisted=True), + ) + ) + + +class CacheKeyFixture: + def _compare_equal(self, a, b, compare_values): + a_key = a._generate_cache_key() + b_key = b._generate_cache_key() + + if a_key is None: + assert a._annotations.get("nocache") + + assert b_key is None + else: + eq_(a_key.key, b_key.key) + eq_(hash(a_key.key), hash(b_key.key)) + + for a_param, b_param in zip(a_key.bindparams, b_key.bindparams): + assert a_param.compare(b_param, compare_values=compare_values) + return a_key, b_key + + def _run_cache_key_fixture(self, fixture, compare_values): + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + if a == b: + a_key, b_key = self._compare_equal( + case_a[a], case_b[b], compare_values + ) + if a_key is None: + continue + else: + a_key = case_a[a]._generate_cache_key() + b_key = case_b[b]._generate_cache_key() + + if a_key is None or b_key is None: + if a_key is None: + assert case_a[a]._annotations.get("nocache") + if b_key is None: + assert case_b[b]._annotations.get("nocache") + continue + + if a_key.key == b_key.key: + for a_param, b_param in zip( + a_key.bindparams, b_key.bindparams + ): + if not a_param.compare( + b_param, compare_values=compare_values + ): + break + else: + # this fails unconditionally since we could not + # find bound parameter values that differed. + # Usually we intended to get two distinct keys here + # so the failure will be more descriptive using the + # ne_() assertion. + ne_(a_key.key, b_key.key) + else: + ne_(a_key.key, b_key.key) + + # ClauseElement-specific test to ensure the cache key + # collected all the bound parameters that aren't marked + # as "literal execute" + if isinstance(case_a[a], ClauseElement) and isinstance( + case_b[b], ClauseElement + ): + assert_a_params = [] + assert_b_params = [] + + for elem in visitors.iterate(case_a[a]): + if elem.__visit_name__ == "bindparam": + assert_a_params.append(elem) + + for elem in visitors.iterate(case_b[b]): + if elem.__visit_name__ == "bindparam": + assert_b_params.append(elem) + + # note we're asserting the order of the params as well as + # if there are dupes or not. ordering has to be + # deterministic and matches what a traversal would provide. + eq_( + sorted(a_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_a_params), key=lambda b: b.key + ), + ) + eq_( + sorted(b_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_b_params), key=lambda b: b.key + ), + ) + + def _run_cache_key_equal_fixture(self, fixture, compare_values): + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + self._compare_equal(case_a[a], case_b[b], compare_values) + + +def insertmanyvalues_fixture( + connection, randomize_rows=False, warn_on_downgraded=False +): + dialect = connection.dialect + orig_dialect = dialect._deliver_insertmanyvalues_batches + orig_conn = connection._exec_insertmany_context + + class RandomCursor: + __slots__ = ("cursor",) + + def __init__(self, cursor): + self.cursor = cursor + + # only this method is called by the deliver method. + # by not having the other methods we assert that those aren't being + # used + + @property + def description(self): + return self.cursor.description + + def fetchall(self): + rows = self.cursor.fetchall() + rows = list(rows) + random.shuffle(rows) + return rows + + def _deliver_insertmanyvalues_batches( + cursor, statement, parameters, generic_setinputsizes, context + ): + if randomize_rows: + cursor = RandomCursor(cursor) + for batch in orig_dialect( + cursor, statement, parameters, generic_setinputsizes, context + ): + if warn_on_downgraded and batch.is_downgraded: + util.warn("Batches were downgraded for sorted INSERT") + + yield batch + + def _exec_insertmany_context(dialect, context): + with mock.patch.object( + dialect, + "_deliver_insertmanyvalues_batches", + new=_deliver_insertmanyvalues_batches, + ): + return orig_conn(dialect, context) + + connection._exec_insertmany_context = _exec_insertmany_context diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/pickleable.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/pickleable.py new file mode 100644 index 0000000..761891a --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/pickleable.py @@ -0,0 +1,155 @@ +# testing/pickleable.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""Classes used in pickling tests, need to be at the module level for +unpickling. +""" + +from __future__ import annotations + +from .entities import ComparableEntity +from ..schema import Column +from ..types import String + + +class User(ComparableEntity): + pass + + +class Order(ComparableEntity): + pass + + +class Dingaling(ComparableEntity): + pass + + +class EmailUser(User): + pass + + +class Address(ComparableEntity): + pass + + +# TODO: these are kind of arbitrary.... +class Child1(ComparableEntity): + pass + + +class Child2(ComparableEntity): + pass + + +class Parent(ComparableEntity): + pass + + +class Screen: + def __init__(self, obj, parent=None): + self.obj = obj + self.parent = parent + + +class Mixin: + email_address = Column(String) + + +class AddressWMixin(Mixin, ComparableEntity): + pass + + +class Foo: + def __init__(self, moredata, stuff="im stuff"): + self.data = "im data" + self.stuff = stuff + self.moredata = moredata + + __hash__ = object.__hash__ + + def __eq__(self, other): + return ( + other.data == self.data + and other.stuff == self.stuff + and other.moredata == self.moredata + ) + + +class Bar: + def __init__(self, x, y): + self.x = x + self.y = y + + __hash__ = object.__hash__ + + def __eq__(self, other): + return ( + other.__class__ is self.__class__ + and other.x == self.x + and other.y == self.y + ) + + def __str__(self): + return "Bar(%d, %d)" % (self.x, self.y) + + +class OldSchool: + def __init__(self, x, y): + self.x = x + self.y = y + + def __eq__(self, other): + return ( + other.__class__ is self.__class__ + and other.x == self.x + and other.y == self.y + ) + + +class OldSchoolWithoutCompare: + def __init__(self, x, y): + self.x = x + self.y = y + + +class BarWithoutCompare: + def __init__(self, x, y): + self.x = x + self.y = y + + def __str__(self): + return "Bar(%d, %d)" % (self.x, self.y) + + +class NotComparable: + def __init__(self, data): + self.data = data + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return NotImplemented + + def __ne__(self, other): + return NotImplemented + + +class BrokenComparable: + def __init__(self, data): + self.data = data + + def __hash__(self): + return id(self) + + def __eq__(self, other): + raise NotImplementedError + + def __ne__(self, other): + raise NotImplementedError diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__init__.py new file mode 100644 index 0000000..0f98777 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__init__.py @@ -0,0 +1,6 @@ +# testing/plugin/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..163dc80 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/bootstrap.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/bootstrap.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..9b70281 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/bootstrap.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/plugin_base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/plugin_base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c65e39e --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/plugin_base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/pytestplugin.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/pytestplugin.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6eb39ab --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/pytestplugin.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/bootstrap.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/bootstrap.py new file mode 100644 index 0000000..d0d3754 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/bootstrap.py @@ -0,0 +1,51 @@ +# testing/plugin/bootstrap.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +""" +Bootstrapper for test framework plugins. + +The entire rationale for this system is to get the modules in plugin/ +imported without importing all of the supporting library, so that we can +set up things for testing before coverage starts. + +The rationale for all of plugin/ being *in* the supporting library in the +first place is so that the testing and plugin suite is available to other +libraries, mainly external SQLAlchemy and Alembic dialects, to make use +of the same test environment and standard suites available to +SQLAlchemy/Alembic themselves without the need to ship/install a separate +package outside of SQLAlchemy. + + +""" + +import importlib.util +import os +import sys + + +bootstrap_file = locals()["bootstrap_file"] +to_bootstrap = locals()["to_bootstrap"] + + +def load_file_as_module(name): + path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name) + + spec = importlib.util.spec_from_file_location(name, path) + assert spec is not None + assert spec.loader is not None + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +if to_bootstrap == "pytest": + sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base") + sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True + sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin") +else: + raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/plugin_base.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/plugin_base.py new file mode 100644 index 0000000..a642668 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/plugin_base.py @@ -0,0 +1,779 @@ +# testing/plugin/plugin_base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import abc +from argparse import Namespace +import configparser +import logging +import os +from pathlib import Path +import re +import sys +from typing import Any + +from sqlalchemy.testing import asyncio + +"""Testing extensions. + +this module is designed to work as a testing-framework-agnostic library, +created so that multiple test frameworks can be supported at once +(mostly so that we can migrate to new ones). The current target +is pytest. + +""" + +# flag which indicates we are in the SQLAlchemy testing suite, +# and not that of Alembic or a third party dialect. +bootstrapped_as_sqlalchemy = False + +log = logging.getLogger("sqlalchemy.testing.plugin_base") + +# late imports +fixtures = None +engines = None +exclusions = None +warnings = None +profiling = None +provision = None +assertions = None +requirements = None +config = None +testing = None +util = None +file_config = None + +logging = None +include_tags = set() +exclude_tags = set() +options: Namespace = None # type: ignore + + +def setup_options(make_option): + make_option( + "--log-info", + action="callback", + type=str, + callback=_log, + help="turn on info logging for <LOG> (multiple OK)", + ) + make_option( + "--log-debug", + action="callback", + type=str, + callback=_log, + help="turn on debug logging for <LOG> (multiple OK)", + ) + make_option( + "--db", + action="append", + type=str, + dest="db", + help="Use prefab database uri. Multiple OK, " + "first one is run by default.", + ) + make_option( + "--dbs", + action="callback", + zeroarg_callback=_list_dbs, + help="List available prefab dbs", + ) + make_option( + "--dburi", + action="append", + type=str, + dest="dburi", + help="Database uri. Multiple OK, first one is run by default.", + ) + make_option( + "--dbdriver", + action="append", + type=str, + dest="dbdriver", + help="Additional database drivers to include in tests. " + "These are linked to the existing database URLs by the " + "provisioning system.", + ) + make_option( + "--dropfirst", + action="store_true", + dest="dropfirst", + help="Drop all tables in the target database first", + ) + make_option( + "--disable-asyncio", + action="store_true", + help="disable test / fixtures / provisoning running in asyncio", + ) + make_option( + "--backend-only", + action="callback", + zeroarg_callback=_set_tag_include("backend"), + help=( + "Run only tests marked with __backend__ or __sparse_backend__; " + "this is now equivalent to the pytest -m backend mark expression" + ), + ) + make_option( + "--nomemory", + action="callback", + zeroarg_callback=_set_tag_exclude("memory_intensive"), + help="Don't run memory profiling tests; " + "this is now equivalent to the pytest -m 'not memory_intensive' " + "mark expression", + ) + make_option( + "--notimingintensive", + action="callback", + zeroarg_callback=_set_tag_exclude("timing_intensive"), + help="Don't run timing intensive tests; " + "this is now equivalent to the pytest -m 'not timing_intensive' " + "mark expression", + ) + make_option( + "--nomypy", + action="callback", + zeroarg_callback=_set_tag_exclude("mypy"), + help="Don't run mypy typing tests; " + "this is now equivalent to the pytest -m 'not mypy' mark expression", + ) + make_option( + "--profile-sort", + type=str, + default="cumulative", + dest="profilesort", + help="Type of sort for profiling standard output", + ) + make_option( + "--profile-dump", + type=str, + dest="profiledump", + help="Filename where a single profile run will be dumped", + ) + make_option( + "--low-connections", + action="store_true", + dest="low_connections", + help="Use a low number of distinct connections - " + "i.e. for Oracle TNS", + ) + make_option( + "--write-idents", + type=str, + dest="write_idents", + help="write out generated follower idents to <file>, " + "when -n<num> is used", + ) + make_option( + "--requirements", + action="callback", + type=str, + callback=_requirements_opt, + help="requirements class for testing, overrides setup.cfg", + ) + make_option( + "--include-tag", + action="callback", + callback=_include_tag, + type=str, + help="Include tests with tag <tag>; " + "legacy, use pytest -m 'tag' instead", + ) + make_option( + "--exclude-tag", + action="callback", + callback=_exclude_tag, + type=str, + help="Exclude tests with tag <tag>; " + "legacy, use pytest -m 'not tag' instead", + ) + make_option( + "--write-profiles", + action="store_true", + dest="write_profiles", + default=False, + help="Write/update failing profiling data.", + ) + make_option( + "--force-write-profiles", + action="store_true", + dest="force_write_profiles", + default=False, + help="Unconditionally write/update profiling data.", + ) + make_option( + "--dump-pyannotate", + type=str, + dest="dump_pyannotate", + help="Run pyannotate and dump json info to given file", + ) + make_option( + "--mypy-extra-test-path", + type=str, + action="append", + default=[], + dest="mypy_extra_test_paths", + help="Additional test directories to add to the mypy tests. " + "This is used only when running mypy tests. Multiple OK", + ) + # db specific options + make_option( + "--postgresql-templatedb", + type=str, + help="name of template database to use for PostgreSQL " + "CREATE DATABASE (defaults to current database)", + ) + make_option( + "--oracledb-thick-mode", + action="store_true", + help="enables the 'thick mode' when testing with oracle+oracledb", + ) + + +def configure_follower(follower_ident): + """Configure required state for a follower. + + This invokes in the parent process and typically includes + database creation. + + """ + from sqlalchemy.testing import provision + + provision.FOLLOWER_IDENT = follower_ident + + +def memoize_important_follower_config(dict_): + """Store important configuration we will need to send to a follower. + + This invokes in the parent process after normal config is set up. + + Hook is currently not used. + + """ + + +def restore_important_follower_config(dict_): + """Restore important configuration needed by a follower. + + This invokes in the follower process. + + Hook is currently not used. + + """ + + +def read_config(root_path): + global file_config + file_config = configparser.ConfigParser() + file_config.read( + [str(root_path / "setup.cfg"), str(root_path / "test.cfg")] + ) + + +def pre_begin(opt): + """things to set up early, before coverage might be setup.""" + global options + options = opt + for fn in pre_configure: + fn(options, file_config) + + +def set_coverage_flag(value): + options.has_coverage = value + + +def post_begin(): + """things to set up later, once we know coverage is running.""" + # Lazy setup of other options (post coverage) + for fn in post_configure: + fn(options, file_config) + + # late imports, has to happen after config. + global util, fixtures, engines, exclusions, assertions, provision + global warnings, profiling, config, testing + from sqlalchemy import testing # noqa + from sqlalchemy.testing import fixtures, engines, exclusions # noqa + from sqlalchemy.testing import assertions, warnings, profiling # noqa + from sqlalchemy.testing import config, provision # noqa + from sqlalchemy import util # noqa + + warnings.setup_filters() + + +def _log(opt_str, value, parser): + global logging + if not logging: + import logging + + logging.basicConfig() + + if opt_str.endswith("-info"): + logging.getLogger(value).setLevel(logging.INFO) + elif opt_str.endswith("-debug"): + logging.getLogger(value).setLevel(logging.DEBUG) + + +def _list_dbs(*args): + if file_config is None: + # assume the current working directory is the one containing the + # setup file + read_config(Path.cwd()) + print("Available --db options (use --dburi to override)") + for macro in sorted(file_config.options("db")): + print("%20s\t%s" % (macro, file_config.get("db", macro))) + sys.exit(0) + + +def _requirements_opt(opt_str, value, parser): + _setup_requirements(value) + + +def _set_tag_include(tag): + def _do_include_tag(opt_str, value, parser): + _include_tag(opt_str, tag, parser) + + return _do_include_tag + + +def _set_tag_exclude(tag): + def _do_exclude_tag(opt_str, value, parser): + _exclude_tag(opt_str, tag, parser) + + return _do_exclude_tag + + +def _exclude_tag(opt_str, value, parser): + exclude_tags.add(value.replace("-", "_")) + + +def _include_tag(opt_str, value, parser): + include_tags.add(value.replace("-", "_")) + + +pre_configure = [] +post_configure = [] + + +def pre(fn): + pre_configure.append(fn) + return fn + + +def post(fn): + post_configure.append(fn) + return fn + + +@pre +def _setup_options(opt, file_config): + global options + options = opt + + +@pre +def _register_sqlite_numeric_dialect(opt, file_config): + from sqlalchemy.dialects import registry + + registry.register( + "sqlite.pysqlite_numeric", + "sqlalchemy.dialects.sqlite.pysqlite", + "_SQLiteDialect_pysqlite_numeric", + ) + registry.register( + "sqlite.pysqlite_dollar", + "sqlalchemy.dialects.sqlite.pysqlite", + "_SQLiteDialect_pysqlite_dollar", + ) + + +@post +def __ensure_cext(opt, file_config): + if os.environ.get("REQUIRE_SQLALCHEMY_CEXT", "0") == "1": + from sqlalchemy.util import has_compiled_ext + + try: + has_compiled_ext(raise_=True) + except ImportError as err: + raise AssertionError( + "REQUIRE_SQLALCHEMY_CEXT is set but can't import the " + "cython extensions" + ) from err + + +@post +def _init_symbols(options, file_config): + from sqlalchemy.testing import config + + config._fixture_functions = _fixture_fn_class() + + +@pre +def _set_disable_asyncio(opt, file_config): + if opt.disable_asyncio: + asyncio.ENABLE_ASYNCIO = False + + +@post +def _engine_uri(options, file_config): + from sqlalchemy import testing + from sqlalchemy.testing import config + from sqlalchemy.testing import provision + from sqlalchemy.engine import url as sa_url + + if options.dburi: + db_urls = list(options.dburi) + else: + db_urls = [] + + extra_drivers = options.dbdriver or [] + + if options.db: + for db_token in options.db: + for db in re.split(r"[,\s]+", db_token): + if db not in file_config.options("db"): + raise RuntimeError( + "Unknown URI specifier '%s'. " + "Specify --dbs for known uris." % db + ) + else: + db_urls.append(file_config.get("db", db)) + + if not db_urls: + db_urls.append(file_config.get("db", "default")) + + config._current = None + + if options.write_idents and provision.FOLLOWER_IDENT: + for db_url in [sa_url.make_url(db_url) for db_url in db_urls]: + with open(options.write_idents, "a") as file_: + file_.write( + f"{provision.FOLLOWER_IDENT} " + f"{db_url.render_as_string(hide_password=False)}\n" + ) + + expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers)) + + for db_url in expanded_urls: + log.info("Adding database URL: %s", db_url) + + cfg = provision.setup_config( + db_url, options, file_config, provision.FOLLOWER_IDENT + ) + if not config._current: + cfg.set_as_current(cfg, testing) + + +@post +def _requirements(options, file_config): + requirement_cls = file_config.get("sqla_testing", "requirement_cls") + _setup_requirements(requirement_cls) + + +def _setup_requirements(argument): + from sqlalchemy.testing import config + from sqlalchemy import testing + + modname, clsname = argument.split(":") + + # importlib.import_module() only introduced in 2.7, a little + # late + mod = __import__(modname) + for component in modname.split(".")[1:]: + mod = getattr(mod, component) + req_cls = getattr(mod, clsname) + + config.requirements = testing.requires = req_cls() + + config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy + + +@post +def _prep_testing_database(options, file_config): + from sqlalchemy.testing import config + + if options.dropfirst: + from sqlalchemy.testing import provision + + for cfg in config.Config.all_configs(): + provision.drop_all_schema_objects(cfg, cfg.db) + + +@post +def _post_setup_options(opt, file_config): + from sqlalchemy.testing import config + + config.options = options + config.file_config = file_config + + +@post +def _setup_profiling(options, file_config): + from sqlalchemy.testing import profiling + + profiling._profile_stats = profiling.ProfileStatsFile( + file_config.get("sqla_testing", "profile_file"), + sort=options.profilesort, + dump=options.profiledump, + ) + + +def want_class(name, cls): + if not issubclass(cls, fixtures.TestBase): + return False + elif name.startswith("_"): + return False + else: + return True + + +def want_method(cls, fn): + if not fn.__name__.startswith("test_"): + return False + elif fn.__module__ is None: + return False + else: + return True + + +def generate_sub_tests(cls, module, markers): + if "backend" in markers or "sparse_backend" in markers: + sparse = "sparse_backend" in markers + for cfg in _possible_configs_for_cls(cls, sparse=sparse): + orig_name = cls.__name__ + + # we can have special chars in these names except for the + # pytest junit plugin, which is tripped up by the brackets + # and periods, so sanitize + + alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name) + alpha_name = re.sub(r"_+$", "", alpha_name) + name = "%s_%s" % (cls.__name__, alpha_name) + subcls = type( + name, + (cls,), + {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg}, + ) + setattr(module, name, subcls) + yield subcls + else: + yield cls + + +def start_test_class_outside_fixtures(cls): + _do_skips(cls) + _setup_engine(cls) + + +def stop_test_class(cls): + # close sessions, immediate connections, etc. + fixtures.stop_test_class_inside_fixtures(cls) + + # close outstanding connection pool connections, dispose of + # additional engines + engines.testing_reaper.stop_test_class_inside_fixtures() + + +def stop_test_class_outside_fixtures(cls): + engines.testing_reaper.stop_test_class_outside_fixtures() + provision.stop_test_class_outside_fixtures(config, config.db, cls) + try: + if not options.low_connections: + assertions.global_cleanup_assertions() + finally: + _restore_engine() + + +def _restore_engine(): + if config._current: + config._current.reset(testing) + + +def final_process_cleanup(): + engines.testing_reaper.final_cleanup() + assertions.global_cleanup_assertions() + _restore_engine() + + +def _setup_engine(cls): + if getattr(cls, "__engine_options__", None): + opts = dict(cls.__engine_options__) + opts["scope"] = "class" + eng = engines.testing_engine(options=opts) + config._current.push_engine(eng, testing) + + +def before_test(test, test_module_name, test_class, test_name): + # format looks like: + # "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause" + + name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__) + + id_ = "%s.%s.%s" % (test_module_name, name, test_name) + + profiling._start_current_test(id_) + + +def after_test(test): + fixtures.after_test() + engines.testing_reaper.after_test() + + +def after_test_fixtures(test): + engines.testing_reaper.after_test_outside_fixtures(test) + + +def _possible_configs_for_cls(cls, reasons=None, sparse=False): + all_configs = set(config.Config.all_configs()) + + if cls.__unsupported_on__: + spec = exclusions.db_spec(*cls.__unsupported_on__) + for config_obj in list(all_configs): + if spec(config_obj): + all_configs.remove(config_obj) + + if getattr(cls, "__only_on__", None): + spec = exclusions.db_spec(*util.to_list(cls.__only_on__)) + for config_obj in list(all_configs): + if not spec(config_obj): + all_configs.remove(config_obj) + + if getattr(cls, "__only_on_config__", None): + all_configs.intersection_update([cls.__only_on_config__]) + + if hasattr(cls, "__requires__"): + requirements = config.requirements + for config_obj in list(all_configs): + for requirement in cls.__requires__: + check = getattr(requirements, requirement) + + skip_reasons = check.matching_config_reasons(config_obj) + if skip_reasons: + all_configs.remove(config_obj) + if reasons is not None: + reasons.extend(skip_reasons) + break + + if hasattr(cls, "__prefer_requires__"): + non_preferred = set() + requirements = config.requirements + for config_obj in list(all_configs): + for requirement in cls.__prefer_requires__: + check = getattr(requirements, requirement) + + if not check.enabled_for_config(config_obj): + non_preferred.add(config_obj) + if all_configs.difference(non_preferred): + all_configs.difference_update(non_preferred) + + if sparse: + # pick only one config from each base dialect + # sorted so we get the same backend each time selecting the highest + # server version info. + per_dialect = {} + for cfg in reversed( + sorted( + all_configs, + key=lambda cfg: ( + cfg.db.name, + cfg.db.driver, + cfg.db.dialect.server_version_info, + ), + ) + ): + db = cfg.db.name + if db not in per_dialect: + per_dialect[db] = cfg + return per_dialect.values() + + return all_configs + + +def _do_skips(cls): + reasons = [] + all_configs = _possible_configs_for_cls(cls, reasons) + + if getattr(cls, "__skip_if__", False): + for c in getattr(cls, "__skip_if__"): + if c(): + config.skip_test( + "'%s' skipped by %s" % (cls.__name__, c.__name__) + ) + + if not all_configs: + msg = "'%s.%s' unsupported on any DB implementation %s%s" % ( + cls.__module__, + cls.__name__, + ", ".join( + "'%s(%s)+%s'" + % ( + config_obj.db.name, + ".".join( + str(dig) + for dig in exclusions._server_version(config_obj.db) + ), + config_obj.db.driver, + ) + for config_obj in config.Config.all_configs() + ), + ", ".join(reasons), + ) + config.skip_test(msg) + elif hasattr(cls, "__prefer_backends__"): + non_preferred = set() + spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__)) + for config_obj in all_configs: + if not spec(config_obj): + non_preferred.add(config_obj) + if all_configs.difference(non_preferred): + all_configs.difference_update(non_preferred) + + if config._current not in all_configs: + _setup_config(all_configs.pop(), cls) + + +def _setup_config(config_obj, ctx): + config._current.push(config_obj, testing) + + +class FixtureFunctions(abc.ABC): + @abc.abstractmethod + def skip_test_exception(self, *arg, **kw): + raise NotImplementedError() + + @abc.abstractmethod + def combinations(self, *args, **kw): + raise NotImplementedError() + + @abc.abstractmethod + def param_ident(self, *args, **kw): + raise NotImplementedError() + + @abc.abstractmethod + def fixture(self, *arg, **kw): + raise NotImplementedError() + + def get_current_test_name(self): + raise NotImplementedError() + + @abc.abstractmethod + def mark_base_test_class(self) -> Any: + raise NotImplementedError() + + @abc.abstractproperty + def add_to_marker(self): + raise NotImplementedError() + + +_fixture_fn_class = None + + +def set_fixture_functions(fixture_fn_class): + global _fixture_fn_class + _fixture_fn_class = fixture_fn_class diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/pytestplugin.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/pytestplugin.py new file mode 100644 index 0000000..1a4d4bb --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/pytestplugin.py @@ -0,0 +1,868 @@ +# testing/plugin/pytestplugin.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from __future__ import annotations + +import argparse +import collections +from functools import update_wrapper +import inspect +import itertools +import operator +import os +import re +import sys +from typing import TYPE_CHECKING +import uuid + +import pytest + +try: + # installed by bootstrap.py + if not TYPE_CHECKING: + import sqla_plugin_base as plugin_base +except ImportError: + # assume we're a package, use traditional import + from . import plugin_base + + +def pytest_addoption(parser): + group = parser.getgroup("sqlalchemy") + + def make_option(name, **kw): + callback_ = kw.pop("callback", None) + if callback_: + + class CallableAction(argparse.Action): + def __call__( + self, parser, namespace, values, option_string=None + ): + callback_(option_string, values, parser) + + kw["action"] = CallableAction + + zeroarg_callback = kw.pop("zeroarg_callback", None) + if zeroarg_callback: + + class CallableAction(argparse.Action): + def __init__( + self, + option_strings, + dest, + default=False, + required=False, + help=None, # noqa + ): + super().__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + const=True, + default=default, + required=required, + help=help, + ) + + def __call__( + self, parser, namespace, values, option_string=None + ): + zeroarg_callback(option_string, values, parser) + + kw["action"] = CallableAction + + group.addoption(name, **kw) + + plugin_base.setup_options(make_option) + + +def pytest_configure(config: pytest.Config): + plugin_base.read_config(config.rootpath) + if plugin_base.exclude_tags or plugin_base.include_tags: + new_expr = " and ".join( + list(plugin_base.include_tags) + + [f"not {tag}" for tag in plugin_base.exclude_tags] + ) + + if config.option.markexpr: + config.option.markexpr += f" and {new_expr}" + else: + config.option.markexpr = new_expr + + if config.pluginmanager.hasplugin("xdist"): + config.pluginmanager.register(XDistHooks()) + + if hasattr(config, "workerinput"): + plugin_base.restore_important_follower_config(config.workerinput) + plugin_base.configure_follower(config.workerinput["follower_ident"]) + else: + if config.option.write_idents and os.path.exists( + config.option.write_idents + ): + os.remove(config.option.write_idents) + + plugin_base.pre_begin(config.option) + + plugin_base.set_coverage_flag( + bool(getattr(config.option, "cov_source", False)) + ) + + plugin_base.set_fixture_functions(PytestFixtureFunctions) + + if config.option.dump_pyannotate: + global DUMP_PYANNOTATE + DUMP_PYANNOTATE = True + + +DUMP_PYANNOTATE = False + + +@pytest.fixture(autouse=True) +def collect_types_fixture(): + if DUMP_PYANNOTATE: + from pyannotate_runtime import collect_types + + collect_types.start() + yield + if DUMP_PYANNOTATE: + collect_types.stop() + + +def _log_sqlalchemy_info(session): + import sqlalchemy + from sqlalchemy import __version__ + from sqlalchemy.util import has_compiled_ext + from sqlalchemy.util._has_cy import _CYEXTENSION_MSG + + greet = "sqlalchemy installation" + site = "no user site" if sys.flags.no_user_site else "user site loaded" + msgs = [ + f"SQLAlchemy {__version__} ({site})", + f"Path: {sqlalchemy.__file__}", + ] + + if has_compiled_ext(): + from sqlalchemy.cyextension import util + + msgs.append(f"compiled extension enabled, e.g. {util.__file__} ") + else: + msgs.append(f"compiled extension not enabled; {_CYEXTENSION_MSG}") + + pm = session.config.pluginmanager.get_plugin("terminalreporter") + if pm: + pm.write_sep("=", greet) + for m in msgs: + pm.write_line(m) + else: + # fancy pants reporter not found, fallback to plain print + print("=" * 25, greet, "=" * 25) + for m in msgs: + print(m) + + +def pytest_sessionstart(session): + from sqlalchemy.testing import asyncio + + _log_sqlalchemy_info(session) + asyncio._assume_async(plugin_base.post_begin) + + +def pytest_sessionfinish(session): + from sqlalchemy.testing import asyncio + + asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup) + + if session.config.option.dump_pyannotate: + from pyannotate_runtime import collect_types + + collect_types.dump_stats(session.config.option.dump_pyannotate) + + +def pytest_unconfigure(config): + from sqlalchemy.testing import asyncio + + asyncio._shutdown() + + +def pytest_collection_finish(session): + if session.config.option.dump_pyannotate: + from pyannotate_runtime import collect_types + + lib_sqlalchemy = os.path.abspath("lib/sqlalchemy") + + def _filter(filename): + filename = os.path.normpath(os.path.abspath(filename)) + if "lib/sqlalchemy" not in os.path.commonpath( + [filename, lib_sqlalchemy] + ): + return None + if "testing" in filename: + return None + + return filename + + collect_types.init_types_collection(filter_filename=_filter) + + +class XDistHooks: + def pytest_configure_node(self, node): + from sqlalchemy.testing import provision + from sqlalchemy.testing import asyncio + + # the master for each node fills workerinput dictionary + # which pytest-xdist will transfer to the subprocess + + plugin_base.memoize_important_follower_config(node.workerinput) + + node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12] + + asyncio._maybe_async_provisioning( + provision.create_follower_db, node.workerinput["follower_ident"] + ) + + def pytest_testnodedown(self, node, error): + from sqlalchemy.testing import provision + from sqlalchemy.testing import asyncio + + asyncio._maybe_async_provisioning( + provision.drop_follower_db, node.workerinput["follower_ident"] + ) + + +def pytest_collection_modifyitems(session, config, items): + # look for all those classes that specify __backend__ and + # expand them out into per-database test cases. + + # this is much easier to do within pytest_pycollect_makeitem, however + # pytest is iterating through cls.__dict__ as makeitem is + # called which causes a "dictionary changed size" error on py3k. + # I'd submit a pullreq for them to turn it into a list first, but + # it's to suit the rather odd use case here which is that we are adding + # new classes to a module on the fly. + + from sqlalchemy.testing import asyncio + + rebuilt_items = collections.defaultdict( + lambda: collections.defaultdict(list) + ) + + items[:] = [ + item + for item in items + if item.getparent(pytest.Class) is not None + and not item.getparent(pytest.Class).name.startswith("_") + ] + + test_classes = {item.getparent(pytest.Class) for item in items} + + def collect(element): + for inst_or_fn in element.collect(): + if isinstance(inst_or_fn, pytest.Collector): + yield from collect(inst_or_fn) + else: + yield inst_or_fn + + def setup_test_classes(): + for test_class in test_classes: + # transfer legacy __backend__ and __sparse_backend__ symbols + # to be markers + add_markers = set() + if getattr(test_class.cls, "__backend__", False) or getattr( + test_class.cls, "__only_on__", False + ): + add_markers = {"backend"} + elif getattr(test_class.cls, "__sparse_backend__", False): + add_markers = {"sparse_backend"} + else: + add_markers = frozenset() + + existing_markers = { + mark.name for mark in test_class.iter_markers() + } + add_markers = add_markers - existing_markers + all_markers = existing_markers.union(add_markers) + + for marker in add_markers: + test_class.add_marker(marker) + + for sub_cls in plugin_base.generate_sub_tests( + test_class.cls, test_class.module, all_markers + ): + if sub_cls is not test_class.cls: + per_cls_dict = rebuilt_items[test_class.cls] + + module = test_class.getparent(pytest.Module) + + new_cls = pytest.Class.from_parent( + name=sub_cls.__name__, parent=module + ) + for marker in add_markers: + new_cls.add_marker(marker) + + for fn in collect(new_cls): + per_cls_dict[fn.name].append(fn) + + # class requirements will sometimes need to access the DB to check + # capabilities, so need to do this for async + asyncio._maybe_async_provisioning(setup_test_classes) + + newitems = [] + for item in items: + cls_ = item.cls + if cls_ in rebuilt_items: + newitems.extend(rebuilt_items[cls_][item.name]) + else: + newitems.append(item) + + # seems like the functions attached to a test class aren't sorted already? + # is that true and why's that? (when using unittest, they're sorted) + items[:] = sorted( + newitems, + key=lambda item: ( + item.getparent(pytest.Module).name, + item.getparent(pytest.Class).name, + item.name, + ), + ) + + +def pytest_pycollect_makeitem(collector, name, obj): + if inspect.isclass(obj) and plugin_base.want_class(name, obj): + from sqlalchemy.testing import config + + if config.any_async: + obj = _apply_maybe_async(obj) + + return [ + pytest.Class.from_parent( + name=parametrize_cls.__name__, parent=collector + ) + for parametrize_cls in _parametrize_cls(collector.module, obj) + ] + elif ( + inspect.isfunction(obj) + and collector.cls is not None + and plugin_base.want_method(collector.cls, obj) + ): + # None means, fall back to default logic, which includes + # method-level parametrize + return None + else: + # empty list means skip this item + return [] + + +def _is_wrapped_coroutine_function(fn): + while hasattr(fn, "__wrapped__"): + fn = fn.__wrapped__ + + return inspect.iscoroutinefunction(fn) + + +def _apply_maybe_async(obj, recurse=True): + from sqlalchemy.testing import asyncio + + for name, value in vars(obj).items(): + if ( + (callable(value) or isinstance(value, classmethod)) + and not getattr(value, "_maybe_async_applied", False) + and (name.startswith("test_")) + and not _is_wrapped_coroutine_function(value) + ): + is_classmethod = False + if isinstance(value, classmethod): + value = value.__func__ + is_classmethod = True + + @_pytest_fn_decorator + def make_async(fn, *args, **kwargs): + return asyncio._maybe_async(fn, *args, **kwargs) + + do_async = make_async(value) + if is_classmethod: + do_async = classmethod(do_async) + do_async._maybe_async_applied = True + + setattr(obj, name, do_async) + if recurse: + for cls in obj.mro()[1:]: + if cls != object: + _apply_maybe_async(cls, False) + return obj + + +def _parametrize_cls(module, cls): + """implement a class-based version of pytest parametrize.""" + + if "_sa_parametrize" not in cls.__dict__: + return [cls] + + _sa_parametrize = cls._sa_parametrize + classes = [] + for full_param_set in itertools.product( + *[params for argname, params in _sa_parametrize] + ): + cls_variables = {} + + for argname, param in zip( + [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set + ): + if not argname: + raise TypeError("need argnames for class-based combinations") + argname_split = re.split(r",\s*", argname) + for arg, val in zip(argname_split, param.values): + cls_variables[arg] = val + parametrized_name = "_".join( + re.sub(r"\W", "", token) + for param in full_param_set + for token in param.id.split("-") + ) + name = "%s_%s" % (cls.__name__, parametrized_name) + newcls = type.__new__(type, name, (cls,), cls_variables) + setattr(module, name, newcls) + classes.append(newcls) + return classes + + +_current_class = None + + +def pytest_runtest_setup(item): + from sqlalchemy.testing import asyncio + + # pytest_runtest_setup runs *before* pytest fixtures with scope="class". + # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest + # for the whole class and has to run things that are across all current + # databases, so we run this outside of the pytest fixture system altogether + # and ensure asyncio greenlet if any engines are async + + global _current_class + + if isinstance(item, pytest.Function) and _current_class is None: + asyncio._maybe_async_provisioning( + plugin_base.start_test_class_outside_fixtures, + item.cls, + ) + _current_class = item.getparent(pytest.Class) + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_teardown(item, nextitem): + # runs inside of pytest function fixture scope + # after test function runs + + from sqlalchemy.testing import asyncio + + asyncio._maybe_async(plugin_base.after_test, item) + + yield + # this is now after all the fixture teardown have run, the class can be + # finalized. Since pytest v7 this finalizer can no longer be added in + # pytest_runtest_setup since the class has not yet been setup at that + # time. + # See https://github.com/pytest-dev/pytest/issues/9343 + global _current_class, _current_report + + if _current_class is not None and ( + # last test or a new class + nextitem is None + or nextitem.getparent(pytest.Class) is not _current_class + ): + _current_class = None + + try: + asyncio._maybe_async_provisioning( + plugin_base.stop_test_class_outside_fixtures, item.cls + ) + except Exception as e: + # in case of an exception during teardown attach the original + # error to the exception message, otherwise it will get lost + if _current_report.failed: + if not e.args: + e.args = ( + "__Original test failure__:\n" + + _current_report.longreprtext, + ) + elif e.args[-1] and isinstance(e.args[-1], str): + args = list(e.args) + args[-1] += ( + "\n__Original test failure__:\n" + + _current_report.longreprtext + ) + e.args = tuple(args) + else: + e.args += ( + "__Original test failure__", + _current_report.longreprtext, + ) + raise + finally: + _current_report = None + + +def pytest_runtest_call(item): + # runs inside of pytest function fixture scope + # before test function runs + + from sqlalchemy.testing import asyncio + + asyncio._maybe_async( + plugin_base.before_test, + item, + item.module.__name__, + item.cls, + item.name, + ) + + +_current_report = None + + +def pytest_runtest_logreport(report): + global _current_report + if report.when == "call": + _current_report = report + + +@pytest.fixture(scope="class") +def setup_class_methods(request): + from sqlalchemy.testing import asyncio + + cls = request.cls + + if hasattr(cls, "setup_test_class"): + asyncio._maybe_async(cls.setup_test_class) + + yield + + if hasattr(cls, "teardown_test_class"): + asyncio._maybe_async(cls.teardown_test_class) + + asyncio._maybe_async(plugin_base.stop_test_class, cls) + + +@pytest.fixture(scope="function") +def setup_test_methods(request): + from sqlalchemy.testing import asyncio + + # called for each test + + self = request.instance + + # before this fixture runs: + + # 1. function level "autouse" fixtures under py3k (examples: TablesTest + # define tables / data, MappedTest define tables / mappers / data) + + # 2. was for p2k. no longer applies + + # 3. run outer xdist-style setup + if hasattr(self, "setup_test"): + asyncio._maybe_async(self.setup_test) + + # alembic test suite is using setUp and tearDown + # xdist methods; support these in the test suite + # for the near term + if hasattr(self, "setUp"): + asyncio._maybe_async(self.setUp) + + # inside the yield: + # 4. function level fixtures defined on test functions themselves, + # e.g. "connection", "metadata" run next + + # 5. pytest hook pytest_runtest_call then runs + + # 6. test itself runs + + yield + + # yield finishes: + + # 7. function level fixtures defined on test functions + # themselves, e.g. "connection" rolls back the transaction, "metadata" + # emits drop all + + # 8. pytest hook pytest_runtest_teardown hook runs, this is associated + # with fixtures close all sessions, provisioning.stop_test_class(), + # engines.testing_reaper -> ensure all connection pool connections + # are returned, engines created by testing_engine that aren't the + # config engine are disposed + + asyncio._maybe_async(plugin_base.after_test_fixtures, self) + + # 10. run xdist-style teardown + if hasattr(self, "tearDown"): + asyncio._maybe_async(self.tearDown) + + if hasattr(self, "teardown_test"): + asyncio._maybe_async(self.teardown_test) + + # 11. was for p2k. no longer applies + + # 12. function level "autouse" fixtures under py3k (examples: TablesTest / + # MappedTest delete table data, possibly drop tables and clear mappers + # depending on the flags defined by the test class) + + +def _pytest_fn_decorator(target): + """Port of langhelpers.decorator with pytest-specific tricks.""" + + from sqlalchemy.util.langhelpers import format_argspec_plus + from sqlalchemy.util.compat import inspect_getfullargspec + + def _exec_code_in_env(code, env, fn_name): + # note this is affected by "from __future__ import annotations" at + # the top; exec'ed code will use non-evaluated annotations + # which allows us to be more flexible with code rendering + # in format_argpsec_plus() + exec(code, env) + return env[fn_name] + + def decorate(fn, add_positional_parameters=()): + spec = inspect_getfullargspec(fn) + if add_positional_parameters: + spec.args.extend(add_positional_parameters) + + metadata = dict( + __target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__ + ) + metadata.update(format_argspec_plus(spec, grouped=False)) + code = ( + """\ +def %(name)s%(grouped_args)s: + return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s) +""" + % metadata + ) + decorated = _exec_code_in_env( + code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__ + ) + if not add_positional_parameters: + decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__ + decorated.__wrapped__ = fn + return update_wrapper(decorated, fn) + else: + # this is the pytest hacky part. don't do a full update wrapper + # because pytest is really being sneaky about finding the args + # for the wrapped function + decorated.__module__ = fn.__module__ + decorated.__name__ = fn.__name__ + if hasattr(fn, "pytestmark"): + decorated.pytestmark = fn.pytestmark + return decorated + + return decorate + + +class PytestFixtureFunctions(plugin_base.FixtureFunctions): + def skip_test_exception(self, *arg, **kw): + return pytest.skip.Exception(*arg, **kw) + + @property + def add_to_marker(self): + return pytest.mark + + def mark_base_test_class(self): + return pytest.mark.usefixtures( + "setup_class_methods", "setup_test_methods" + ) + + _combination_id_fns = { + "i": lambda obj: obj, + "r": repr, + "s": str, + "n": lambda obj: ( + obj.__name__ if hasattr(obj, "__name__") else type(obj).__name__ + ), + } + + def combinations(self, *arg_sets, **kw): + """Facade for pytest.mark.parametrize. + + Automatically derives argument names from the callable which in our + case is always a method on a class with positional arguments. + + ids for parameter sets are derived using an optional template. + + """ + from sqlalchemy.testing import exclusions + + if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"): + arg_sets = list(arg_sets[0]) + + argnames = kw.pop("argnames", None) + + def _filter_exclusions(args): + result = [] + gathered_exclusions = [] + for a in args: + if isinstance(a, exclusions.compound): + gathered_exclusions.append(a) + else: + result.append(a) + + return result, gathered_exclusions + + id_ = kw.pop("id_", None) + + tobuild_pytest_params = [] + has_exclusions = False + if id_: + _combination_id_fns = self._combination_id_fns + + # because itemgetter is not consistent for one argument vs. + # multiple, make it multiple in all cases and use a slice + # to omit the first argument + _arg_getter = operator.itemgetter( + 0, + *[ + idx + for idx, char in enumerate(id_) + if char in ("n", "r", "s", "a") + ], + ) + fns = [ + (operator.itemgetter(idx), _combination_id_fns[char]) + for idx, char in enumerate(id_) + if char in _combination_id_fns + ] + + for arg in arg_sets: + if not isinstance(arg, tuple): + arg = (arg,) + + fn_params, param_exclusions = _filter_exclusions(arg) + + parameters = _arg_getter(fn_params)[1:] + + if param_exclusions: + has_exclusions = True + + tobuild_pytest_params.append( + ( + parameters, + param_exclusions, + "-".join( + comb_fn(getter(arg)) for getter, comb_fn in fns + ), + ) + ) + + else: + for arg in arg_sets: + if not isinstance(arg, tuple): + arg = (arg,) + + fn_params, param_exclusions = _filter_exclusions(arg) + + if param_exclusions: + has_exclusions = True + + tobuild_pytest_params.append( + (fn_params, param_exclusions, None) + ) + + pytest_params = [] + for parameters, param_exclusions, id_ in tobuild_pytest_params: + if has_exclusions: + parameters += (param_exclusions,) + + param = pytest.param(*parameters, id=id_) + pytest_params.append(param) + + def decorate(fn): + if inspect.isclass(fn): + if has_exclusions: + raise NotImplementedError( + "exclusions not supported for class level combinations" + ) + if "_sa_parametrize" not in fn.__dict__: + fn._sa_parametrize = [] + fn._sa_parametrize.append((argnames, pytest_params)) + return fn + else: + _fn_argnames = inspect.getfullargspec(fn).args[1:] + if argnames is None: + _argnames = _fn_argnames + else: + _argnames = re.split(r", *", argnames) + + if has_exclusions: + existing_exl = sum( + 1 for n in _fn_argnames if n.startswith("_exclusions") + ) + current_exclusion_name = f"_exclusions_{existing_exl}" + _argnames += [current_exclusion_name] + + @_pytest_fn_decorator + def check_exclusions(fn, *args, **kw): + _exclusions = args[-1] + if _exclusions: + exlu = exclusions.compound().add(*_exclusions) + fn = exlu(fn) + return fn(*args[:-1], **kw) + + fn = check_exclusions( + fn, add_positional_parameters=(current_exclusion_name,) + ) + + return pytest.mark.parametrize(_argnames, pytest_params)(fn) + + return decorate + + def param_ident(self, *parameters): + ident = parameters[0] + return pytest.param(*parameters[1:], id=ident) + + def fixture(self, *arg, **kw): + from sqlalchemy.testing import config + from sqlalchemy.testing import asyncio + + # wrapping pytest.fixture function. determine if + # decorator was called as @fixture or @fixture(). + if len(arg) > 0 and callable(arg[0]): + # was called as @fixture(), we have the function to wrap. + fn = arg[0] + arg = arg[1:] + else: + # was called as @fixture, don't have the function yet. + fn = None + + # create a pytest.fixture marker. because the fn is not being + # passed, this is always a pytest.FixtureFunctionMarker() + # object (or whatever pytest is calling it when you read this) + # that is waiting for a function. + fixture = pytest.fixture(*arg, **kw) + + # now apply wrappers to the function, including fixture itself + + def wrap(fn): + if config.any_async: + fn = asyncio._maybe_async_wrapper(fn) + # other wrappers may be added here + + # now apply FixtureFunctionMarker + fn = fixture(fn) + + return fn + + if fn: + return wrap(fn) + else: + return wrap + + def get_current_test_name(self): + return os.environ.get("PYTEST_CURRENT_TEST") + + def async_test(self, fn): + from sqlalchemy.testing import asyncio + + @_pytest_fn_decorator + def decorate(fn, *args, **kwargs): + asyncio._run_coroutine_function(fn, *args, **kwargs) + + return decorate(fn) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/profiling.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/profiling.py new file mode 100644 index 0000000..b9093c9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/profiling.py @@ -0,0 +1,324 @@ +# testing/profiling.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""Profiling support for unit and performance tests. + +These are special purpose profiling methods which operate +in a more fine-grained way than nose's profiling plugin. + +""" + +from __future__ import annotations + +import collections +import contextlib +import os +import platform +import pstats +import re +import sys + +from . import config +from .util import gc_collect +from ..util import has_compiled_ext + + +try: + import cProfile +except ImportError: + cProfile = None + +_profile_stats = None +"""global ProfileStatsFileInstance. + +plugin_base assigns this at the start of all tests. + +""" + + +_current_test = None +"""String id of current test. + +plugin_base assigns this at the start of each test using +_start_current_test. + +""" + + +def _start_current_test(id_): + global _current_test + _current_test = id_ + + if _profile_stats.force_write: + _profile_stats.reset_count() + + +class ProfileStatsFile: + """Store per-platform/fn profiling results in a file. + + There was no json module available when this was written, but now + the file format which is very deterministically line oriented is kind of + handy in any case for diffs and merges. + + """ + + def __init__(self, filename, sort="cumulative", dump=None): + self.force_write = ( + config.options is not None and config.options.force_write_profiles + ) + self.write = self.force_write or ( + config.options is not None and config.options.write_profiles + ) + self.fname = os.path.abspath(filename) + self.short_fname = os.path.split(self.fname)[-1] + self.data = collections.defaultdict( + lambda: collections.defaultdict(dict) + ) + self.dump = dump + self.sort = sort + self._read() + if self.write: + # rewrite for the case where features changed, + # etc. + self._write() + + @property + def platform_key(self): + dbapi_key = config.db.name + "_" + config.db.driver + + if config.db.name == "sqlite" and config.db.dialect._is_url_file_db( + config.db.url + ): + dbapi_key += "_file" + + # keep it at 2.7, 3.1, 3.2, etc. for now. + py_version = ".".join([str(v) for v in sys.version_info[0:2]]) + + platform_tokens = [ + platform.machine(), + platform.system().lower(), + platform.python_implementation().lower(), + py_version, + dbapi_key, + ] + + platform_tokens.append("dbapiunicode") + _has_cext = has_compiled_ext() + platform_tokens.append(_has_cext and "cextensions" or "nocextensions") + return "_".join(platform_tokens) + + def has_stats(self): + test_key = _current_test + return ( + test_key in self.data and self.platform_key in self.data[test_key] + ) + + def result(self, callcount): + test_key = _current_test + per_fn = self.data[test_key] + per_platform = per_fn[self.platform_key] + + if "counts" not in per_platform: + per_platform["counts"] = counts = [] + else: + counts = per_platform["counts"] + + if "current_count" not in per_platform: + per_platform["current_count"] = current_count = 0 + else: + current_count = per_platform["current_count"] + + has_count = len(counts) > current_count + + if not has_count: + counts.append(callcount) + if self.write: + self._write() + result = None + else: + result = per_platform["lineno"], counts[current_count] + per_platform["current_count"] += 1 + return result + + def reset_count(self): + test_key = _current_test + # since self.data is a defaultdict, don't access a key + # if we don't know it's there first. + if test_key not in self.data: + return + per_fn = self.data[test_key] + if self.platform_key not in per_fn: + return + per_platform = per_fn[self.platform_key] + if "counts" in per_platform: + per_platform["counts"][:] = [] + + def replace(self, callcount): + test_key = _current_test + per_fn = self.data[test_key] + per_platform = per_fn[self.platform_key] + counts = per_platform["counts"] + current_count = per_platform["current_count"] + if current_count < len(counts): + counts[current_count - 1] = callcount + else: + counts[-1] = callcount + if self.write: + self._write() + + def _header(self): + return ( + "# %s\n" + "# This file is written out on a per-environment basis.\n" + "# For each test in aaa_profiling, the corresponding " + "function and \n" + "# environment is located within this file. " + "If it doesn't exist,\n" + "# the test is skipped.\n" + "# If a callcount does exist, it is compared " + "to what we received. \n" + "# assertions are raised if the counts do not match.\n" + "# \n" + "# To add a new callcount test, apply the function_call_count \n" + "# decorator and re-run the tests using the --write-profiles \n" + "# option - this file will be rewritten including the new count.\n" + "# \n" + ) % (self.fname) + + def _read(self): + try: + profile_f = open(self.fname) + except OSError: + return + for lineno, line in enumerate(profile_f): + line = line.strip() + if not line or line.startswith("#"): + continue + + test_key, platform_key, counts = line.split() + per_fn = self.data[test_key] + per_platform = per_fn[platform_key] + c = [int(count) for count in counts.split(",")] + per_platform["counts"] = c + per_platform["lineno"] = lineno + 1 + per_platform["current_count"] = 0 + profile_f.close() + + def _write(self): + print("Writing profile file %s" % self.fname) + profile_f = open(self.fname, "w") + profile_f.write(self._header()) + for test_key in sorted(self.data): + per_fn = self.data[test_key] + profile_f.write("\n# TEST: %s\n\n" % test_key) + for platform_key in sorted(per_fn): + per_platform = per_fn[platform_key] + c = ",".join(str(count) for count in per_platform["counts"]) + profile_f.write("%s %s %s\n" % (test_key, platform_key, c)) + profile_f.close() + + +def function_call_count(variance=0.05, times=1, warmup=0): + """Assert a target for a test case's function call count. + + The main purpose of this assertion is to detect changes in + callcounts for various functions - the actual number is not as important. + Callcounts are stored in a file keyed to Python version and OS platform + information. This file is generated automatically for new tests, + and versioned so that unexpected changes in callcounts will be detected. + + """ + + # use signature-rewriting decorator function so that pytest fixtures + # still work on py27. In Py3, update_wrapper() alone is good enough, + # likely due to the introduction of __signature__. + + from sqlalchemy.util import decorator + + @decorator + def wrap(fn, *args, **kw): + for warm in range(warmup): + fn(*args, **kw) + + timerange = range(times) + with count_functions(variance=variance): + for time in timerange: + rv = fn(*args, **kw) + return rv + + return wrap + + +@contextlib.contextmanager +def count_functions(variance=0.05): + if cProfile is None: + raise config._skip_test_exception("cProfile is not installed") + + if not _profile_stats.has_stats() and not _profile_stats.write: + config.skip_test( + "No profiling stats available on this " + "platform for this function. Run tests with " + "--write-profiles to add statistics to %s for " + "this platform." % _profile_stats.short_fname + ) + + gc_collect() + + pr = cProfile.Profile() + pr.enable() + # began = time.time() + yield + # ended = time.time() + pr.disable() + + # s = StringIO() + stats = pstats.Stats(pr, stream=sys.stdout) + + # timespent = ended - began + callcount = stats.total_calls + + expected = _profile_stats.result(callcount) + + if expected is None: + expected_count = None + else: + line_no, expected_count = expected + + print("Pstats calls: %d Expected %s" % (callcount, expected_count)) + stats.sort_stats(*re.split(r"[, ]", _profile_stats.sort)) + stats.print_stats() + if _profile_stats.dump: + base, ext = os.path.splitext(_profile_stats.dump) + test_name = _current_test.split(".")[-1] + dumpfile = "%s_%s%s" % (base, test_name, ext or ".profile") + stats.dump_stats(dumpfile) + print("Dumped stats to file %s" % dumpfile) + # stats.print_callers() + if _profile_stats.force_write: + _profile_stats.replace(callcount) + elif expected_count: + deviance = int(callcount * variance) + failed = abs(callcount - expected_count) > deviance + + if failed: + if _profile_stats.write: + _profile_stats.replace(callcount) + else: + raise AssertionError( + "Adjusted function call count %s not within %s%% " + "of expected %s, platform %s. Rerun with " + "--write-profiles to " + "regenerate this callcount." + % ( + callcount, + (variance * 100), + expected_count, + _profile_stats.platform_key, + ) + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/provision.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/provision.py new file mode 100644 index 0000000..e50c6eb --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/provision.py @@ -0,0 +1,496 @@ +# testing/provision.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from __future__ import annotations + +import collections +import logging + +from . import config +from . import engines +from . import util +from .. import exc +from .. import inspect +from ..engine import url as sa_url +from ..sql import ddl +from ..sql import schema + + +log = logging.getLogger(__name__) + +FOLLOWER_IDENT = None + + +class register: + def __init__(self, decorator=None): + self.fns = {} + self.decorator = decorator + + @classmethod + def init(cls, fn): + return register().for_db("*")(fn) + + @classmethod + def init_decorator(cls, decorator): + return register(decorator).for_db("*") + + def for_db(self, *dbnames): + def decorate(fn): + if self.decorator: + fn = self.decorator(fn) + for dbname in dbnames: + self.fns[dbname] = fn + return self + + return decorate + + def __call__(self, cfg, *arg, **kw): + if isinstance(cfg, str): + url = sa_url.make_url(cfg) + elif isinstance(cfg, sa_url.URL): + url = cfg + else: + url = cfg.db.url + backend = url.get_backend_name() + if backend in self.fns: + return self.fns[backend](cfg, *arg, **kw) + else: + return self.fns["*"](cfg, *arg, **kw) + + +def create_follower_db(follower_ident): + for cfg in _configs_for_db_operation(): + log.info("CREATE database %s, URI %r", follower_ident, cfg.db.url) + create_db(cfg, cfg.db, follower_ident) + + +def setup_config(db_url, options, file_config, follower_ident): + # load the dialect, which should also have it set up its provision + # hooks + + dialect = sa_url.make_url(db_url).get_dialect() + + dialect.load_provisioning() + + if follower_ident: + db_url = follower_url_from_main(db_url, follower_ident) + db_opts = {} + update_db_opts(db_url, db_opts, options) + db_opts["scope"] = "global" + eng = engines.testing_engine(db_url, db_opts) + post_configure_engine(db_url, eng, follower_ident) + eng.connect().close() + + cfg = config.Config.register(eng, db_opts, options, file_config) + + # a symbolic name that tests can use if they need to disambiguate + # names across databases + if follower_ident: + config.ident = follower_ident + + if follower_ident: + configure_follower(cfg, follower_ident) + return cfg + + +def drop_follower_db(follower_ident): + for cfg in _configs_for_db_operation(): + log.info("DROP database %s, URI %r", follower_ident, cfg.db.url) + drop_db(cfg, cfg.db, follower_ident) + + +def generate_db_urls(db_urls, extra_drivers): + """Generate a set of URLs to test given configured URLs plus additional + driver names. + + Given:: + + --dburi postgresql://db1 \ + --dburi postgresql://db2 \ + --dburi postgresql://db2 \ + --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true + + Noting that the default postgresql driver is psycopg2, the output + would be:: + + postgresql+psycopg2://db1 + postgresql+asyncpg://db1 + postgresql+psycopg2://db2 + postgresql+psycopg2://db3 + + That is, for the driver in a --dburi, we want to keep that and use that + driver for each URL it's part of . For a driver that is only + in --dbdrivers, we want to use it just once for one of the URLs. + for a driver that is both coming from --dburi as well as --dbdrivers, + we want to keep it in that dburi. + + Driver specific query options can be specified by added them to the + driver name. For example, to enable the async fallback option for + asyncpg:: + + --dburi postgresql://db1 \ + --dbdriver=asyncpg?async_fallback=true + + """ + urls = set() + + backend_to_driver_we_already_have = collections.defaultdict(set) + + urls_plus_dialects = [ + (url_obj, url_obj.get_dialect()) + for url_obj in [sa_url.make_url(db_url) for db_url in db_urls] + ] + + for url_obj, dialect in urls_plus_dialects: + # use get_driver_name instead of dialect.driver to account for + # "_async" virtual drivers like oracledb and psycopg + driver_name = url_obj.get_driver_name() + backend_to_driver_we_already_have[dialect.name].add(driver_name) + + backend_to_driver_we_need = {} + + for url_obj, dialect in urls_plus_dialects: + backend = dialect.name + dialect.load_provisioning() + + if backend not in backend_to_driver_we_need: + backend_to_driver_we_need[backend] = extra_per_backend = set( + extra_drivers + ).difference(backend_to_driver_we_already_have[backend]) + else: + extra_per_backend = backend_to_driver_we_need[backend] + + for driver_url in _generate_driver_urls(url_obj, extra_per_backend): + if driver_url in urls: + continue + urls.add(driver_url) + yield driver_url + + +def _generate_driver_urls(url, extra_drivers): + main_driver = url.get_driver_name() + extra_drivers.discard(main_driver) + + url = generate_driver_url(url, main_driver, "") + yield url + + for drv in list(extra_drivers): + if "?" in drv: + driver_only, query_str = drv.split("?", 1) + + else: + driver_only = drv + query_str = None + + new_url = generate_driver_url(url, driver_only, query_str) + if new_url: + extra_drivers.remove(drv) + + yield new_url + + +@register.init +def generate_driver_url(url, driver, query_str): + backend = url.get_backend_name() + + new_url = url.set( + drivername="%s+%s" % (backend, driver), + ) + if query_str: + new_url = new_url.update_query_string(query_str) + + try: + new_url.get_dialect() + except exc.NoSuchModuleError: + return None + else: + return new_url + + +def _configs_for_db_operation(): + hosts = set() + + for cfg in config.Config.all_configs(): + cfg.db.dispose() + + for cfg in config.Config.all_configs(): + url = cfg.db.url + backend = url.get_backend_name() + host_conf = (backend, url.username, url.host, url.database) + + if host_conf not in hosts: + yield cfg + hosts.add(host_conf) + + for cfg in config.Config.all_configs(): + cfg.db.dispose() + + +@register.init +def drop_all_schema_objects_pre_tables(cfg, eng): + pass + + +@register.init +def drop_all_schema_objects_post_tables(cfg, eng): + pass + + +def drop_all_schema_objects(cfg, eng): + drop_all_schema_objects_pre_tables(cfg, eng) + + drop_views(cfg, eng) + + if config.requirements.materialized_views.enabled: + drop_materialized_views(cfg, eng) + + inspector = inspect(eng) + + consider_schemas = (None,) + if config.requirements.schemas.enabled_for_config(cfg): + consider_schemas += (cfg.test_schema, cfg.test_schema_2) + util.drop_all_tables(eng, inspector, consider_schemas=consider_schemas) + + drop_all_schema_objects_post_tables(cfg, eng) + + if config.requirements.sequences.enabled_for_config(cfg): + with eng.begin() as conn: + for seq in inspector.get_sequence_names(): + conn.execute(ddl.DropSequence(schema.Sequence(seq))) + if config.requirements.schemas.enabled_for_config(cfg): + for schema_name in [cfg.test_schema, cfg.test_schema_2]: + for seq in inspector.get_sequence_names( + schema=schema_name + ): + conn.execute( + ddl.DropSequence( + schema.Sequence(seq, schema=schema_name) + ) + ) + + +def drop_views(cfg, eng): + inspector = inspect(eng) + + try: + view_names = inspector.get_view_names() + except NotImplementedError: + pass + else: + with eng.begin() as conn: + for vname in view_names: + conn.execute( + ddl._DropView(schema.Table(vname, schema.MetaData())) + ) + + if config.requirements.schemas.enabled_for_config(cfg): + try: + view_names = inspector.get_view_names(schema=cfg.test_schema) + except NotImplementedError: + pass + else: + with eng.begin() as conn: + for vname in view_names: + conn.execute( + ddl._DropView( + schema.Table( + vname, + schema.MetaData(), + schema=cfg.test_schema, + ) + ) + ) + + +def drop_materialized_views(cfg, eng): + inspector = inspect(eng) + + mview_names = inspector.get_materialized_view_names() + + with eng.begin() as conn: + for vname in mview_names: + conn.exec_driver_sql(f"DROP MATERIALIZED VIEW {vname}") + + if config.requirements.schemas.enabled_for_config(cfg): + mview_names = inspector.get_materialized_view_names( + schema=cfg.test_schema + ) + with eng.begin() as conn: + for vname in mview_names: + conn.exec_driver_sql( + f"DROP MATERIALIZED VIEW {cfg.test_schema}.{vname}" + ) + + +@register.init +def create_db(cfg, eng, ident): + """Dynamically create a database for testing. + + Used when a test run will employ multiple processes, e.g., when run + via `tox` or `pytest -n4`. + """ + raise NotImplementedError( + "no DB creation routine for cfg: %s" % (eng.url,) + ) + + +@register.init +def drop_db(cfg, eng, ident): + """Drop a database that we dynamically created for testing.""" + raise NotImplementedError("no DB drop routine for cfg: %s" % (eng.url,)) + + +def _adapt_update_db_opts(fn): + insp = util.inspect_getfullargspec(fn) + if len(insp.args) == 3: + return fn + else: + return lambda db_url, db_opts, _options: fn(db_url, db_opts) + + +@register.init_decorator(_adapt_update_db_opts) +def update_db_opts(db_url, db_opts, options): + """Set database options (db_opts) for a test database that we created.""" + + +@register.init +def post_configure_engine(url, engine, follower_ident): + """Perform extra steps after configuring an engine for testing. + + (For the internal dialects, currently only used by sqlite, oracle) + """ + + +@register.init +def follower_url_from_main(url, ident): + """Create a connection URL for a dynamically-created test database. + + :param url: the connection URL specified when the test run was invoked + :param ident: the pytest-xdist "worker identifier" to be used as the + database name + """ + url = sa_url.make_url(url) + return url.set(database=ident) + + +@register.init +def configure_follower(cfg, ident): + """Create dialect-specific config settings for a follower database.""" + pass + + +@register.init +def run_reap_dbs(url, ident): + """Remove databases that were created during the test process, after the + process has ended. + + This is an optional step that is invoked for certain backends that do not + reliably release locks on the database as long as a process is still in + use. For the internal dialects, this is currently only necessary for + mssql and oracle. + """ + + +def reap_dbs(idents_file): + log.info("Reaping databases...") + + urls = collections.defaultdict(set) + idents = collections.defaultdict(set) + dialects = {} + + with open(idents_file) as file_: + for line in file_: + line = line.strip() + db_name, db_url = line.split(" ") + url_obj = sa_url.make_url(db_url) + if db_name not in dialects: + dialects[db_name] = url_obj.get_dialect() + dialects[db_name].load_provisioning() + url_key = (url_obj.get_backend_name(), url_obj.host) + urls[url_key].add(db_url) + idents[url_key].add(db_name) + + for url_key in urls: + url = list(urls[url_key])[0] + ident = idents[url_key] + run_reap_dbs(url, ident) + + +@register.init +def temp_table_keyword_args(cfg, eng): + """Specify keyword arguments for creating a temporary Table. + + Dialect-specific implementations of this method will return the + kwargs that are passed to the Table method when creating a temporary + table for testing, e.g., in the define_temp_tables method of the + ComponentReflectionTest class in suite/test_reflection.py + """ + raise NotImplementedError( + "no temp table keyword args routine for cfg: %s" % (eng.url,) + ) + + +@register.init +def prepare_for_drop_tables(config, connection): + pass + + +@register.init +def stop_test_class_outside_fixtures(config, db, testcls): + pass + + +@register.init +def get_temp_table_name(cfg, eng, base_name): + """Specify table name for creating a temporary Table. + + Dialect-specific implementations of this method will return the + name to use when creating a temporary table for testing, + e.g., in the define_temp_tables method of the + ComponentReflectionTest class in suite/test_reflection.py + + Default to just the base name since that's what most dialects will + use. The mssql dialect's implementation will need a "#" prepended. + """ + return base_name + + +@register.init +def set_default_schema_on_connection(cfg, dbapi_connection, schema_name): + raise NotImplementedError( + "backend does not implement a schema name set function: %s" + % (cfg.db.url,) + ) + + +@register.init +def upsert( + cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False +): + """return the backends insert..on conflict / on dupe etc. construct. + + while we should add a backend-neutral upsert construct as well, such as + insert().upsert(), it's important that we continue to test the + backend-specific insert() constructs since if we do implement + insert().upsert(), that would be using a different codepath for the things + we need to test like insertmanyvalues, etc. + + """ + raise NotImplementedError( + f"backend does not include an upsert implementation: {cfg.db.url}" + ) + + +@register.init +def normalize_sequence(cfg, sequence): + """Normalize sequence parameters for dialect that don't start with 1 + by default. + + The default implementation does nothing + """ + return sequence diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/requirements.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/requirements.py new file mode 100644 index 0000000..31aac74 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/requirements.py @@ -0,0 +1,1783 @@ +# testing/requirements.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""Global database feature support policy. + +Provides decorators to mark tests requiring specific feature support from the +target database. + +External dialect test suites should subclass SuiteRequirements +to provide specific inclusion/exclusions. + +""" + +from __future__ import annotations + +import platform + +from . import asyncio as _test_asyncio +from . import exclusions +from .exclusions import only_on +from .. import create_engine +from .. import util +from ..pool import QueuePool + + +class Requirements: + pass + + +class SuiteRequirements(Requirements): + @property + def create_table(self): + """target platform can emit basic CreateTable DDL.""" + + return exclusions.open() + + @property + def drop_table(self): + """target platform can emit basic DropTable DDL.""" + + return exclusions.open() + + @property + def table_ddl_if_exists(self): + """target platform supports IF NOT EXISTS / IF EXISTS for tables.""" + + return exclusions.closed() + + @property + def index_ddl_if_exists(self): + """target platform supports IF NOT EXISTS / IF EXISTS for indexes.""" + + return exclusions.closed() + + @property + def uuid_data_type(self): + """Return databases that support the UUID datatype.""" + + return exclusions.closed() + + @property + def foreign_keys(self): + """Target database must support foreign keys.""" + + return exclusions.open() + + @property + def foreign_keys_reflect_as_index(self): + """Target database creates an index that's reflected for + foreign keys.""" + + return exclusions.closed() + + @property + def unique_index_reflect_as_unique_constraints(self): + """Target database reflects unique indexes as unique constrains.""" + + return exclusions.closed() + + @property + def unique_constraints_reflect_as_index(self): + """Target database reflects unique constraints as indexes.""" + + return exclusions.closed() + + @property + def table_value_constructor(self): + """Database / dialect supports a query like:: + + SELECT * FROM VALUES ( (c1, c2), (c1, c2), ...) + AS some_table(col1, col2) + + SQLAlchemy generates this with the :func:`_sql.values` function. + + """ + return exclusions.closed() + + @property + def standard_cursor_sql(self): + """Target database passes SQL-92 style statements to cursor.execute() + when a statement like select() or insert() is run. + + A very small portion of dialect-level tests will ensure that certain + conditions are present in SQL strings, and these tests use very basic + SQL that will work on any SQL-like platform in order to assert results. + + It's normally a given for any pep-249 DBAPI that a statement like + "SELECT id, name FROM table WHERE some_table.id=5" will work. + However, there are dialects that don't actually produce SQL Strings + and instead may work with symbolic objects instead, or dialects that + aren't working with SQL, so for those this requirement can be marked + as excluded. + + """ + + return exclusions.open() + + @property + def on_update_cascade(self): + """target database must support ON UPDATE..CASCADE behavior in + foreign keys.""" + + return exclusions.open() + + @property + def non_updating_cascade(self): + """target database must *not* support ON UPDATE..CASCADE behavior in + foreign keys.""" + return exclusions.closed() + + @property + def deferrable_fks(self): + return exclusions.closed() + + @property + def on_update_or_deferrable_fks(self): + # TODO: exclusions should be composable, + # somehow only_if([x, y]) isn't working here, negation/conjunctions + # getting confused. + return exclusions.only_if( + lambda: self.on_update_cascade.enabled + or self.deferrable_fks.enabled + ) + + @property + def queue_pool(self): + """target database is using QueuePool""" + + def go(config): + return isinstance(config.db.pool, QueuePool) + + return exclusions.only_if(go) + + @property + def self_referential_foreign_keys(self): + """Target database must support self-referential foreign keys.""" + + return exclusions.open() + + @property + def foreign_key_ddl(self): + """Target database must support the DDL phrases for FOREIGN KEY.""" + + return exclusions.open() + + @property + def named_constraints(self): + """target database must support names for constraints.""" + + return exclusions.open() + + @property + def implicitly_named_constraints(self): + """target database must apply names to unnamed constraints.""" + + return exclusions.open() + + @property + def unusual_column_name_characters(self): + """target database allows column names that have unusual characters + in them, such as dots, spaces, slashes, or percent signs. + + The column names are as always in such a case quoted, however the + DB still needs to support those characters in the name somehow. + + """ + return exclusions.open() + + @property + def subqueries(self): + """Target database must support subqueries.""" + + return exclusions.open() + + @property + def offset(self): + """target database can render OFFSET, or an equivalent, in a + SELECT. + """ + + return exclusions.open() + + @property + def bound_limit_offset(self): + """target database can render LIMIT and/or OFFSET using a bound + parameter + """ + + return exclusions.open() + + @property + def sql_expression_limit_offset(self): + """target database can render LIMIT and/or OFFSET with a complete + SQL expression, such as one that uses the addition operator. + parameter + """ + + return exclusions.open() + + @property + def parens_in_union_contained_select_w_limit_offset(self): + """Target database must support parenthesized SELECT in UNION + when LIMIT/OFFSET is specifically present. + + E.g. (SELECT ...) UNION (SELECT ..) + + This is known to fail on SQLite. + + """ + return exclusions.open() + + @property + def parens_in_union_contained_select_wo_limit_offset(self): + """Target database must support parenthesized SELECT in UNION + when OFFSET/LIMIT is specifically not present. + + E.g. (SELECT ... LIMIT ..) UNION (SELECT .. OFFSET ..) + + This is known to fail on SQLite. It also fails on Oracle + because without LIMIT/OFFSET, there is currently no step that + creates an additional subquery. + + """ + return exclusions.open() + + @property + def boolean_col_expressions(self): + """Target database must support boolean expressions as columns""" + + return exclusions.closed() + + @property + def nullable_booleans(self): + """Target database allows boolean columns to store NULL.""" + + return exclusions.open() + + @property + def nullsordering(self): + """Target backends that support nulls ordering.""" + + return exclusions.closed() + + @property + def standalone_binds(self): + """target database/driver supports bound parameters as column + expressions without being in the context of a typed column. + """ + return exclusions.open() + + @property + def standalone_null_binds_whereclause(self): + """target database/driver supports bound parameters with NULL in the + WHERE clause, in situations where it has to be typed. + + """ + return exclusions.open() + + @property + def intersect(self): + """Target database must support INTERSECT or equivalent.""" + return exclusions.closed() + + @property + def except_(self): + """Target database must support EXCEPT or equivalent (i.e. MINUS).""" + return exclusions.closed() + + @property + def window_functions(self): + """Target database must support window functions.""" + return exclusions.closed() + + @property + def ctes(self): + """Target database supports CTEs""" + + return exclusions.closed() + + @property + def ctes_with_update_delete(self): + """target database supports CTES that ride on top of a normal UPDATE + or DELETE statement which refers to the CTE in a correlated subquery. + + """ + + return exclusions.closed() + + @property + def ctes_on_dml(self): + """target database supports CTES which consist of INSERT, UPDATE + or DELETE *within* the CTE, e.g. WITH x AS (UPDATE....)""" + + return exclusions.closed() + + @property + def autoincrement_insert(self): + """target platform generates new surrogate integer primary key values + when insert() is executed, excluding the pk column.""" + + return exclusions.open() + + @property + def fetch_rows_post_commit(self): + """target platform will allow cursor.fetchone() to proceed after a + COMMIT. + + Typically this refers to an INSERT statement with RETURNING which + is invoked within "autocommit". If the row can be returned + after the autocommit, then this rule can be open. + + """ + + return exclusions.open() + + @property + def group_by_complex_expression(self): + """target platform supports SQL expressions in GROUP BY + + e.g. + + SELECT x + y AS somelabel FROM table GROUP BY x + y + + """ + + return exclusions.open() + + @property + def sane_rowcount(self): + return exclusions.skip_if( + lambda config: not config.db.dialect.supports_sane_rowcount, + "driver doesn't support 'sane' rowcount", + ) + + @property + def sane_multi_rowcount(self): + return exclusions.fails_if( + lambda config: not config.db.dialect.supports_sane_multi_rowcount, + "driver %(driver)s %(doesnt_support)s 'sane' multi row count", + ) + + @property + def sane_rowcount_w_returning(self): + return exclusions.fails_if( + lambda config: not ( + config.db.dialect.supports_sane_rowcount_returning + ), + "driver doesn't support 'sane' rowcount when returning is on", + ) + + @property + def empty_inserts(self): + """target platform supports INSERT with no values, i.e. + INSERT DEFAULT VALUES or equivalent.""" + + return exclusions.only_if( + lambda config: config.db.dialect.supports_empty_insert + or config.db.dialect.supports_default_values + or config.db.dialect.supports_default_metavalue, + "empty inserts not supported", + ) + + @property + def empty_inserts_executemany(self): + """target platform supports INSERT with no values, i.e. + INSERT DEFAULT VALUES or equivalent, within executemany()""" + + return self.empty_inserts + + @property + def insert_from_select(self): + """target platform supports INSERT from a SELECT.""" + + return exclusions.open() + + @property + def delete_returning(self): + """target platform supports DELETE ... RETURNING.""" + + return exclusions.only_if( + lambda config: config.db.dialect.delete_returning, + "%(database)s %(does_support)s 'DELETE ... RETURNING'", + ) + + @property + def insert_returning(self): + """target platform supports INSERT ... RETURNING.""" + + return exclusions.only_if( + lambda config: config.db.dialect.insert_returning, + "%(database)s %(does_support)s 'INSERT ... RETURNING'", + ) + + @property + def update_returning(self): + """target platform supports UPDATE ... RETURNING.""" + + return exclusions.only_if( + lambda config: config.db.dialect.update_returning, + "%(database)s %(does_support)s 'UPDATE ... RETURNING'", + ) + + @property + def insert_executemany_returning(self): + """target platform supports RETURNING when INSERT is used with + executemany(), e.g. multiple parameter sets, indicating + as many rows come back as do parameter sets were passed. + + """ + + return exclusions.only_if( + lambda config: config.db.dialect.insert_executemany_returning, + "%(database)s %(does_support)s 'RETURNING of " + "multiple rows with INSERT executemany'", + ) + + @property + def insertmanyvalues(self): + return exclusions.only_if( + lambda config: config.db.dialect.supports_multivalues_insert + and config.db.dialect.insert_returning + and config.db.dialect.use_insertmanyvalues, + "%(database)s %(does_support)s 'insertmanyvalues functionality", + ) + + @property + def tuple_in(self): + """Target platform supports the syntax + "(x, y) IN ((x1, y1), (x2, y2), ...)" + """ + + return exclusions.closed() + + @property + def tuple_in_w_empty(self): + """Target platform tuple IN w/ empty set""" + return self.tuple_in + + @property + def duplicate_names_in_cursor_description(self): + """target platform supports a SELECT statement that has + the same name repeated more than once in the columns list.""" + + return exclusions.open() + + @property + def denormalized_names(self): + """Target database must have 'denormalized', i.e. + UPPERCASE as case insensitive names.""" + + return exclusions.skip_if( + lambda config: not config.db.dialect.requires_name_normalize, + "Backend does not require denormalized names.", + ) + + @property + def multivalues_inserts(self): + """target database must support multiple VALUES clauses in an + INSERT statement.""" + + return exclusions.skip_if( + lambda config: not config.db.dialect.supports_multivalues_insert, + "Backend does not support multirow inserts.", + ) + + @property + def implements_get_lastrowid(self): + """target dialect implements the executioncontext.get_lastrowid() + method without reliance on RETURNING. + + """ + return exclusions.open() + + @property + def arraysize(self): + """dialect includes the required pep-249 attribute + ``cursor.arraysize``""" + + return exclusions.open() + + @property + def emulated_lastrowid(self): + """target dialect retrieves cursor.lastrowid, or fetches + from a database-side function after an insert() construct executes, + within the get_lastrowid() method. + + Only dialects that "pre-execute", or need RETURNING to get last + inserted id, would return closed/fail/skip for this. + + """ + return exclusions.closed() + + @property + def emulated_lastrowid_even_with_sequences(self): + """target dialect retrieves cursor.lastrowid or an equivalent + after an insert() construct executes, even if the table has a + Sequence on it. + + """ + return exclusions.closed() + + @property + def dbapi_lastrowid(self): + """target platform includes a 'lastrowid' accessor on the DBAPI + cursor object. + + """ + return exclusions.closed() + + @property + def views(self): + """Target database must support VIEWs.""" + + return exclusions.closed() + + @property + def schemas(self): + """Target database must support external schemas, and have one + named 'test_schema'.""" + + return only_on(lambda config: config.db.dialect.supports_schemas) + + @property + def cross_schema_fk_reflection(self): + """target system must support reflection of inter-schema + foreign keys""" + return exclusions.closed() + + @property + def foreign_key_constraint_name_reflection(self): + """Target supports reflection of FOREIGN KEY constraints and + will return the name of the constraint that was used in the + "CONSTRAINT <name> FOREIGN KEY" DDL. + + MySQL prior to version 8 and MariaDB prior to version 10.5 + don't support this. + + """ + return exclusions.closed() + + @property + def implicit_default_schema(self): + """target system has a strong concept of 'default' schema that can + be referred to implicitly. + + basically, PostgreSQL. + + """ + return exclusions.closed() + + @property + def default_schema_name_switch(self): + """target dialect implements provisioning module including + set_default_schema_on_connection""" + + return exclusions.closed() + + @property + def server_side_cursors(self): + """Target dialect must support server side cursors.""" + + return exclusions.only_if( + [lambda config: config.db.dialect.supports_server_side_cursors], + "no server side cursors support", + ) + + @property + def sequences(self): + """Target database must support SEQUENCEs.""" + + return exclusions.only_if( + [lambda config: config.db.dialect.supports_sequences], + "no sequence support", + ) + + @property + def no_sequences(self): + """the opposite of "sequences", DB does not support sequences at + all.""" + + return exclusions.NotPredicate(self.sequences) + + @property + def sequences_optional(self): + """Target database supports sequences, but also optionally + as a means of generating new PK values.""" + + return exclusions.only_if( + [ + lambda config: config.db.dialect.supports_sequences + and config.db.dialect.sequences_optional + ], + "no sequence support, or sequences not optional", + ) + + @property + def supports_lastrowid(self): + """target database / driver supports cursor.lastrowid as a means + of retrieving the last inserted primary key value. + + note that if the target DB supports sequences also, this is still + assumed to work. This is a new use case brought on by MariaDB 10.3. + + """ + return exclusions.only_if( + [lambda config: config.db.dialect.postfetch_lastrowid] + ) + + @property + def no_lastrowid_support(self): + """the opposite of supports_lastrowid""" + return exclusions.only_if( + [lambda config: not config.db.dialect.postfetch_lastrowid] + ) + + @property + def reflects_pk_names(self): + return exclusions.closed() + + @property + def table_reflection(self): + """target database has general support for table reflection""" + return exclusions.open() + + @property + def reflect_tables_no_columns(self): + """target database supports creation and reflection of tables with no + columns, or at least tables that seem to have no columns.""" + + return exclusions.closed() + + @property + def comment_reflection(self): + """Indicates if the database support table comment reflection""" + return exclusions.closed() + + @property + def comment_reflection_full_unicode(self): + """Indicates if the database support table comment reflection in the + full unicode range, including emoji etc. + """ + return exclusions.closed() + + @property + def constraint_comment_reflection(self): + """indicates if the database support comments on constraints + and their reflection""" + return exclusions.closed() + + @property + def view_column_reflection(self): + """target database must support retrieval of the columns in a view, + similarly to how a table is inspected. + + This does not include the full CREATE VIEW definition. + + """ + return self.views + + @property + def view_reflection(self): + """target database must support inspection of the full CREATE VIEW + definition.""" + return self.views + + @property + def schema_reflection(self): + return self.schemas + + @property + def schema_create_delete(self): + """target database supports schema create and dropped with + 'CREATE SCHEMA' and 'DROP SCHEMA'""" + return exclusions.closed() + + @property + def primary_key_constraint_reflection(self): + return exclusions.open() + + @property + def foreign_key_constraint_reflection(self): + return exclusions.open() + + @property + def foreign_key_constraint_option_reflection_ondelete(self): + return exclusions.closed() + + @property + def fk_constraint_option_reflection_ondelete_restrict(self): + return exclusions.closed() + + @property + def fk_constraint_option_reflection_ondelete_noaction(self): + return exclusions.closed() + + @property + def foreign_key_constraint_option_reflection_onupdate(self): + return exclusions.closed() + + @property + def fk_constraint_option_reflection_onupdate_restrict(self): + return exclusions.closed() + + @property + def temp_table_reflection(self): + return exclusions.open() + + @property + def temp_table_reflect_indexes(self): + return self.temp_table_reflection + + @property + def temp_table_names(self): + """target dialect supports listing of temporary table names""" + return exclusions.closed() + + @property + def has_temp_table(self): + """target dialect supports checking a single temp table name""" + return exclusions.closed() + + @property + def temporary_tables(self): + """target database supports temporary tables""" + return exclusions.open() + + @property + def temporary_views(self): + """target database supports temporary views""" + return exclusions.closed() + + @property + def index_reflection(self): + return exclusions.open() + + @property + def index_reflects_included_columns(self): + return exclusions.closed() + + @property + def indexes_with_ascdesc(self): + """target database supports CREATE INDEX with per-column ASC/DESC.""" + return exclusions.open() + + @property + def reflect_indexes_with_ascdesc(self): + """target database supports reflecting INDEX with per-column + ASC/DESC.""" + return exclusions.open() + + @property + def reflect_indexes_with_ascdesc_as_expression(self): + """target database supports reflecting INDEX with per-column + ASC/DESC but reflects them as expressions (like oracle).""" + return exclusions.closed() + + @property + def indexes_with_expressions(self): + """target database supports CREATE INDEX against SQL expressions.""" + return exclusions.closed() + + @property + def reflect_indexes_with_expressions(self): + """target database supports reflection of indexes with + SQL expressions.""" + return exclusions.closed() + + @property + def unique_constraint_reflection(self): + """target dialect supports reflection of unique constraints""" + return exclusions.open() + + @property + def check_constraint_reflection(self): + """target dialect supports reflection of check constraints""" + return exclusions.closed() + + @property + def duplicate_key_raises_integrity_error(self): + """target dialect raises IntegrityError when reporting an INSERT + with a primary key violation. (hint: it should) + + """ + return exclusions.open() + + @property + def unbounded_varchar(self): + """Target database must support VARCHAR with no length""" + + return exclusions.open() + + @property + def unicode_data_no_special_types(self): + """Target database/dialect can receive / deliver / compare data with + non-ASCII characters in plain VARCHAR, TEXT columns, without the need + for special "national" datatypes like NVARCHAR or similar. + + """ + return exclusions.open() + + @property + def unicode_data(self): + """Target database/dialect must support Python unicode objects with + non-ASCII characters represented, delivered as bound parameters + as well as in result rows. + + """ + return exclusions.open() + + @property + def unicode_ddl(self): + """Target driver must support some degree of non-ascii symbol + names. + """ + return exclusions.closed() + + @property + def symbol_names_w_double_quote(self): + """Target driver can create tables with a name like 'some " table'""" + return exclusions.open() + + @property + def datetime_interval(self): + """target dialect supports rendering of a datetime.timedelta as a + literal string, e.g. via the TypeEngine.literal_processor() method. + + """ + return exclusions.closed() + + @property + def datetime_literals(self): + """target dialect supports rendering of a date, time, or datetime as a + literal string, e.g. via the TypeEngine.literal_processor() method. + + """ + + return exclusions.closed() + + @property + def datetime(self): + """target dialect supports representation of Python + datetime.datetime() objects.""" + + return exclusions.open() + + @property + def datetime_timezone(self): + """target dialect supports representation of Python + datetime.datetime() with tzinfo with DateTime(timezone=True).""" + + return exclusions.closed() + + @property + def time_timezone(self): + """target dialect supports representation of Python + datetime.time() with tzinfo with Time(timezone=True).""" + + return exclusions.closed() + + @property + def date_implicit_bound(self): + """target dialect when given a date object will bind it such + that the database server knows the object is a date, and not + a plain string. + + """ + return exclusions.open() + + @property + def time_implicit_bound(self): + """target dialect when given a time object will bind it such + that the database server knows the object is a time, and not + a plain string. + + """ + return exclusions.open() + + @property + def datetime_implicit_bound(self): + """target dialect when given a datetime object will bind it such + that the database server knows the object is a datetime, and not + a plain string. + + """ + return exclusions.open() + + @property + def datetime_microseconds(self): + """target dialect supports representation of Python + datetime.datetime() with microsecond objects.""" + + return exclusions.open() + + @property + def timestamp_microseconds(self): + """target dialect supports representation of Python + datetime.datetime() with microsecond objects but only + if TIMESTAMP is used.""" + return exclusions.closed() + + @property + def timestamp_microseconds_implicit_bound(self): + """target dialect when given a datetime object which also includes + a microseconds portion when using the TIMESTAMP data type + will bind it such that the database server knows + the object is a datetime with microseconds, and not a plain string. + + """ + return self.timestamp_microseconds + + @property + def datetime_historic(self): + """target dialect supports representation of Python + datetime.datetime() objects with historic (pre 1970) values.""" + + return exclusions.closed() + + @property + def date(self): + """target dialect supports representation of Python + datetime.date() objects.""" + + return exclusions.open() + + @property + def date_coerces_from_datetime(self): + """target dialect accepts a datetime object as the target + of a date column.""" + + return exclusions.open() + + @property + def date_historic(self): + """target dialect supports representation of Python + datetime.datetime() objects with historic (pre 1970) values.""" + + return exclusions.closed() + + @property + def time(self): + """target dialect supports representation of Python + datetime.time() objects.""" + + return exclusions.open() + + @property + def time_microseconds(self): + """target dialect supports representation of Python + datetime.time() with microsecond objects.""" + + return exclusions.open() + + @property + def binary_comparisons(self): + """target database/driver can allow BLOB/BINARY fields to be compared + against a bound parameter value. + """ + + return exclusions.open() + + @property + def binary_literals(self): + """target backend supports simple binary literals, e.g. an + expression like:: + + SELECT CAST('foo' AS BINARY) + + Where ``BINARY`` is the type emitted from :class:`.LargeBinary`, + e.g. it could be ``BLOB`` or similar. + + Basically fails on Oracle. + + """ + + return exclusions.open() + + @property + def autocommit(self): + """target dialect supports 'AUTOCOMMIT' as an isolation_level""" + return exclusions.closed() + + @property + def isolation_level(self): + """target dialect supports general isolation level settings. + + Note that this requirement, when enabled, also requires that + the get_isolation_levels() method be implemented. + + """ + return exclusions.closed() + + def get_isolation_levels(self, config): + """Return a structure of supported isolation levels for the current + testing dialect. + + The structure indicates to the testing suite what the expected + "default" isolation should be, as well as the other values that + are accepted. The dictionary has two keys, "default" and "supported". + The "supported" key refers to a list of all supported levels and + it should include AUTOCOMMIT if the dialect supports it. + + If the :meth:`.DefaultRequirements.isolation_level` requirement is + not open, then this method has no return value. + + E.g.:: + + >>> testing.requirements.get_isolation_levels() + { + "default": "READ_COMMITTED", + "supported": [ + "SERIALIZABLE", "READ UNCOMMITTED", + "READ COMMITTED", "REPEATABLE READ", + "AUTOCOMMIT" + ] + } + """ + with config.db.connect() as conn: + try: + supported = conn.dialect.get_isolation_level_values( + conn.connection.dbapi_connection + ) + except NotImplementedError: + return None + else: + return { + "default": conn.dialect.default_isolation_level, + "supported": supported, + } + + @property + def get_isolation_level_values(self): + """target dialect supports the + :meth:`_engine.Dialect.get_isolation_level_values` + method added in SQLAlchemy 2.0. + + """ + + def go(config): + with config.db.connect() as conn: + try: + conn.dialect.get_isolation_level_values( + conn.connection.dbapi_connection + ) + except NotImplementedError: + return False + else: + return True + + return exclusions.only_if(go) + + @property + def dialect_level_isolation_level_param(self): + """test that the dialect allows the 'isolation_level' argument + to be handled by DefaultDialect""" + + def go(config): + try: + e = create_engine( + config.db.url, isolation_level="READ COMMITTED" + ) + except: + return False + else: + return ( + e.dialect._on_connect_isolation_level == "READ COMMITTED" + ) + + return exclusions.only_if(go) + + @property + def json_type(self): + """target platform implements a native JSON type.""" + + return exclusions.closed() + + @property + def json_array_indexes(self): + """target platform supports numeric array indexes + within a JSON structure""" + + return self.json_type + + @property + def json_index_supplementary_unicode_element(self): + return exclusions.open() + + @property + def legacy_unconditional_json_extract(self): + """Backend has a JSON_EXTRACT or similar function that returns a + valid JSON string in all cases. + + Used to test a legacy feature and is not needed. + + """ + return exclusions.closed() + + @property + def precision_numerics_general(self): + """target backend has general support for moderately high-precision + numerics.""" + return exclusions.open() + + @property + def precision_numerics_enotation_small(self): + """target backend supports Decimal() objects using E notation + to represent very small values.""" + return exclusions.closed() + + @property + def precision_numerics_enotation_large(self): + """target backend supports Decimal() objects using E notation + to represent very large values.""" + return exclusions.open() + + @property + def precision_numerics_many_significant_digits(self): + """target backend supports values with many digits on both sides, + such as 319438950232418390.273596, 87673.594069654243 + + """ + return exclusions.closed() + + @property + def cast_precision_numerics_many_significant_digits(self): + """same as precision_numerics_many_significant_digits but within the + context of a CAST statement (hello MySQL) + + """ + return self.precision_numerics_many_significant_digits + + @property + def implicit_decimal_binds(self): + """target backend will return a selected Decimal as a Decimal, not + a string. + + e.g.:: + + expr = decimal.Decimal("15.7563") + + value = e.scalar( + select(literal(expr)) + ) + + assert value == expr + + See :ticket:`4036` + + """ + + return exclusions.open() + + @property + def numeric_received_as_decimal_untyped(self): + """target backend will return result columns that are explicitly + against NUMERIC or similar precision-numeric datatypes (not including + FLOAT or INT types) as Python Decimal objects, and not as floats + or ints, including when no SQLAlchemy-side typing information is + associated with the statement (e.g. such as a raw SQL string). + + This should be enabled if either the DBAPI itself returns Decimal + objects, or if the dialect has set up DBAPI-specific return type + handlers such that Decimal objects come back automatically. + + """ + return exclusions.open() + + @property + def nested_aggregates(self): + """target database can select an aggregate from a subquery that's + also using an aggregate + + """ + return exclusions.open() + + @property + def recursive_fk_cascade(self): + """target database must support ON DELETE CASCADE on a self-referential + foreign key + + """ + return exclusions.open() + + @property + def precision_numerics_retains_significant_digits(self): + """A precision numeric type will return empty significant digits, + i.e. a value such as 10.000 will come back in Decimal form with + the .000 maintained.""" + + return exclusions.closed() + + @property + def infinity_floats(self): + """The Float type can persist and load float('inf'), float('-inf').""" + + return exclusions.closed() + + @property + def float_or_double_precision_behaves_generically(self): + return exclusions.closed() + + @property + def precision_generic_float_type(self): + """target backend will return native floating point numbers with at + least seven decimal places when using the generic Float type. + + """ + return exclusions.open() + + @property + def literal_float_coercion(self): + """target backend will return the exact float value 15.7563 + with only four significant digits from this statement: + + SELECT :param + + where :param is the Python float 15.7563 + + i.e. it does not return 15.75629997253418 + + """ + return exclusions.open() + + @property + def floats_to_four_decimals(self): + """target backend can return a floating-point number with four + significant digits (such as 15.7563) accurately + (i.e. without FP inaccuracies, such as 15.75629997253418). + + """ + return exclusions.open() + + @property + def fetch_null_from_numeric(self): + """target backend doesn't crash when you try to select a NUMERIC + value that has a value of NULL. + + Added to support Pyodbc bug #351. + """ + + return exclusions.open() + + @property + def float_is_numeric(self): + """target backend uses Numeric for Float/Dual""" + + return exclusions.open() + + @property + def text_type(self): + """Target database must support an unbounded Text() " + "type such as TEXT or CLOB""" + + return exclusions.open() + + @property + def empty_strings_varchar(self): + """target database can persist/return an empty string with a + varchar. + + """ + return exclusions.open() + + @property + def empty_strings_text(self): + """target database can persist/return an empty string with an + unbounded text.""" + + return exclusions.open() + + @property + def expressions_against_unbounded_text(self): + """target database supports use of an unbounded textual field in a + WHERE clause.""" + + return exclusions.open() + + @property + def selectone(self): + """target driver must support the literal statement 'select 1'""" + return exclusions.open() + + @property + def savepoints(self): + """Target database must support savepoints.""" + + return exclusions.closed() + + @property + def two_phase_transactions(self): + """Target database must support two-phase transactions.""" + + return exclusions.closed() + + @property + def update_from(self): + """Target must support UPDATE..FROM syntax""" + return exclusions.closed() + + @property + def delete_from(self): + """Target must support DELETE FROM..FROM or DELETE..USING syntax""" + return exclusions.closed() + + @property + def update_where_target_in_subquery(self): + """Target must support UPDATE (or DELETE) where the same table is + present in a subquery in the WHERE clause. + + This is an ANSI-standard syntax that apparently MySQL can't handle, + such as:: + + UPDATE documents SET flag=1 WHERE documents.title IN + (SELECT max(documents.title) AS title + FROM documents GROUP BY documents.user_id + ) + + """ + return exclusions.open() + + @property + def mod_operator_as_percent_sign(self): + """target database must use a plain percent '%' as the 'modulus' + operator.""" + return exclusions.closed() + + @property + def percent_schema_names(self): + """target backend supports weird identifiers with percent signs + in them, e.g. 'some % column'. + + this is a very weird use case but often has problems because of + DBAPIs that use python formatting. It's not a critical use + case either. + + """ + return exclusions.closed() + + @property + def order_by_col_from_union(self): + """target database supports ordering by a column from a SELECT + inside of a UNION + + E.g. (SELECT id, ...) UNION (SELECT id, ...) ORDER BY id + + """ + return exclusions.open() + + @property + def order_by_label_with_expression(self): + """target backend supports ORDER BY a column label within an + expression. + + Basically this:: + + select data as foo from test order by foo || 'bar' + + Lots of databases including PostgreSQL don't support this, + so this is off by default. + + """ + return exclusions.closed() + + @property + def order_by_collation(self): + def check(config): + try: + self.get_order_by_collation(config) + return False + except NotImplementedError: + return True + + return exclusions.skip_if(check) + + def get_order_by_collation(self, config): + raise NotImplementedError() + + @property + def unicode_connections(self): + """Target driver must support non-ASCII characters being passed at + all. + """ + return exclusions.open() + + @property + def graceful_disconnects(self): + """Target driver must raise a DBAPI-level exception, such as + InterfaceError, when the underlying connection has been closed + and the execute() method is called. + """ + return exclusions.open() + + @property + def independent_connections(self): + """ + Target must support simultaneous, independent database connections. + """ + return exclusions.open() + + @property + def independent_readonly_connections(self): + """ + Target must support simultaneous, independent database connections + that will be used in a readonly fashion. + + """ + return exclusions.open() + + @property + def skip_mysql_on_windows(self): + """Catchall for a large variety of MySQL on Windows failures""" + return exclusions.open() + + @property + def ad_hoc_engines(self): + """Test environment must allow ad-hoc engine/connection creation. + + DBs that scale poorly for many connections, even when closed, i.e. + Oracle, may use the "--low-connections" option which flags this + requirement as not present. + + """ + return exclusions.skip_if( + lambda config: config.options.low_connections + ) + + @property + def no_windows(self): + return exclusions.skip_if(self._running_on_windows()) + + def _running_on_windows(self): + return exclusions.LambdaPredicate( + lambda: platform.system() == "Windows", + description="running on Windows", + ) + + @property + def timing_intensive(self): + from . import config + + return config.add_to_marker.timing_intensive + + @property + def memory_intensive(self): + from . import config + + return config.add_to_marker.memory_intensive + + @property + def threading_with_mock(self): + """Mark tests that use threading and mock at the same time - stability + issues have been observed with coverage + + """ + return exclusions.skip_if( + lambda config: config.options.has_coverage, + "Stability issues with coverage", + ) + + @property + def sqlalchemy2_stubs(self): + def check(config): + try: + __import__("sqlalchemy-stubs.ext.mypy") + except ImportError: + return False + else: + return True + + return exclusions.only_if(check) + + @property + def no_sqlalchemy2_stubs(self): + def check(config): + try: + __import__("sqlalchemy-stubs.ext.mypy") + except ImportError: + return False + else: + return True + + return exclusions.skip_if(check) + + @property + def python38(self): + return exclusions.only_if( + lambda: util.py38, "Python 3.8 or above required" + ) + + @property + def python39(self): + return exclusions.only_if( + lambda: util.py39, "Python 3.9 or above required" + ) + + @property + def python310(self): + return exclusions.only_if( + lambda: util.py310, "Python 3.10 or above required" + ) + + @property + def python311(self): + return exclusions.only_if( + lambda: util.py311, "Python 3.11 or above required" + ) + + @property + def python312(self): + return exclusions.only_if( + lambda: util.py312, "Python 3.12 or above required" + ) + + @property + def cpython(self): + return exclusions.only_if( + lambda: util.cpython, "cPython interpreter needed" + ) + + @property + def is64bit(self): + return exclusions.only_if(lambda: util.is64bit, "64bit required") + + @property + def patch_library(self): + def check_lib(): + try: + __import__("patch") + except ImportError: + return False + else: + return True + + return exclusions.only_if(check_lib, "patch library needed") + + @property + def predictable_gc(self): + """target platform must remove all cycles unconditionally when + gc.collect() is called, as well as clean out unreferenced subclasses. + + """ + return self.cpython + + @property + def no_coverage(self): + """Test should be skipped if coverage is enabled. + + This is to block tests that exercise libraries that seem to be + sensitive to coverage, such as PostgreSQL notice logging. + + """ + return exclusions.skip_if( + lambda config: config.options.has_coverage, + "Issues observed when coverage is enabled", + ) + + def _has_mysql_on_windows(self, config): + return False + + def _has_mysql_fully_case_sensitive(self, config): + return False + + @property + def sqlite(self): + return exclusions.skip_if(lambda: not self._has_sqlite()) + + @property + def cextensions(self): + return exclusions.skip_if( + lambda: not util.has_compiled_ext(), + "Cython extensions not installed", + ) + + def _has_sqlite(self): + from sqlalchemy import create_engine + + try: + create_engine("sqlite://") + return True + except ImportError: + return False + + @property + def async_dialect(self): + """dialect makes use of await_() to invoke operations on the DBAPI.""" + + return exclusions.closed() + + @property + def asyncio(self): + return self.greenlet + + @property + def no_greenlet(self): + def go(config): + try: + import greenlet # noqa: F401 + except ImportError: + return True + else: + return False + + return exclusions.only_if(go) + + @property + def greenlet(self): + def go(config): + if not _test_asyncio.ENABLE_ASYNCIO: + return False + + try: + import greenlet # noqa: F401 + except ImportError: + return False + else: + return True + + return exclusions.only_if(go) + + @property + def computed_columns(self): + "Supports computed columns" + return exclusions.closed() + + @property + def computed_columns_stored(self): + "Supports computed columns with `persisted=True`" + return exclusions.closed() + + @property + def computed_columns_virtual(self): + "Supports computed columns with `persisted=False`" + return exclusions.closed() + + @property + def computed_columns_default_persisted(self): + """If the default persistence is virtual or stored when `persisted` + is omitted""" + return exclusions.closed() + + @property + def computed_columns_reflect_persisted(self): + """If persistence information is returned by the reflection of + computed columns""" + return exclusions.closed() + + @property + def supports_distinct_on(self): + """If a backend supports the DISTINCT ON in a select""" + return exclusions.closed() + + @property + def supports_is_distinct_from(self): + """Supports some form of "x IS [NOT] DISTINCT FROM y" construct. + Different dialects will implement their own flavour, e.g., + sqlite will emit "x IS NOT y" instead of "x IS DISTINCT FROM y". + + .. seealso:: + + :meth:`.ColumnOperators.is_distinct_from` + + """ + return exclusions.skip_if( + lambda config: not config.db.dialect.supports_is_distinct_from, + "driver doesn't support an IS DISTINCT FROM construct", + ) + + @property + def identity_columns(self): + """If a backend supports GENERATED { ALWAYS | BY DEFAULT } + AS IDENTITY""" + return exclusions.closed() + + @property + def identity_columns_standard(self): + """If a backend supports GENERATED { ALWAYS | BY DEFAULT } + AS IDENTITY with a standard syntax. + This is mainly to exclude MSSql. + """ + return exclusions.closed() + + @property + def regexp_match(self): + """backend supports the regexp_match operator.""" + return exclusions.closed() + + @property + def regexp_replace(self): + """backend supports the regexp_replace operator.""" + return exclusions.closed() + + @property + def fetch_first(self): + """backend supports the fetch first clause.""" + return exclusions.closed() + + @property + def fetch_percent(self): + """backend supports the fetch first clause with percent.""" + return exclusions.closed() + + @property + def fetch_ties(self): + """backend supports the fetch first clause with ties.""" + return exclusions.closed() + + @property + def fetch_no_order_by(self): + """backend supports the fetch first without order by""" + return exclusions.closed() + + @property + def fetch_offset_with_options(self): + """backend supports the offset when using fetch first with percent + or ties. basically this is "not mssql" + """ + return exclusions.closed() + + @property + def fetch_expression(self): + """backend supports fetch / offset with expression in them, like + + SELECT * FROM some_table + OFFSET 1 + 1 ROWS FETCH FIRST 1 + 1 ROWS ONLY + """ + return exclusions.closed() + + @property + def autoincrement_without_sequence(self): + """If autoincrement=True on a column does not require an explicit + sequence. This should be false only for oracle. + """ + return exclusions.open() + + @property + def generic_classes(self): + "If X[Y] can be implemented with ``__class_getitem__``. py3.7+" + return exclusions.open() + + @property + def json_deserializer_binary(self): + "indicates if the json_deserializer function is called with bytes" + return exclusions.closed() + + @property + def reflect_table_options(self): + """Target database must support reflecting table_options.""" + return exclusions.closed() + + @property + def materialized_views(self): + """Target database must support MATERIALIZED VIEWs.""" + return exclusions.closed() + + @property + def materialized_views_reflect_pk(self): + """Target database reflect MATERIALIZED VIEWs pks.""" + return exclusions.closed() diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/schema.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/schema.py new file mode 100644 index 0000000..7dfd33d --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/schema.py @@ -0,0 +1,224 @@ +# testing/schema.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from __future__ import annotations + +import sys + +from . import config +from . import exclusions +from .. import event +from .. import schema +from .. import types as sqltypes +from ..orm import mapped_column as _orm_mapped_column +from ..util import OrderedDict + +__all__ = ["Table", "Column"] + +table_options = {} + + +def Table(*args, **kw) -> schema.Table: + """A schema.Table wrapper/hook for dialect-specific tweaks.""" + + test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} + + kw.update(table_options) + + if exclusions.against(config._current, "mysql"): + if ( + "mysql_engine" not in kw + and "mysql_type" not in kw + and "autoload_with" not in kw + ): + if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts: + kw["mysql_engine"] = "InnoDB" + else: + # there are in fact test fixtures that rely upon MyISAM, + # due to MySQL / MariaDB having poor FK behavior under innodb, + # such as a self-referential table can't be deleted from at + # once without attending to per-row dependencies. We'd need to + # add special steps to some fixtures if we want to not + # explicitly state MyISAM here + kw["mysql_engine"] = "MyISAM" + elif exclusions.against(config._current, "mariadb"): + if ( + "mariadb_engine" not in kw + and "mariadb_type" not in kw + and "autoload_with" not in kw + ): + if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts: + kw["mariadb_engine"] = "InnoDB" + else: + kw["mariadb_engine"] = "MyISAM" + + return schema.Table(*args, **kw) + + +def mapped_column(*args, **kw): + """An orm.mapped_column wrapper/hook for dialect-specific tweaks.""" + + return _schema_column(_orm_mapped_column, args, kw) + + +def Column(*args, **kw): + """A schema.Column wrapper/hook for dialect-specific tweaks.""" + + return _schema_column(schema.Column, args, kw) + + +def _schema_column(factory, args, kw): + test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} + + if not config.requirements.foreign_key_ddl.enabled_for_config(config): + args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)] + + construct = factory(*args, **kw) + + if factory is schema.Column: + col = construct + else: + col = construct.column + + if test_opts.get("test_needs_autoincrement", False) and kw.get( + "primary_key", False + ): + if col.default is None and col.server_default is None: + col.autoincrement = True + + # allow any test suite to pick up on this + col.info["test_needs_autoincrement"] = True + + # hardcoded rule for oracle; this should + # be moved out + if exclusions.against(config._current, "oracle"): + + def add_seq(c, tbl): + c._init_items( + schema.Sequence( + _truncate_name( + config.db.dialect, tbl.name + "_" + c.name + "_seq" + ), + optional=True, + ) + ) + + event.listen(col, "after_parent_attach", add_seq, propagate=True) + return construct + + +class eq_type_affinity: + """Helper to compare types inside of datastructures based on affinity. + + E.g.:: + + eq_( + inspect(connection).get_columns("foo"), + [ + { + "name": "id", + "type": testing.eq_type_affinity(sqltypes.INTEGER), + "nullable": False, + "default": None, + "autoincrement": False, + }, + { + "name": "data", + "type": testing.eq_type_affinity(sqltypes.NullType), + "nullable": True, + "default": None, + "autoincrement": False, + }, + ], + ) + + """ + + def __init__(self, target): + self.target = sqltypes.to_instance(target) + + def __eq__(self, other): + return self.target._type_affinity is other._type_affinity + + def __ne__(self, other): + return self.target._type_affinity is not other._type_affinity + + +class eq_compile_type: + """similar to eq_type_affinity but uses compile""" + + def __init__(self, target): + self.target = target + + def __eq__(self, other): + return self.target == other.compile() + + def __ne__(self, other): + return self.target != other.compile() + + +class eq_clause_element: + """Helper to compare SQL structures based on compare()""" + + def __init__(self, target): + self.target = target + + def __eq__(self, other): + return self.target.compare(other) + + def __ne__(self, other): + return not self.target.compare(other) + + +def _truncate_name(dialect, name): + if len(name) > dialect.max_identifier_length: + return ( + name[0 : max(dialect.max_identifier_length - 6, 0)] + + "_" + + hex(hash(name) % 64)[2:] + ) + else: + return name + + +def pep435_enum(name): + # Implements PEP 435 in the minimal fashion needed by SQLAlchemy + __members__ = OrderedDict() + + def __init__(self, name, value, alias=None): + self.name = name + self.value = value + self.__members__[name] = self + value_to_member[value] = self + setattr(self.__class__, name, self) + if alias: + self.__members__[alias] = self + setattr(self.__class__, alias, self) + + value_to_member = {} + + @classmethod + def get(cls, value): + return value_to_member[value] + + someenum = type( + name, + (object,), + {"__members__": __members__, "__init__": __init__, "get": get}, + ) + + # getframe() trick for pickling I don't understand courtesy + # Python namedtuple() + try: + module = sys._getframe(1).f_globals.get("__name__", "__main__") + except (AttributeError, ValueError): + pass + if module is not None: + someenum.__module__ = module + + return someenum diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__init__.py new file mode 100644 index 0000000..a146cb3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__init__.py @@ -0,0 +1,19 @@ +# testing/suite/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from .test_cte import * # noqa +from .test_ddl import * # noqa +from .test_deprecations import * # noqa +from .test_dialect import * # noqa +from .test_insert import * # noqa +from .test_reflection import * # noqa +from .test_results import * # noqa +from .test_rowcount import * # noqa +from .test_select import * # noqa +from .test_sequence import * # noqa +from .test_types import * # noqa +from .test_unicode_ddl import * # noqa +from .test_update_delete import * # noqa diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b8ae9f7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_cte.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_cte.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6bf56dd --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_cte.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_ddl.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_ddl.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..99ddd6a --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_ddl.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_deprecations.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_deprecations.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a114600 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_deprecations.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_dialect.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_dialect.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..325c9a1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_dialect.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_insert.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_insert.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c212fa4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_insert.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_reflection.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_reflection.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..858574b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_reflection.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_results.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_results.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..624d1f2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_results.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_rowcount.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_rowcount.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4ebccba --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_rowcount.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_select.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_select.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5edde44 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_select.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_sequence.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_sequence.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..576fdcf --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_sequence.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_types.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..79cb268 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_unicode_ddl.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_unicode_ddl.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1c8d3ff --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_unicode_ddl.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_update_delete.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_update_delete.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1be9d55 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_update_delete.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_cte.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_cte.py new file mode 100644 index 0000000..5d37880 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_cte.py @@ -0,0 +1,211 @@ +# testing/suite/test_cte.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .. import fixtures +from ..assertions import eq_ +from ..schema import Column +from ..schema import Table +from ... import ForeignKey +from ... import Integer +from ... import select +from ... import String +from ... import testing + + +class CTETest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("ctes",) + + run_inserts = "each" + run_deletes = "each" + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", ForeignKey("some_table.id")), + ) + + Table( + "some_other_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "d1", "parent_id": None}, + {"id": 2, "data": "d2", "parent_id": 1}, + {"id": 3, "data": "d3", "parent_id": 1}, + {"id": 4, "data": "d4", "parent_id": 3}, + {"id": 5, "data": "d5", "parent_id": 3}, + ], + ) + + def test_select_nonrecursive_round_trip(self, connection): + some_table = self.tables.some_table + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + result = connection.execute( + select(cte.c.data).where(cte.c.data.in_(["d4", "d5"])) + ) + eq_(result.fetchall(), [("d4",)]) + + def test_select_recursive_round_trip(self, connection): + some_table = self.tables.some_table + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte", recursive=True) + ) + + cte_alias = cte.alias("c1") + st1 = some_table.alias() + # note that SQL Server requires this to be UNION ALL, + # can't be UNION + cte = cte.union_all( + select(st1).where(st1.c.id == cte_alias.c.parent_id) + ) + result = connection.execute( + select(cte.c.data) + .where(cte.c.data != "d2") + .order_by(cte.c.data.desc()) + ) + eq_( + result.fetchall(), + [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)], + ) + + def test_insert_from_select_round_trip(self, connection): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(cte) + ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)], + ) + + @testing.requires.ctes_with_update_delete + @testing.requires.update_from + def test_update_from_round_trip(self, connection): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) + ) + ) + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.update() + .values(parent_id=5) + .where(some_other_table.c.data == cte.c.data) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [ + (1, "d1", None), + (2, "d2", 5), + (3, "d3", 5), + (4, "d4", 5), + (5, "d5", 3), + ], + ) + + @testing.requires.ctes_with_update_delete + @testing.requires.delete_from + def test_delete_from_round_trip(self, connection): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) + ) + ) + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.delete().where( + some_other_table.c.data == cte.c.data + ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(1, "d1", None), (5, "d5", 3)], + ) + + @testing.requires.ctes_with_update_delete + def test_delete_scalar_subq_round_trip(self, connection): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) + ) + ) + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.delete().where( + some_other_table.c.data + == select(cte.c.data) + .where(cte.c.id == some_other_table.c.id) + .scalar_subquery() + ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(1, "d1", None), (5, "d5", 3)], + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_ddl.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_ddl.py new file mode 100644 index 0000000..3d9b8ec --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_ddl.py @@ -0,0 +1,389 @@ +# testing/suite/test_ddl.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import random + +from . import testing +from .. import config +from .. import fixtures +from .. import util +from ..assertions import eq_ +from ..assertions import is_false +from ..assertions import is_true +from ..config import requirements +from ..schema import Table +from ... import CheckConstraint +from ... import Column +from ... import ForeignKeyConstraint +from ... import Index +from ... import inspect +from ... import Integer +from ... import schema +from ... import String +from ... import UniqueConstraint + + +class TableDDLTest(fixtures.TestBase): + __backend__ = True + + def _simple_fixture(self, schema=None): + return Table( + "test_table", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + schema=schema, + ) + + def _underscore_fixture(self): + return Table( + "_test_table", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("_data", String(50)), + ) + + def _table_index_fixture(self, schema=None): + table = self._simple_fixture(schema=schema) + idx = Index("test_index", table.c.data) + return table, idx + + def _simple_roundtrip(self, table): + with config.db.begin() as conn: + conn.execute(table.insert().values((1, "some data"))) + result = conn.execute(table.select()) + eq_(result.first(), (1, "some data")) + + @requirements.create_table + @util.provide_metadata + def test_create_table(self): + table = self._simple_fixture() + table.create(config.db, checkfirst=False) + self._simple_roundtrip(table) + + @requirements.create_table + @requirements.schemas + @util.provide_metadata + def test_create_table_schema(self): + table = self._simple_fixture(schema=config.test_schema) + table.create(config.db, checkfirst=False) + self._simple_roundtrip(table) + + @requirements.drop_table + @util.provide_metadata + def test_drop_table(self): + table = self._simple_fixture() + table.create(config.db, checkfirst=False) + table.drop(config.db, checkfirst=False) + + @requirements.create_table + @util.provide_metadata + def test_underscore_names(self): + table = self._underscore_fixture() + table.create(config.db, checkfirst=False) + self._simple_roundtrip(table) + + @requirements.comment_reflection + @util.provide_metadata + def test_add_table_comment(self, connection): + table = self._simple_fixture() + table.create(connection, checkfirst=False) + table.comment = "a comment" + connection.execute(schema.SetTableComment(table)) + eq_( + inspect(connection).get_table_comment("test_table"), + {"text": "a comment"}, + ) + + @requirements.comment_reflection + @util.provide_metadata + def test_drop_table_comment(self, connection): + table = self._simple_fixture() + table.create(connection, checkfirst=False) + table.comment = "a comment" + connection.execute(schema.SetTableComment(table)) + connection.execute(schema.DropTableComment(table)) + eq_( + inspect(connection).get_table_comment("test_table"), {"text": None} + ) + + @requirements.table_ddl_if_exists + @util.provide_metadata + def test_create_table_if_not_exists(self, connection): + table = self._simple_fixture() + + connection.execute(schema.CreateTable(table, if_not_exists=True)) + + is_true(inspect(connection).has_table("test_table")) + connection.execute(schema.CreateTable(table, if_not_exists=True)) + + @requirements.index_ddl_if_exists + @util.provide_metadata + def test_create_index_if_not_exists(self, connection): + table, idx = self._table_index_fixture() + + connection.execute(schema.CreateTable(table, if_not_exists=True)) + is_true(inspect(connection).has_table("test_table")) + is_false( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.CreateIndex(idx, if_not_exists=True)) + + is_true( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.CreateIndex(idx, if_not_exists=True)) + + @requirements.table_ddl_if_exists + @util.provide_metadata + def test_drop_table_if_exists(self, connection): + table = self._simple_fixture() + + table.create(connection) + + is_true(inspect(connection).has_table("test_table")) + + connection.execute(schema.DropTable(table, if_exists=True)) + + is_false(inspect(connection).has_table("test_table")) + + connection.execute(schema.DropTable(table, if_exists=True)) + + @requirements.index_ddl_if_exists + @util.provide_metadata + def test_drop_index_if_exists(self, connection): + table, idx = self._table_index_fixture() + + table.create(connection) + + is_true( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.DropIndex(idx, if_exists=True)) + + is_false( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.DropIndex(idx, if_exists=True)) + + +class FutureTableDDLTest(fixtures.FutureEngineMixin, TableDDLTest): + pass + + +class LongNameBlowoutTest(fixtures.TestBase): + """test the creation of a variety of DDL structures and ensure + label length limits pass on backends + + """ + + __backend__ = True + + def fk(self, metadata, connection): + convention = { + "fk": "foreign_key_%(table_name)s_" + "%(column_0_N_name)s_" + "%(referred_table_name)s_" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(20)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + test_needs_fk=True, + ) + + cons = ForeignKeyConstraint( + ["aid"], ["a_things_with_stuff.id_long_column_name"] + ) + Table( + "b_related_things_of_value", + metadata, + Column( + "aid", + ), + cons, + test_needs_fk=True, + ) + actual_name = cons.name + + metadata.create_all(connection) + + if testing.requires.foreign_key_constraint_name_reflection.enabled: + insp = inspect(connection) + fks = insp.get_foreign_keys("b_related_things_of_value") + reflected_name = fks[0]["name"] + + return actual_name, reflected_name + else: + return actual_name, None + + def pk(self, metadata, connection): + convention = { + "pk": "primary_key_%(table_name)s_" + "%(column_0_N_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + a = Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("id_another_long_name", Integer, primary_key=True), + ) + cons = a.primary_key + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + pk = insp.get_pk_constraint("a_things_with_stuff") + reflected_name = pk["name"] + return actual_name, reflected_name + + def ix(self, metadata, connection): + convention = { + "ix": "index_%(table_name)s_" + "%(column_0_N_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + a = Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("id_another_long_name", Integer), + ) + cons = Index(None, a.c.id_long_column_name, a.c.id_another_long_name) + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + ix = insp.get_indexes("a_things_with_stuff") + reflected_name = ix[0]["name"] + return actual_name, reflected_name + + def uq(self, metadata, connection): + convention = { + "uq": "unique_constraint_%(table_name)s_" + "%(column_0_N_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + cons = UniqueConstraint("id_long_column_name", "id_another_long_name") + Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("id_another_long_name", Integer), + cons, + ) + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + uq = insp.get_unique_constraints("a_things_with_stuff") + reflected_name = uq[0]["name"] + return actual_name, reflected_name + + def ck(self, metadata, connection): + convention = { + "ck": "check_constraint_%(table_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + cons = CheckConstraint("some_long_column_name > 5") + Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("some_long_column_name", Integer), + cons, + ) + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + ck = insp.get_check_constraints("a_things_with_stuff") + reflected_name = ck[0]["name"] + return actual_name, reflected_name + + @testing.combinations( + ("fk",), + ("pk",), + ("ix",), + ("ck", testing.requires.check_constraint_reflection.as_skips()), + ("uq", testing.requires.unique_constraint_reflection.as_skips()), + argnames="type_", + ) + def test_long_convention_name(self, type_, metadata, connection): + actual_name, reflected_name = getattr(self, type_)( + metadata, connection + ) + + assert len(actual_name) > 255 + + if reflected_name is not None: + overlap = actual_name[0 : len(reflected_name)] + if len(overlap) < len(actual_name): + eq_(overlap[0:-5], reflected_name[0 : len(overlap) - 5]) + else: + eq_(overlap, reflected_name) + + +__all__ = ("TableDDLTest", "FutureTableDDLTest", "LongNameBlowoutTest") diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_deprecations.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_deprecations.py new file mode 100644 index 0000000..07970c0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_deprecations.py @@ -0,0 +1,153 @@ +# testing/suite/test_deprecations.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .. import fixtures +from ..assertions import eq_ +from ..schema import Column +from ..schema import Table +from ... import Integer +from ... import select +from ... import testing +from ... import union + + +class DeprecatedCompoundSelectTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + ], + ) + + def _assert_result(self, conn, select, result, params=()): + eq_(conn.execute(select, params).fetchall(), result) + + def test_plain_union(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + # note we've had to remove one use case entirely, which is this + # one. the Select gets its FROMS from the WHERE clause and the + # columns clause, but not the ORDER BY, which means the old ".c" system + # allowed you to "order_by(s.c.foo)" to get an unnamed column in the + # ORDER BY without adding the SELECT into the FROM and breaking the + # query. Users will have to adjust for this use case if they were doing + # it before. + def _dont_test_select_from_plain_union(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2).alias().select() + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.order_by_col_from_union + @testing.requires.parens_in_union_contained_select_w_limit_offset + def test_limit_offset_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_wo_limit_offset + def test_order_by_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_distinct_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).distinct() + s2 = select(table).where(table.c.id == 3).distinct() + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_limit_offset_aliased_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = ( + select(table) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + s2 = ( + select(table) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_dialect.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_dialect.py new file mode 100644 index 0000000..6964720 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_dialect.py @@ -0,0 +1,740 @@ +# testing/suite/test_dialect.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +import importlib + +from . import testing +from .. import assert_raises +from .. import config +from .. import engines +from .. import eq_ +from .. import fixtures +from .. import is_not_none +from .. import is_true +from .. import ne_ +from .. import provide_metadata +from ..assertions import expect_raises +from ..assertions import expect_raises_message +from ..config import requirements +from ..provision import set_default_schema_on_connection +from ..schema import Column +from ..schema import Table +from ... import bindparam +from ... import dialects +from ... import event +from ... import exc +from ... import Integer +from ... import literal_column +from ... import select +from ... import String +from ...sql.compiler import Compiled +from ...util import inspect_getfullargspec + + +class PingTest(fixtures.TestBase): + __backend__ = True + + def test_do_ping(self): + with testing.db.connect() as conn: + is_true( + testing.db.dialect.do_ping(conn.connection.dbapi_connection) + ) + + +class ArgSignatureTest(fixtures.TestBase): + """test that all visit_XYZ() in :class:`_sql.Compiler` subclasses have + ``**kw``, for #8988. + + This test uses runtime code inspection. Does not need to be a + ``__backend__`` test as it only needs to run once provided all target + dialects have been imported. + + For third party dialects, the suite would be run with that third + party as a "--dburi", which means its compiler classes will have been + imported by the time this test runs. + + """ + + def _all_subclasses(): # type: ignore # noqa + for d in dialects.__all__: + if not d.startswith("_"): + importlib.import_module("sqlalchemy.dialects.%s" % d) + + stack = [Compiled] + + while stack: + cls = stack.pop(0) + stack.extend(cls.__subclasses__()) + yield cls + + @testing.fixture(params=list(_all_subclasses())) + def all_subclasses(self, request): + yield request.param + + def test_all_visit_methods_accept_kw(self, all_subclasses): + cls = all_subclasses + + for k in cls.__dict__: + if k.startswith("visit_"): + meth = getattr(cls, k) + + insp = inspect_getfullargspec(meth) + is_not_none( + insp.varkw, + f"Compiler visit method {cls.__name__}.{k}() does " + "not accommodate for **kw in its argument signature", + ) + + +class ExceptionTest(fixtures.TablesTest): + """Test basic exception wrapping. + + DBAPIs vary a lot in exception behavior so to actually anticipate + specific exceptions from real round trips, we need to be conservative. + + """ + + run_deletes = "each" + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + + @requirements.duplicate_key_raises_integrity_error + def test_integrity_error(self): + with config.db.connect() as conn: + trans = conn.begin() + conn.execute( + self.tables.manual_pk.insert(), {"id": 1, "data": "d1"} + ) + + assert_raises( + exc.IntegrityError, + conn.execute, + self.tables.manual_pk.insert(), + {"id": 1, "data": "d1"}, + ) + + trans.rollback() + + def test_exception_with_non_ascii(self): + with config.db.connect() as conn: + try: + # try to create an error message that likely has non-ascii + # characters in the DBAPI's message string. unfortunately + # there's no way to make this happen with some drivers like + # mysqlclient, pymysql. this at least does produce a non- + # ascii error message for cx_oracle, psycopg2 + conn.execute(select(literal_column("méil"))) + assert False + except exc.DBAPIError as err: + err_str = str(err) + + assert str(err.orig) in str(err) + + assert isinstance(err_str, str) + + +class IsolationLevelTest(fixtures.TestBase): + __backend__ = True + + __requires__ = ("isolation_level",) + + def _get_non_default_isolation_level(self): + levels = requirements.get_isolation_levels(config) + + default = levels["default"] + supported = levels["supported"] + + s = set(supported).difference(["AUTOCOMMIT", default]) + if s: + return s.pop() + else: + config.skip_test("no non-default isolation level available") + + def test_default_isolation_level(self): + eq_( + config.db.dialect.default_isolation_level, + requirements.get_isolation_levels(config)["default"], + ) + + def test_non_default_isolation_level(self): + non_default = self._get_non_default_isolation_level() + + with config.db.connect() as conn: + existing = conn.get_isolation_level() + + ne_(existing, non_default) + + conn.execution_options(isolation_level=non_default) + + eq_(conn.get_isolation_level(), non_default) + + conn.dialect.reset_isolation_level( + conn.connection.dbapi_connection + ) + + eq_(conn.get_isolation_level(), existing) + + def test_all_levels(self): + levels = requirements.get_isolation_levels(config) + + all_levels = levels["supported"] + + for level in set(all_levels).difference(["AUTOCOMMIT"]): + with config.db.connect() as conn: + conn.execution_options(isolation_level=level) + + eq_(conn.get_isolation_level(), level) + + trans = conn.begin() + trans.rollback() + + eq_(conn.get_isolation_level(), level) + + with config.db.connect() as conn: + eq_( + conn.get_isolation_level(), + levels["default"], + ) + + @testing.requires.get_isolation_level_values + def test_invalid_level_execution_option(self, connection_no_trans): + """test for the new get_isolation_level_values() method""" + + connection = connection_no_trans + with expect_raises_message( + exc.ArgumentError, + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for '%s' are %s" + % ( + "FOO", + connection.dialect.name, + ", ".join( + requirements.get_isolation_levels(config)["supported"] + ), + ), + ): + connection.execution_options(isolation_level="FOO") + + @testing.requires.get_isolation_level_values + @testing.requires.dialect_level_isolation_level_param + def test_invalid_level_engine_param(self, testing_engine): + """test for the new get_isolation_level_values() method + and support for the dialect-level 'isolation_level' parameter. + + """ + + eng = testing_engine(options=dict(isolation_level="FOO")) + with expect_raises_message( + exc.ArgumentError, + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for '%s' are %s" + % ( + "FOO", + eng.dialect.name, + ", ".join( + requirements.get_isolation_levels(config)["supported"] + ), + ), + ): + eng.connect() + + @testing.requires.independent_readonly_connections + def test_dialect_user_setting_is_restored(self, testing_engine): + levels = requirements.get_isolation_levels(config) + default = levels["default"] + supported = ( + sorted( + set(levels["supported"]).difference([default, "AUTOCOMMIT"]) + ) + )[0] + + e = testing_engine(options={"isolation_level": supported}) + + with e.connect() as conn: + eq_(conn.get_isolation_level(), supported) + + with e.connect() as conn: + conn.execution_options(isolation_level=default) + eq_(conn.get_isolation_level(), default) + + with e.connect() as conn: + eq_(conn.get_isolation_level(), supported) + + +class AutocommitIsolationTest(fixtures.TablesTest): + run_deletes = "each" + + __requires__ = ("autocommit",) + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + test_needs_acid=True, + ) + + def _test_conn_autocommits(self, conn, autocommit): + trans = conn.begin() + conn.execute( + self.tables.some_table.insert(), {"id": 1, "data": "some data"} + ) + trans.rollback() + + eq_( + conn.scalar(select(self.tables.some_table.c.id)), + 1 if autocommit else None, + ) + conn.rollback() + + with conn.begin(): + conn.execute(self.tables.some_table.delete()) + + def test_autocommit_on(self, connection_no_trans): + conn = connection_no_trans + c2 = conn.execution_options(isolation_level="AUTOCOMMIT") + self._test_conn_autocommits(c2, True) + + c2.dialect.reset_isolation_level(c2.connection.dbapi_connection) + + self._test_conn_autocommits(conn, False) + + def test_autocommit_off(self, connection_no_trans): + conn = connection_no_trans + self._test_conn_autocommits(conn, False) + + def test_turn_autocommit_off_via_default_iso_level( + self, connection_no_trans + ): + conn = connection_no_trans + conn = conn.execution_options(isolation_level="AUTOCOMMIT") + self._test_conn_autocommits(conn, True) + + conn.execution_options( + isolation_level=requirements.get_isolation_levels(config)[ + "default" + ] + ) + self._test_conn_autocommits(conn, False) + + @testing.requires.independent_readonly_connections + @testing.variation("use_dialect_setting", [True, False]) + def test_dialect_autocommit_is_restored( + self, testing_engine, use_dialect_setting + ): + """test #10147""" + + if use_dialect_setting: + e = testing_engine(options={"isolation_level": "AUTOCOMMIT"}) + else: + e = testing_engine().execution_options( + isolation_level="AUTOCOMMIT" + ) + + levels = requirements.get_isolation_levels(config) + + default = levels["default"] + + with e.connect() as conn: + self._test_conn_autocommits(conn, True) + + with e.connect() as conn: + conn.execution_options(isolation_level=default) + self._test_conn_autocommits(conn, False) + + with e.connect() as conn: + self._test_conn_autocommits(conn, True) + + +class EscapingTest(fixtures.TestBase): + @provide_metadata + def test_percent_sign_round_trip(self): + """test that the DBAPI accommodates for escaped / nonescaped + percent signs in a way that matches the compiler + + """ + m = self.metadata + t = Table("t", m, Column("data", String(50))) + t.create(config.db) + with config.db.begin() as conn: + conn.execute(t.insert(), dict(data="some % value")) + conn.execute(t.insert(), dict(data="some %% other value")) + + eq_( + conn.scalar( + select(t.c.data).where( + t.c.data == literal_column("'some % value'") + ) + ), + "some % value", + ) + + eq_( + conn.scalar( + select(t.c.data).where( + t.c.data == literal_column("'some %% other value'") + ) + ), + "some %% other value", + ) + + +class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase): + __backend__ = True + + __requires__ = ("default_schema_name_switch",) + + def test_control_case(self): + default_schema_name = config.db.dialect.default_schema_name + + eng = engines.testing_engine() + with eng.connect(): + pass + + eq_(eng.dialect.default_schema_name, default_schema_name) + + def test_wont_work_wo_insert(self): + default_schema_name = config.db.dialect.default_schema_name + + eng = engines.testing_engine() + + @event.listens_for(eng, "connect") + def on_connect(dbapi_connection, connection_record): + set_default_schema_on_connection( + config, dbapi_connection, config.test_schema + ) + + with eng.connect() as conn: + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + + eq_(eng.dialect.default_schema_name, default_schema_name) + + def test_schema_change_on_connect(self): + eng = engines.testing_engine() + + @event.listens_for(eng, "connect", insert=True) + def on_connect(dbapi_connection, connection_record): + set_default_schema_on_connection( + config, dbapi_connection, config.test_schema + ) + + with eng.connect() as conn: + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + + eq_(eng.dialect.default_schema_name, config.test_schema) + + def test_schema_change_works_w_transactions(self): + eng = engines.testing_engine() + + @event.listens_for(eng, "connect", insert=True) + def on_connect(dbapi_connection, *arg): + set_default_schema_on_connection( + config, dbapi_connection, config.test_schema + ) + + with eng.connect() as conn: + trans = conn.begin() + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + trans.rollback() + + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + + eq_(eng.dialect.default_schema_name, config.test_schema) + + +class FutureWeCanSetDefaultSchemaWEventsTest( + fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest +): + pass + + +class DifficultParametersTest(fixtures.TestBase): + __backend__ = True + + tough_parameters = testing.combinations( + ("boring",), + ("per cent",), + ("per % cent",), + ("%percent",), + ("par(ens)",), + ("percent%(ens)yah",), + ("col:ons",), + ("_starts_with_underscore",), + ("dot.s",), + ("more :: %colons%",), + ("_name",), + ("___name",), + ("[BracketsAndCase]",), + ("42numbers",), + ("percent%signs",), + ("has spaces",), + ("/slashes/",), + ("more/slashes",), + ("q?marks",), + ("1param",), + ("1col:on",), + argnames="paramname", + ) + + @tough_parameters + @config.requirements.unusual_column_name_characters + def test_round_trip_same_named_column( + self, paramname, connection, metadata + ): + name = paramname + + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column(name, String(50), nullable=False), + ) + + # table is created + t.create(connection) + + # automatic param generated by insert + connection.execute(t.insert().values({"id": 1, name: "some name"})) + + # automatic param generated by criteria, plus selecting the column + stmt = select(t.c[name]).where(t.c[name] == "some name") + + eq_(connection.scalar(stmt), "some name") + + # use the name in a param explicitly + stmt = select(t.c[name]).where(t.c[name] == bindparam(name)) + + row = connection.execute(stmt, {name: "some name"}).first() + + # name works as the key from cursor.description + eq_(row._mapping[name], "some name") + + # use expanding IN + stmt = select(t.c[name]).where( + t.c[name].in_(["some name", "some other_name"]) + ) + + row = connection.execute(stmt).first() + + @testing.fixture + def multirow_fixture(self, metadata, connection): + mytable = Table( + "mytable", + metadata, + Column("myid", Integer), + Column("name", String(50)), + Column("desc", String(50)), + ) + + mytable.create(connection) + + connection.execute( + mytable.insert(), + [ + {"myid": 1, "name": "a", "desc": "a_desc"}, + {"myid": 2, "name": "b", "desc": "b_desc"}, + {"myid": 3, "name": "c", "desc": "c_desc"}, + {"myid": 4, "name": "d", "desc": "d_desc"}, + ], + ) + yield mytable + + @tough_parameters + def test_standalone_bindparam_escape( + self, paramname, connection, multirow_fixture + ): + tbl1 = multirow_fixture + stmt = select(tbl1.c.myid).where( + tbl1.c.name == bindparam(paramname, value="x") + ) + res = connection.scalar(stmt, {paramname: "c"}) + eq_(res, 3) + + @tough_parameters + def test_standalone_bindparam_escape_expanding( + self, paramname, connection, multirow_fixture + ): + tbl1 = multirow_fixture + stmt = ( + select(tbl1.c.myid) + .where(tbl1.c.name.in_(bindparam(paramname, value=["a", "b"]))) + .order_by(tbl1.c.myid) + ) + + res = connection.scalars(stmt, {paramname: ["d", "a"]}).all() + eq_(res, [1, 4]) + + +class ReturningGuardsTest(fixtures.TablesTest): + """test that the various 'returning' flags are set appropriately""" + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "t", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + + @testing.fixture + def run_stmt(self, connection): + t = self.tables.t + + def go(stmt, executemany, id_param_name, expect_success): + stmt = stmt.returning(t.c.id) + + if executemany: + if not expect_success: + # for RETURNING executemany(), we raise our own + # error as this is independent of general RETURNING + # support + with expect_raises_message( + exc.StatementError, + rf"Dialect {connection.dialect.name}\+" + f"{connection.dialect.driver} with " + f"current server capabilities does not support " + f".*RETURNING when executemany is used", + ): + result = connection.execute( + stmt, + [ + {id_param_name: 1, "data": "d1"}, + {id_param_name: 2, "data": "d2"}, + {id_param_name: 3, "data": "d3"}, + ], + ) + else: + result = connection.execute( + stmt, + [ + {id_param_name: 1, "data": "d1"}, + {id_param_name: 2, "data": "d2"}, + {id_param_name: 3, "data": "d3"}, + ], + ) + eq_(result.all(), [(1,), (2,), (3,)]) + else: + if not expect_success: + # for RETURNING execute(), we pass all the way to the DB + # and let it fail + with expect_raises(exc.DBAPIError): + connection.execute( + stmt, {id_param_name: 1, "data": "d1"} + ) + else: + result = connection.execute( + stmt, {id_param_name: 1, "data": "d1"} + ) + eq_(result.all(), [(1,)]) + + return go + + def test_insert_single(self, connection, run_stmt): + t = self.tables.t + + stmt = t.insert() + + run_stmt(stmt, False, "id", connection.dialect.insert_returning) + + def test_insert_many(self, connection, run_stmt): + t = self.tables.t + + stmt = t.insert() + + run_stmt( + stmt, True, "id", connection.dialect.insert_executemany_returning + ) + + def test_update_single(self, connection, run_stmt): + t = self.tables.t + + connection.execute( + t.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + stmt = t.update().where(t.c.id == bindparam("b_id")) + + run_stmt(stmt, False, "b_id", connection.dialect.update_returning) + + def test_update_many(self, connection, run_stmt): + t = self.tables.t + + connection.execute( + t.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + stmt = t.update().where(t.c.id == bindparam("b_id")) + + run_stmt( + stmt, True, "b_id", connection.dialect.update_executemany_returning + ) + + def test_delete_single(self, connection, run_stmt): + t = self.tables.t + + connection.execute( + t.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + stmt = t.delete().where(t.c.id == bindparam("b_id")) + + run_stmt(stmt, False, "b_id", connection.dialect.delete_returning) + + def test_delete_many(self, connection, run_stmt): + t = self.tables.t + + connection.execute( + t.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + stmt = t.delete().where(t.c.id == bindparam("b_id")) + + run_stmt( + stmt, True, "b_id", connection.dialect.delete_executemany_returning + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_insert.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_insert.py new file mode 100644 index 0000000..1cff044 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_insert.py @@ -0,0 +1,630 @@ +# testing/suite/test_insert.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from decimal import Decimal +import uuid + +from . import testing +from .. import fixtures +from ..assertions import eq_ +from ..config import requirements +from ..schema import Column +from ..schema import Table +from ... import Double +from ... import Float +from ... import Identity +from ... import Integer +from ... import literal +from ... import literal_column +from ... import Numeric +from ... import select +from ... import String +from ...types import LargeBinary +from ...types import UUID +from ...types import Uuid + + +class LastrowidTest(fixtures.TablesTest): + run_deletes = "each" + + __backend__ = True + + __requires__ = "implements_get_lastrowid", "autoincrement_insert" + + @classmethod + def define_tables(cls, metadata): + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + implicit_returning=False, + ) + + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + implicit_returning=False, + ) + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_( + row, + ( + conn.dialect.default_sequence_base, + "some data", + ), + ) + + def test_autoincrement_on_insert(self, connection): + connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + self._assert_round_trip(self.tables.autoinc_pk, connection) + + def test_last_inserted_id(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) + eq_(r.inserted_primary_key, (pk,)) + + @requirements.dbapi_lastrowid + def test_native_lastrowid_autoinc(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + lastrowid = r.lastrowid + pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) + eq_(lastrowid, pk) + + +class InsertBehaviorTest(fixtures.TablesTest): + run_deletes = "each" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + Table( + "no_implicit_returning", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + implicit_returning=False, + ) + Table( + "includes_defaults", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("x", Integer, default=5), + Column( + "y", + Integer, + default=literal_column("2", type_=Integer) + literal(2), + ), + ) + + @testing.variation("style", ["plain", "return_defaults"]) + @testing.variation("executemany", [True, False]) + def test_no_results_for_non_returning_insert( + self, connection, style, executemany + ): + """test another INSERT issue found during #10453""" + + table = self.tables.no_implicit_returning + + stmt = table.insert() + if style.return_defaults: + stmt = stmt.return_defaults() + + if executemany: + data = [ + {"data": "d1"}, + {"data": "d2"}, + {"data": "d3"}, + {"data": "d4"}, + {"data": "d5"}, + ] + else: + data = {"data": "d1"} + + r = connection.execute(stmt, data) + assert not r.returns_rows + + @requirements.autoincrement_insert + def test_autoclose_on_insert(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + assert r._soft_closed + assert not r.closed + assert r.is_insert + + # new as of I8091919d45421e3f53029b8660427f844fee0228; for the moment + # an insert where the PK was taken from a row that the dialect + # selected, as is the case for mssql/pyodbc, will still report + # returns_rows as true because there's a cursor description. in that + # case, the row had to have been consumed at least. + assert not r.returns_rows or r.fetchone() is None + + @requirements.insert_returning + def test_autoclose_on_insert_implicit_returning(self, connection): + r = connection.execute( + # return_defaults() ensures RETURNING will be used, + # new in 2.0 as sqlite/mariadb offer both RETURNING and + # cursor.lastrowid + self.tables.autoinc_pk.insert().return_defaults(), + dict(data="some data"), + ) + assert r._soft_closed + assert not r.closed + assert r.is_insert + + # note we are experimenting with having this be True + # as of I8091919d45421e3f53029b8660427f844fee0228 . + # implicit returning has fetched the row, but it still is a + # "returns rows" + assert r.returns_rows + + # and we should be able to fetchone() on it, we just get no row + eq_(r.fetchone(), None) + + # and the keys, etc. + eq_(r.keys(), ["id"]) + + # but the dialect took in the row already. not really sure + # what the best behavior is. + + @requirements.empty_inserts + def test_empty_insert(self, connection): + r = connection.execute(self.tables.autoinc_pk.insert()) + assert r._soft_closed + assert not r.closed + + r = connection.execute( + self.tables.autoinc_pk.select().where( + self.tables.autoinc_pk.c.id != None + ) + ) + eq_(len(r.all()), 1) + + @requirements.empty_inserts_executemany + def test_empty_insert_multiple(self, connection): + r = connection.execute(self.tables.autoinc_pk.insert(), [{}, {}, {}]) + assert r._soft_closed + assert not r.closed + + r = connection.execute( + self.tables.autoinc_pk.select().where( + self.tables.autoinc_pk.c.id != None + ) + ) + + eq_(len(r.all()), 3) + + @requirements.insert_from_select + def test_insert_from_select_autoinc(self, connection): + src_table = self.tables.manual_pk + dest_table = self.tables.autoinc_pk + connection.execute( + src_table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ], + ) + + result = connection.execute( + dest_table.insert().from_select( + ("data",), + select(src_table.c.data).where( + src_table.c.data.in_(["data2", "data3"]) + ), + ) + ) + + eq_(result.inserted_primary_key, (None,)) + + result = connection.execute( + select(dest_table.c.data).order_by(dest_table.c.data) + ) + eq_(result.fetchall(), [("data2",), ("data3",)]) + + @requirements.insert_from_select + def test_insert_from_select_autoinc_no_rows(self, connection): + src_table = self.tables.manual_pk + dest_table = self.tables.autoinc_pk + + result = connection.execute( + dest_table.insert().from_select( + ("data",), + select(src_table.c.data).where( + src_table.c.data.in_(["data2", "data3"]) + ), + ) + ) + eq_(result.inserted_primary_key, (None,)) + + result = connection.execute( + select(dest_table.c.data).order_by(dest_table.c.data) + ) + + eq_(result.fetchall(), []) + + @requirements.insert_from_select + def test_insert_from_select(self, connection): + table = self.tables.manual_pk + connection.execute( + table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ], + ) + + connection.execute( + table.insert() + .inline() + .from_select( + ("id", "data"), + select(table.c.id + 5, table.c.data).where( + table.c.data.in_(["data2", "data3"]) + ), + ) + ) + + eq_( + connection.execute( + select(table.c.data).order_by(table.c.data) + ).fetchall(), + [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)], + ) + + @requirements.insert_from_select + def test_insert_from_select_with_defaults(self, connection): + table = self.tables.includes_defaults + connection.execute( + table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ], + ) + + connection.execute( + table.insert() + .inline() + .from_select( + ("id", "data"), + select(table.c.id + 5, table.c.data).where( + table.c.data.in_(["data2", "data3"]) + ), + ) + ) + + eq_( + connection.execute( + select(table).order_by(table.c.data, table.c.id) + ).fetchall(), + [ + (1, "data1", 5, 4), + (2, "data2", 5, 4), + (7, "data2", 5, 4), + (3, "data3", 5, 4), + (8, "data3", 5, 4), + ], + ) + + +class ReturningTest(fixtures.TablesTest): + run_create_tables = "each" + __requires__ = "insert_returning", "autoincrement_insert" + __backend__ = True + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_( + row, + ( + conn.dialect.default_sequence_base, + "some data", + ), + ) + + @classmethod + def define_tables(cls, metadata): + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + + @requirements.fetch_rows_post_commit + def test_explicit_returning_pk_autocommit(self, connection): + table = self.tables.autoinc_pk + r = connection.execute( + table.insert().returning(table.c.id), dict(data="some data") + ) + pk = r.first()[0] + fetched_pk = connection.scalar(select(table.c.id)) + eq_(fetched_pk, pk) + + def test_explicit_returning_pk_no_autocommit(self, connection): + table = self.tables.autoinc_pk + r = connection.execute( + table.insert().returning(table.c.id), dict(data="some data") + ) + + pk = r.first()[0] + fetched_pk = connection.scalar(select(table.c.id)) + eq_(fetched_pk, pk) + + def test_autoincrement_on_insert_implicit_returning(self, connection): + connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + self._assert_round_trip(self.tables.autoinc_pk, connection) + + def test_last_inserted_id_implicit_returning(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) + eq_(r.inserted_primary_key, (pk,)) + + @requirements.insert_executemany_returning + def test_insertmanyvalues_returning(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert().returning( + self.tables.autoinc_pk.c.id + ), + [ + {"data": "d1"}, + {"data": "d2"}, + {"data": "d3"}, + {"data": "d4"}, + {"data": "d5"}, + ], + ) + rall = r.all() + + pks = connection.execute(select(self.tables.autoinc_pk.c.id)) + + eq_(rall, pks.all()) + + @testing.combinations( + (Double(), 8.5514716, True), + ( + Double(53), + 8.5514716, + True, + testing.requires.float_or_double_precision_behaves_generically, + ), + (Float(), 8.5514, True), + ( + Float(8), + 8.5514, + True, + testing.requires.float_or_double_precision_behaves_generically, + ), + ( + Numeric(precision=15, scale=12, asdecimal=False), + 8.5514716, + True, + testing.requires.literal_float_coercion, + ), + ( + Numeric(precision=15, scale=12, asdecimal=True), + Decimal("8.5514716"), + False, + ), + argnames="type_,value,do_rounding", + ) + @testing.variation("sort_by_parameter_order", [True, False]) + @testing.variation("multiple_rows", [True, False]) + def test_insert_w_floats( + self, + connection, + metadata, + sort_by_parameter_order, + type_, + value, + do_rounding, + multiple_rows, + ): + """test #9701. + + this tests insertmanyvalues as well as decimal / floating point + RETURNING types + + """ + + t = Table( + # Oracle backends seems to be getting confused if + # this table is named the same as the one + # in test_imv_returning_datatypes. use a different name + "f_t", + metadata, + Column("id", Integer, Identity(), primary_key=True), + Column("value", type_), + ) + + t.create(connection) + + result = connection.execute( + t.insert().returning( + t.c.id, + t.c.value, + sort_by_parameter_order=bool(sort_by_parameter_order), + ), + ( + [{"value": value} for i in range(10)] + if multiple_rows + else {"value": value} + ), + ) + + if multiple_rows: + i_range = range(1, 11) + else: + i_range = range(1, 2) + + # we want to test only that we are getting floating points back + # with some degree of the original value maintained, that it is not + # being truncated to an integer. there's too much variation in how + # drivers return floats, which should not be relied upon to be + # exact, for us to just compare as is (works for PG drivers but not + # others) so we use rounding here. There's precedent for this + # in suite/test_types.py::NumericTest as well + + if do_rounding: + eq_( + {(id_, round(val_, 5)) for id_, val_ in result}, + {(id_, round(value, 5)) for id_ in i_range}, + ) + + eq_( + { + round(val_, 5) + for val_ in connection.scalars(select(t.c.value)) + }, + {round(value, 5)}, + ) + else: + eq_( + set(result), + {(id_, value) for id_ in i_range}, + ) + + eq_( + set(connection.scalars(select(t.c.value))), + {value}, + ) + + @testing.combinations( + ( + "non_native_uuid", + Uuid(native_uuid=False), + uuid.uuid4(), + ), + ( + "non_native_uuid_str", + Uuid(as_uuid=False, native_uuid=False), + str(uuid.uuid4()), + ), + ( + "generic_native_uuid", + Uuid(native_uuid=True), + uuid.uuid4(), + testing.requires.uuid_data_type, + ), + ( + "generic_native_uuid_str", + Uuid(as_uuid=False, native_uuid=True), + str(uuid.uuid4()), + testing.requires.uuid_data_type, + ), + ("UUID", UUID(), uuid.uuid4(), testing.requires.uuid_data_type), + ( + "LargeBinary1", + LargeBinary(), + b"this is binary", + ), + ("LargeBinary2", LargeBinary(), b"7\xe7\x9f"), + argnames="type_,value", + id_="iaa", + ) + @testing.variation("sort_by_parameter_order", [True, False]) + @testing.variation("multiple_rows", [True, False]) + @testing.requires.insert_returning + def test_imv_returning_datatypes( + self, + connection, + metadata, + sort_by_parameter_order, + type_, + value, + multiple_rows, + ): + """test #9739, #9808 (similar to #9701). + + this tests insertmanyvalues in conjunction with various datatypes. + + These tests are particularly for the asyncpg driver which needs + most types to be explicitly cast for the new IMV format + + """ + t = Table( + "d_t", + metadata, + Column("id", Integer, Identity(), primary_key=True), + Column("value", type_), + ) + + t.create(connection) + + result = connection.execute( + t.insert().returning( + t.c.id, + t.c.value, + sort_by_parameter_order=bool(sort_by_parameter_order), + ), + ( + [{"value": value} for i in range(10)] + if multiple_rows + else {"value": value} + ), + ) + + if multiple_rows: + i_range = range(1, 11) + else: + i_range = range(1, 2) + + eq_( + set(result), + {(id_, value) for id_ in i_range}, + ) + + eq_( + set(connection.scalars(select(t.c.value))), + {value}, + ) + + +__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest") diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_reflection.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_reflection.py new file mode 100644 index 0000000..f257d2f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_reflection.py @@ -0,0 +1,3128 @@ +# testing/suite/test_reflection.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import operator +import re + +import sqlalchemy as sa +from .. import config +from .. import engines +from .. import eq_ +from .. import expect_raises +from .. import expect_raises_message +from .. import expect_warnings +from .. import fixtures +from .. import is_ +from ..provision import get_temp_table_name +from ..provision import temp_table_keyword_args +from ..schema import Column +from ..schema import Table +from ... import event +from ... import ForeignKey +from ... import func +from ... import Identity +from ... import inspect +from ... import Integer +from ... import MetaData +from ... import String +from ... import testing +from ... import types as sql_types +from ...engine import Inspector +from ...engine import ObjectKind +from ...engine import ObjectScope +from ...exc import NoSuchTableError +from ...exc import UnreflectableTableError +from ...schema import DDL +from ...schema import Index +from ...sql.elements import quoted_name +from ...sql.schema import BLANK_SCHEMA +from ...testing import ComparesIndexes +from ...testing import ComparesTables +from ...testing import is_false +from ...testing import is_true +from ...testing import mock + + +metadata, users = None, None + + +class OneConnectionTablesTest(fixtures.TablesTest): + @classmethod + def setup_bind(cls): + # TODO: when temp tables are subject to server reset, + # this will also have to disable that server reset from + # happening + if config.requirements.independent_connections.enabled: + from sqlalchemy import pool + + return engines.testing_engine( + options=dict(poolclass=pool.StaticPool, scope="class"), + ) + else: + return config.db + + +class HasTableTest(OneConnectionTablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + if testing.requires.schemas.enabled: + Table( + "test_table_s", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + schema=config.test_schema, + ) + + if testing.requires.view_reflection: + cls.define_views(metadata) + if testing.requires.has_temp_table.enabled: + cls.define_temp_tables(metadata) + + @classmethod + def define_views(cls, metadata): + query = "CREATE VIEW vv AS SELECT id, data FROM test_table" + + event.listen(metadata, "after_create", DDL(query)) + event.listen(metadata, "before_drop", DDL("DROP VIEW vv")) + + if testing.requires.schemas.enabled: + query = ( + "CREATE VIEW %s.vv AS SELECT id, data FROM %s.test_table_s" + % ( + config.test_schema, + config.test_schema, + ) + ) + event.listen(metadata, "after_create", DDL(query)) + event.listen( + metadata, + "before_drop", + DDL("DROP VIEW %s.vv" % (config.test_schema)), + ) + + @classmethod + def temp_table_name(cls): + return get_temp_table_name( + config, config.db, f"user_tmp_{config.ident}" + ) + + @classmethod + def define_temp_tables(cls, metadata): + kw = temp_table_keyword_args(config, config.db) + table_name = cls.temp_table_name() + user_tmp = Table( + table_name, + metadata, + Column("id", sa.INT, primary_key=True), + Column("name", sa.VARCHAR(50)), + **kw, + ) + if ( + testing.requires.view_reflection.enabled + and testing.requires.temporary_views.enabled + ): + event.listen( + user_tmp, + "after_create", + DDL( + "create temporary view user_tmp_v as " + "select * from user_tmp_%s" % config.ident + ), + ) + event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v")) + + def test_has_table(self): + with config.db.begin() as conn: + is_true(config.db.dialect.has_table(conn, "test_table")) + is_false(config.db.dialect.has_table(conn, "test_table_s")) + is_false(config.db.dialect.has_table(conn, "nonexistent_table")) + + def test_has_table_cache(self, metadata): + insp = inspect(config.db) + is_true(insp.has_table("test_table")) + nt = Table("new_table", metadata, Column("col", Integer)) + is_false(insp.has_table("new_table")) + nt.create(config.db) + try: + is_false(insp.has_table("new_table")) + insp.clear_cache() + is_true(insp.has_table("new_table")) + finally: + nt.drop(config.db) + + @testing.requires.schemas + def test_has_table_schema(self): + with config.db.begin() as conn: + is_false( + config.db.dialect.has_table( + conn, "test_table", schema=config.test_schema + ) + ) + is_true( + config.db.dialect.has_table( + conn, "test_table_s", schema=config.test_schema + ) + ) + is_false( + config.db.dialect.has_table( + conn, "nonexistent_table", schema=config.test_schema + ) + ) + + @testing.requires.schemas + def test_has_table_nonexistent_schema(self): + with config.db.begin() as conn: + is_false( + config.db.dialect.has_table( + conn, "test_table", schema="nonexistent_schema" + ) + ) + + @testing.requires.views + def test_has_table_view(self, connection): + insp = inspect(connection) + is_true(insp.has_table("vv")) + + @testing.requires.has_temp_table + def test_has_table_temp_table(self, connection): + insp = inspect(connection) + temp_table_name = self.temp_table_name() + is_true(insp.has_table(temp_table_name)) + + @testing.requires.has_temp_table + @testing.requires.view_reflection + @testing.requires.temporary_views + def test_has_table_temp_view(self, connection): + insp = inspect(connection) + is_true(insp.has_table("user_tmp_v")) + + @testing.requires.views + @testing.requires.schemas + def test_has_table_view_schema(self, connection): + insp = inspect(connection) + is_true(insp.has_table("vv", config.test_schema)) + + +class HasIndexTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + tt = Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("data2", String(50)), + ) + Index("my_idx", tt.c.data) + + if testing.requires.schemas.enabled: + tt = Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + schema=config.test_schema, + ) + Index("my_idx_s", tt.c.data) + + kind = testing.combinations("dialect", "inspector", argnames="kind") + + def _has_index(self, kind, conn): + if kind == "dialect": + return lambda *a, **k: config.db.dialect.has_index(conn, *a, **k) + else: + return inspect(conn).has_index + + @kind + def test_has_index(self, kind, connection, metadata): + meth = self._has_index(kind, connection) + assert meth("test_table", "my_idx") + assert not meth("test_table", "my_idx_s") + assert not meth("nonexistent_table", "my_idx") + assert not meth("test_table", "nonexistent_idx") + + assert not meth("test_table", "my_idx_2") + assert not meth("test_table_2", "my_idx_3") + idx = Index("my_idx_2", self.tables.test_table.c.data2) + tbl = Table( + "test_table_2", + metadata, + Column("foo", Integer), + Index("my_idx_3", "foo"), + ) + idx.create(connection) + tbl.create(connection) + try: + if kind == "inspector": + assert not meth("test_table", "my_idx_2") + assert not meth("test_table_2", "my_idx_3") + meth.__self__.clear_cache() + assert meth("test_table", "my_idx_2") is True + assert meth("test_table_2", "my_idx_3") is True + finally: + tbl.drop(connection) + idx.drop(connection) + + @testing.requires.schemas + @kind + def test_has_index_schema(self, kind, connection): + meth = self._has_index(kind, connection) + assert meth("test_table", "my_idx_s", schema=config.test_schema) + assert not meth("test_table", "my_idx", schema=config.test_schema) + assert not meth( + "nonexistent_table", "my_idx_s", schema=config.test_schema + ) + assert not meth( + "test_table", "nonexistent_idx_s", schema=config.test_schema + ) + + +class BizarroCharacterFKResolutionTest(fixtures.TestBase): + """tests for #10275""" + + __backend__ = True + + @testing.combinations( + ("id",), ("(3)",), ("col%p",), ("[brack]",), argnames="columnname" + ) + @testing.variation("use_composite", [True, False]) + @testing.combinations( + ("plain",), + ("(2)",), + ("per % cent",), + ("[brackets]",), + argnames="tablename", + ) + def test_fk_ref( + self, connection, metadata, use_composite, tablename, columnname + ): + tt = Table( + tablename, + metadata, + Column(columnname, Integer, key="id", primary_key=True), + test_needs_fk=True, + ) + if use_composite: + tt.append_column(Column("id2", Integer, primary_key=True)) + + if use_composite: + Table( + "other", + metadata, + Column("id", Integer, primary_key=True), + Column("ref", Integer), + Column("ref2", Integer), + sa.ForeignKeyConstraint(["ref", "ref2"], [tt.c.id, tt.c.id2]), + test_needs_fk=True, + ) + else: + Table( + "other", + metadata, + Column("id", Integer, primary_key=True), + Column("ref", ForeignKey(tt.c.id)), + test_needs_fk=True, + ) + + metadata.create_all(connection) + + m2 = MetaData() + + o2 = Table("other", m2, autoload_with=connection) + t1 = m2.tables[tablename] + + assert o2.c.ref.references(t1.c[0]) + if use_composite: + assert o2.c.ref2.references(t1.c[1]) + + +class QuotedNameArgumentTest(fixtures.TablesTest): + run_create_tables = "once" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "quote ' one", + metadata, + Column("id", Integer), + Column("name", String(50)), + Column("data", String(50)), + Column("related_id", Integer), + sa.PrimaryKeyConstraint("id", name="pk quote ' one"), + sa.Index("ix quote ' one", "name"), + sa.UniqueConstraint( + "data", + name="uq quote' one", + ), + sa.ForeignKeyConstraint( + ["id"], ["related.id"], name="fk quote ' one" + ), + sa.CheckConstraint("name != 'foo'", name="ck quote ' one"), + comment=r"""quote ' one comment""", + test_needs_fk=True, + ) + + if testing.requires.symbol_names_w_double_quote.enabled: + Table( + 'quote " two', + metadata, + Column("id", Integer), + Column("name", String(50)), + Column("data", String(50)), + Column("related_id", Integer), + sa.PrimaryKeyConstraint("id", name='pk quote " two'), + sa.Index('ix quote " two', "name"), + sa.UniqueConstraint( + "data", + name='uq quote" two', + ), + sa.ForeignKeyConstraint( + ["id"], ["related.id"], name='fk quote " two' + ), + sa.CheckConstraint("name != 'foo'", name='ck quote " two '), + comment=r"""quote " two comment""", + test_needs_fk=True, + ) + + Table( + "related", + metadata, + Column("id", Integer, primary_key=True), + Column("related", Integer), + test_needs_fk=True, + ) + + if testing.requires.view_column_reflection.enabled: + if testing.requires.symbol_names_w_double_quote.enabled: + names = [ + "quote ' one", + 'quote " two', + ] + else: + names = [ + "quote ' one", + ] + for name in names: + query = "CREATE VIEW %s AS SELECT * FROM %s" % ( + config.db.dialect.identifier_preparer.quote( + "view %s" % name + ), + config.db.dialect.identifier_preparer.quote(name), + ) + + event.listen(metadata, "after_create", DDL(query)) + event.listen( + metadata, + "before_drop", + DDL( + "DROP VIEW %s" + % config.db.dialect.identifier_preparer.quote( + "view %s" % name + ) + ), + ) + + def quote_fixtures(fn): + return testing.combinations( + ("quote ' one",), + ('quote " two', testing.requires.symbol_names_w_double_quote), + )(fn) + + @quote_fixtures + def test_get_table_options(self, name): + insp = inspect(config.db) + + if testing.requires.reflect_table_options.enabled: + res = insp.get_table_options(name) + is_true(isinstance(res, dict)) + else: + with expect_raises(NotImplementedError): + res = insp.get_table_options(name) + + @quote_fixtures + @testing.requires.view_column_reflection + def test_get_view_definition(self, name): + insp = inspect(config.db) + assert insp.get_view_definition("view %s" % name) + + @quote_fixtures + def test_get_columns(self, name): + insp = inspect(config.db) + assert insp.get_columns(name) + + @quote_fixtures + def test_get_pk_constraint(self, name): + insp = inspect(config.db) + assert insp.get_pk_constraint(name) + + @quote_fixtures + def test_get_foreign_keys(self, name): + insp = inspect(config.db) + assert insp.get_foreign_keys(name) + + @quote_fixtures + def test_get_indexes(self, name): + insp = inspect(config.db) + assert insp.get_indexes(name) + + @quote_fixtures + @testing.requires.unique_constraint_reflection + def test_get_unique_constraints(self, name): + insp = inspect(config.db) + assert insp.get_unique_constraints(name) + + @quote_fixtures + @testing.requires.comment_reflection + def test_get_table_comment(self, name): + insp = inspect(config.db) + assert insp.get_table_comment(name) + + @quote_fixtures + @testing.requires.check_constraint_reflection + def test_get_check_constraints(self, name): + insp = inspect(config.db) + assert insp.get_check_constraints(name) + + +def _multi_combination(fn): + schema = testing.combinations( + None, + ( + lambda: config.test_schema, + testing.requires.schemas, + ), + argnames="schema", + ) + scope = testing.combinations( + ObjectScope.DEFAULT, + ObjectScope.TEMPORARY, + ObjectScope.ANY, + argnames="scope", + ) + kind = testing.combinations( + ObjectKind.TABLE, + ObjectKind.VIEW, + ObjectKind.MATERIALIZED_VIEW, + ObjectKind.ANY, + ObjectKind.ANY_VIEW, + ObjectKind.TABLE | ObjectKind.VIEW, + ObjectKind.TABLE | ObjectKind.MATERIALIZED_VIEW, + argnames="kind", + ) + filter_names = testing.combinations(True, False, argnames="use_filter") + + return schema(scope(kind(filter_names(fn)))) + + +class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest): + run_inserts = run_deletes = None + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + cls.define_reflected_tables(metadata, None) + if testing.requires.schemas.enabled: + cls.define_reflected_tables(metadata, testing.config.test_schema) + + @classmethod + def define_reflected_tables(cls, metadata, schema): + if schema: + schema_prefix = schema + "." + else: + schema_prefix = "" + + if testing.requires.self_referential_foreign_keys.enabled: + parent_id_args = ( + ForeignKey( + "%susers.user_id" % schema_prefix, name="user_id_fk" + ), + ) + else: + parent_id_args = () + users = Table( + "users", + metadata, + Column("user_id", sa.INT, primary_key=True), + Column("test1", sa.CHAR(5), nullable=False), + Column("test2", sa.Float(), nullable=False), + Column("parent_user_id", sa.Integer, *parent_id_args), + sa.CheckConstraint( + "test2 > 0", + name="zz_test2_gt_zero", + comment="users check constraint", + ), + sa.CheckConstraint("test2 <= 1000"), + schema=schema, + test_needs_fk=True, + ) + + Table( + "dingalings", + metadata, + Column("dingaling_id", sa.Integer, primary_key=True), + Column( + "address_id", + sa.Integer, + ForeignKey( + "%semail_addresses.address_id" % schema_prefix, + name="zz_email_add_id_fg", + comment="di fk comment", + ), + ), + Column( + "id_user", + sa.Integer, + ForeignKey("%susers.user_id" % schema_prefix), + ), + Column("data", sa.String(30), unique=True), + sa.CheckConstraint( + "address_id > 0 AND address_id < 1000", + name="address_id_gt_zero", + ), + sa.UniqueConstraint( + "address_id", + "dingaling_id", + name="zz_dingalings_multiple", + comment="di unique comment", + ), + schema=schema, + test_needs_fk=True, + ) + Table( + "email_addresses", + metadata, + Column("address_id", sa.Integer), + Column("remote_user_id", sa.Integer, ForeignKey(users.c.user_id)), + Column("email_address", sa.String(20), index=True), + sa.PrimaryKeyConstraint( + "address_id", name="email_ad_pk", comment="ea pk comment" + ), + schema=schema, + test_needs_fk=True, + ) + Table( + "comment_test", + metadata, + Column("id", sa.Integer, primary_key=True, comment="id comment"), + Column("data", sa.String(20), comment="data % comment"), + Column( + "d2", + sa.String(20), + comment=r"""Comment types type speedily ' " \ '' Fun!""", + ), + Column("d3", sa.String(42), comment="Comment\nwith\rescapes"), + schema=schema, + comment=r"""the test % ' " \ table comment""", + ) + Table( + "no_constraints", + metadata, + Column("data", sa.String(20)), + schema=schema, + comment="no\nconstraints\rhas\fescaped\vcomment", + ) + + if testing.requires.cross_schema_fk_reflection.enabled: + if schema is None: + Table( + "local_table", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(20)), + Column( + "remote_id", + ForeignKey( + "%s.remote_table_2.id" % testing.config.test_schema + ), + ), + test_needs_fk=True, + schema=config.db.dialect.default_schema_name, + ) + else: + Table( + "remote_table", + metadata, + Column("id", sa.Integer, primary_key=True), + Column( + "local_id", + ForeignKey( + "%s.local_table.id" + % config.db.dialect.default_schema_name + ), + ), + Column("data", sa.String(20)), + schema=schema, + test_needs_fk=True, + ) + Table( + "remote_table_2", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(20)), + schema=schema, + test_needs_fk=True, + ) + + if testing.requires.index_reflection.enabled: + Index("users_t_idx", users.c.test1, users.c.test2, unique=True) + Index( + "users_all_idx", users.c.user_id, users.c.test2, users.c.test1 + ) + + if not schema: + # test_needs_fk is at the moment to force MySQL InnoDB + noncol_idx_test_nopk = Table( + "noncol_idx_test_nopk", + metadata, + Column("q", sa.String(5)), + test_needs_fk=True, + ) + + noncol_idx_test_pk = Table( + "noncol_idx_test_pk", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("q", sa.String(5)), + test_needs_fk=True, + ) + + if ( + testing.requires.indexes_with_ascdesc.enabled + and testing.requires.reflect_indexes_with_ascdesc.enabled + ): + Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc()) + Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc()) + + if testing.requires.view_column_reflection.enabled: + cls.define_views(metadata, schema) + if not schema and testing.requires.temp_table_reflection.enabled: + cls.define_temp_tables(metadata) + + @classmethod + def temp_table_name(cls): + return get_temp_table_name( + config, config.db, f"user_tmp_{config.ident}" + ) + + @classmethod + def define_temp_tables(cls, metadata): + kw = temp_table_keyword_args(config, config.db) + table_name = cls.temp_table_name() + user_tmp = Table( + table_name, + metadata, + Column("id", sa.INT, primary_key=True), + Column("name", sa.VARCHAR(50)), + Column("foo", sa.INT), + # disambiguate temp table unique constraint names. this is + # pretty arbitrary for a generic dialect however we are doing + # it to suit SQL Server which will produce name conflicts for + # unique constraints created against temp tables in different + # databases. + # https://www.arbinada.com/en/node/1645 + sa.UniqueConstraint("name", name=f"user_tmp_uq_{config.ident}"), + sa.Index("user_tmp_ix", "foo"), + **kw, + ) + if ( + testing.requires.view_reflection.enabled + and testing.requires.temporary_views.enabled + ): + event.listen( + user_tmp, + "after_create", + DDL( + "create temporary view user_tmp_v as " + "select * from user_tmp_%s" % config.ident + ), + ) + event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v")) + + @classmethod + def define_views(cls, metadata, schema): + if testing.requires.materialized_views.enabled: + materialized = {"dingalings"} + else: + materialized = set() + for table_name in ("users", "email_addresses", "dingalings"): + fullname = table_name + if schema: + fullname = f"{schema}.{table_name}" + view_name = fullname + "_v" + prefix = "MATERIALIZED " if table_name in materialized else "" + query = ( + f"CREATE {prefix}VIEW {view_name} AS SELECT * FROM {fullname}" + ) + + event.listen(metadata, "after_create", DDL(query)) + if table_name in materialized: + index_name = "mat_index" + if schema and testing.against("oracle"): + index_name = f"{schema}.{index_name}" + idx = f"CREATE INDEX {index_name} ON {view_name}(data)" + event.listen(metadata, "after_create", DDL(idx)) + event.listen( + metadata, "before_drop", DDL(f"DROP {prefix}VIEW {view_name}") + ) + + def _resolve_kind(self, kind, tables, views, materialized): + res = {} + if ObjectKind.TABLE in kind: + res.update(tables) + if ObjectKind.VIEW in kind: + res.update(views) + if ObjectKind.MATERIALIZED_VIEW in kind: + res.update(materialized) + return res + + def _resolve_views(self, views, materialized): + if not testing.requires.view_column_reflection.enabled: + materialized.clear() + views.clear() + elif not testing.requires.materialized_views.enabled: + views.update(materialized) + materialized.clear() + + def _resolve_names(self, schema, scope, filter_names, values): + scope_filter = lambda _: True # noqa: E731 + if scope is ObjectScope.DEFAULT: + scope_filter = lambda k: "tmp" not in k[1] # noqa: E731 + if scope is ObjectScope.TEMPORARY: + scope_filter = lambda k: "tmp" in k[1] # noqa: E731 + + removed = { + None: {"remote_table", "remote_table_2"}, + testing.config.test_schema: { + "local_table", + "noncol_idx_test_nopk", + "noncol_idx_test_pk", + "user_tmp_v", + self.temp_table_name(), + }, + } + if not testing.requires.cross_schema_fk_reflection.enabled: + removed[None].add("local_table") + removed[testing.config.test_schema].update( + ["remote_table", "remote_table_2"] + ) + if not testing.requires.index_reflection.enabled: + removed[None].update( + ["noncol_idx_test_nopk", "noncol_idx_test_pk"] + ) + if ( + not testing.requires.temp_table_reflection.enabled + or not testing.requires.temp_table_names.enabled + ): + removed[None].update(["user_tmp_v", self.temp_table_name()]) + if not testing.requires.temporary_views.enabled: + removed[None].update(["user_tmp_v"]) + + res = { + k: v + for k, v in values.items() + if scope_filter(k) + and k[1] not in removed[schema] + and (not filter_names or k[1] in filter_names) + } + return res + + def exp_options( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + materialized = {(schema, "dingalings_v"): mock.ANY} + views = { + (schema, "email_addresses_v"): mock.ANY, + (schema, "users_v"): mock.ANY, + (schema, "user_tmp_v"): mock.ANY, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): mock.ANY, + (schema, "dingalings"): mock.ANY, + (schema, "email_addresses"): mock.ANY, + (schema, "comment_test"): mock.ANY, + (schema, "no_constraints"): mock.ANY, + (schema, "local_table"): mock.ANY, + (schema, "remote_table"): mock.ANY, + (schema, "remote_table_2"): mock.ANY, + (schema, "noncol_idx_test_nopk"): mock.ANY, + (schema, "noncol_idx_test_pk"): mock.ANY, + (schema, self.temp_table_name()): mock.ANY, + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + def exp_comments( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + empty = {"text": None} + materialized = {(schema, "dingalings_v"): empty} + views = { + (schema, "email_addresses_v"): empty, + (schema, "users_v"): empty, + (schema, "user_tmp_v"): empty, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): empty, + (schema, "dingalings"): empty, + (schema, "email_addresses"): empty, + (schema, "comment_test"): { + "text": r"""the test % ' " \ table comment""" + }, + (schema, "no_constraints"): { + "text": "no\nconstraints\rhas\fescaped\vcomment" + }, + (schema, "local_table"): empty, + (schema, "remote_table"): empty, + (schema, "remote_table_2"): empty, + (schema, "noncol_idx_test_nopk"): empty, + (schema, "noncol_idx_test_pk"): empty, + (schema, self.temp_table_name()): empty, + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + def exp_columns( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def col( + name, auto=False, default=mock.ANY, comment=None, nullable=True + ): + res = { + "name": name, + "autoincrement": auto, + "type": mock.ANY, + "default": default, + "comment": comment, + "nullable": nullable, + } + if auto == "omit": + res.pop("autoincrement") + return res + + def pk(name, **kw): + kw = {"auto": True, "default": mock.ANY, "nullable": False, **kw} + return col(name, **kw) + + materialized = { + (schema, "dingalings_v"): [ + col("dingaling_id", auto="omit", nullable=mock.ANY), + col("address_id"), + col("id_user"), + col("data"), + ] + } + views = { + (schema, "email_addresses_v"): [ + col("address_id", auto="omit", nullable=mock.ANY), + col("remote_user_id"), + col("email_address"), + ], + (schema, "users_v"): [ + col("user_id", auto="omit", nullable=mock.ANY), + col("test1", nullable=mock.ANY), + col("test2", nullable=mock.ANY), + col("parent_user_id"), + ], + (schema, "user_tmp_v"): [ + col("id", auto="omit", nullable=mock.ANY), + col("name"), + col("foo"), + ], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + pk("user_id"), + col("test1", nullable=False), + col("test2", nullable=False), + col("parent_user_id"), + ], + (schema, "dingalings"): [ + pk("dingaling_id"), + col("address_id"), + col("id_user"), + col("data"), + ], + (schema, "email_addresses"): [ + pk("address_id"), + col("remote_user_id"), + col("email_address"), + ], + (schema, "comment_test"): [ + pk("id", comment="id comment"), + col("data", comment="data % comment"), + col( + "d2", + comment=r"""Comment types type speedily ' " \ '' Fun!""", + ), + col("d3", comment="Comment\nwith\rescapes"), + ], + (schema, "no_constraints"): [col("data")], + (schema, "local_table"): [pk("id"), col("data"), col("remote_id")], + (schema, "remote_table"): [pk("id"), col("local_id"), col("data")], + (schema, "remote_table_2"): [pk("id"), col("data")], + (schema, "noncol_idx_test_nopk"): [col("q")], + (schema, "noncol_idx_test_pk"): [pk("id"), col("q")], + (schema, self.temp_table_name()): [ + pk("id"), + col("name"), + col("foo"), + ], + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_column_keys(self): + return {"name", "type", "nullable", "default"} + + def exp_pks( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def pk(*cols, name=mock.ANY, comment=None): + return { + "constrained_columns": list(cols), + "name": name, + "comment": comment, + } + + empty = pk(name=None) + if testing.requires.materialized_views_reflect_pk.enabled: + materialized = {(schema, "dingalings_v"): pk("dingaling_id")} + else: + materialized = {(schema, "dingalings_v"): empty} + views = { + (schema, "email_addresses_v"): empty, + (schema, "users_v"): empty, + (schema, "user_tmp_v"): empty, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): pk("user_id"), + (schema, "dingalings"): pk("dingaling_id"), + (schema, "email_addresses"): pk( + "address_id", name="email_ad_pk", comment="ea pk comment" + ), + (schema, "comment_test"): pk("id"), + (schema, "no_constraints"): empty, + (schema, "local_table"): pk("id"), + (schema, "remote_table"): pk("id"), + (schema, "remote_table_2"): pk("id"), + (schema, "noncol_idx_test_nopk"): empty, + (schema, "noncol_idx_test_pk"): pk("id"), + (schema, self.temp_table_name()): pk("id"), + } + if not testing.requires.reflects_pk_names.enabled: + for val in tables.values(): + if val["name"] is not None: + val["name"] = mock.ANY + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_pk_keys(self): + return {"name", "constrained_columns"} + + def exp_fks( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + class tt: + def __eq__(self, other): + return ( + other is None + or config.db.dialect.default_schema_name == other + ) + + def fk( + cols, + ref_col, + ref_table, + ref_schema=schema, + name=mock.ANY, + comment=None, + ): + return { + "constrained_columns": cols, + "referred_columns": ref_col, + "name": name, + "options": mock.ANY, + "referred_schema": ( + ref_schema if ref_schema is not None else tt() + ), + "referred_table": ref_table, + "comment": comment, + } + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + fk(["parent_user_id"], ["user_id"], "users", name="user_id_fk") + ], + (schema, "dingalings"): [ + fk(["id_user"], ["user_id"], "users"), + fk( + ["address_id"], + ["address_id"], + "email_addresses", + name="zz_email_add_id_fg", + comment="di fk comment", + ), + ], + (schema, "email_addresses"): [ + fk(["remote_user_id"], ["user_id"], "users") + ], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [ + fk( + ["remote_id"], + ["id"], + "remote_table_2", + ref_schema=config.test_schema, + ) + ], + (schema, "remote_table"): [ + fk(["local_id"], ["id"], "local_table", ref_schema=None) + ], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [], + } + if not testing.requires.self_referential_foreign_keys.enabled: + tables[(schema, "users")].clear() + if not testing.requires.named_constraints.enabled: + for vals in tables.values(): + for val in vals: + if val["name"] is not mock.ANY: + val["name"] = mock.ANY + + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_fk_keys(self): + return { + "name", + "constrained_columns", + "referred_schema", + "referred_table", + "referred_columns", + } + + def exp_indexes( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def idx( + *cols, + name, + unique=False, + column_sorting=None, + duplicates=False, + fk=False, + ): + fk_req = testing.requires.foreign_keys_reflect_as_index + dup_req = testing.requires.unique_constraints_reflect_as_index + sorting_expression = ( + testing.requires.reflect_indexes_with_ascdesc_as_expression + ) + + if (fk and not fk_req.enabled) or ( + duplicates and not dup_req.enabled + ): + return () + res = { + "unique": unique, + "column_names": list(cols), + "name": name, + "dialect_options": mock.ANY, + "include_columns": [], + } + if column_sorting: + res["column_sorting"] = column_sorting + if sorting_expression.enabled: + res["expressions"] = orig = res["column_names"] + res["column_names"] = [ + None if c in column_sorting else c for c in orig + ] + + if duplicates: + res["duplicates_constraint"] = name + return [res] + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + if materialized: + materialized[(schema, "dingalings_v")].extend( + idx("data", name="mat_index") + ) + tables = { + (schema, "users"): [ + *idx("parent_user_id", name="user_id_fk", fk=True), + *idx("user_id", "test2", "test1", name="users_all_idx"), + *idx("test1", "test2", name="users_t_idx", unique=True), + ], + (schema, "dingalings"): [ + *idx("data", name=mock.ANY, unique=True, duplicates=True), + *idx("id_user", name=mock.ANY, fk=True), + *idx( + "address_id", + "dingaling_id", + name="zz_dingalings_multiple", + unique=True, + duplicates=True, + ), + ], + (schema, "email_addresses"): [ + *idx("email_address", name=mock.ANY), + *idx("remote_user_id", name=mock.ANY, fk=True), + ], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [ + *idx("remote_id", name=mock.ANY, fk=True) + ], + (schema, "remote_table"): [ + *idx("local_id", name=mock.ANY, fk=True) + ], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [ + *idx( + "q", + name="noncol_idx_nopk", + column_sorting={"q": ("desc",)}, + ) + ], + (schema, "noncol_idx_test_pk"): [ + *idx( + "q", name="noncol_idx_pk", column_sorting={"q": ("desc",)} + ) + ], + (schema, self.temp_table_name()): [ + *idx("foo", name="user_tmp_ix"), + *idx( + "name", + name=f"user_tmp_uq_{config.ident}", + duplicates=True, + unique=True, + ), + ], + } + if ( + not testing.requires.indexes_with_ascdesc.enabled + or not testing.requires.reflect_indexes_with_ascdesc.enabled + ): + tables[(schema, "noncol_idx_test_nopk")].clear() + tables[(schema, "noncol_idx_test_pk")].clear() + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_index_keys(self): + return {"name", "column_names", "unique"} + + def exp_ucs( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + all_=False, + ): + def uc( + *cols, name, duplicates_index=None, is_index=False, comment=None + ): + req = testing.requires.unique_index_reflect_as_unique_constraints + if is_index and not req.enabled: + return () + res = { + "column_names": list(cols), + "name": name, + "comment": comment, + } + if duplicates_index: + res["duplicates_index"] = duplicates_index + return [res] + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + *uc( + "test1", + "test2", + name="users_t_idx", + duplicates_index="users_t_idx", + is_index=True, + ) + ], + (schema, "dingalings"): [ + *uc("data", name=mock.ANY, duplicates_index=mock.ANY), + *uc( + "address_id", + "dingaling_id", + name="zz_dingalings_multiple", + duplicates_index="zz_dingalings_multiple", + comment="di unique comment", + ), + ], + (schema, "email_addresses"): [], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [], + (schema, "remote_table"): [], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [ + *uc("name", name=f"user_tmp_uq_{config.ident}") + ], + } + if all_: + return {**materialized, **views, **tables} + else: + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_unique_cst_keys(self): + return {"name", "column_names"} + + def exp_ccs( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + class tt(str): + def __eq__(self, other): + res = ( + other.lower() + .replace("(", "") + .replace(")", "") + .replace("`", "") + ) + return self in res + + def cc(text, name, comment=None): + return {"sqltext": tt(text), "name": name, "comment": comment} + + # print({1: "test2 > (0)::double precision"} == {1: tt("test2 > 0")}) + # assert 0 + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + cc("test2 <= 1000", mock.ANY), + cc( + "test2 > 0", + "zz_test2_gt_zero", + comment="users check constraint", + ), + ], + (schema, "dingalings"): [ + cc( + "address_id > 0 and address_id < 1000", + name="address_id_gt_zero", + ), + ], + (schema, "email_addresses"): [], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [], + (schema, "remote_table"): [], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [], + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_cc_keys(self): + return {"name", "sqltext"} + + @testing.requires.schema_reflection + def test_get_schema_names(self, connection): + insp = inspect(connection) + + is_true(testing.config.test_schema in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_has_schema(self, connection): + insp = inspect(connection) + + is_true(insp.has_schema(testing.config.test_schema)) + is_false(insp.has_schema("sa_fake_schema_foo")) + + @testing.requires.schema_reflection + def test_get_schema_names_w_translate_map(self, connection): + """test #7300""" + + connection = connection.execution_options( + schema_translate_map={ + "foo": "bar", + BLANK_SCHEMA: testing.config.test_schema, + } + ) + insp = inspect(connection) + + is_true(testing.config.test_schema in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_has_schema_w_translate_map(self, connection): + connection = connection.execution_options( + schema_translate_map={ + "foo": "bar", + BLANK_SCHEMA: testing.config.test_schema, + } + ) + insp = inspect(connection) + + is_true(insp.has_schema(testing.config.test_schema)) + is_false(insp.has_schema("sa_fake_schema_foo")) + + @testing.requires.schema_reflection + @testing.requires.schema_create_delete + def test_schema_cache(self, connection): + insp = inspect(connection) + + is_false("foo_bar" in insp.get_schema_names()) + is_false(insp.has_schema("foo_bar")) + connection.execute(DDL("CREATE SCHEMA foo_bar")) + try: + is_false("foo_bar" in insp.get_schema_names()) + is_false(insp.has_schema("foo_bar")) + insp.clear_cache() + is_true("foo_bar" in insp.get_schema_names()) + is_true(insp.has_schema("foo_bar")) + finally: + connection.execute(DDL("DROP SCHEMA foo_bar")) + + @testing.requires.schema_reflection + def test_dialect_initialize(self): + engine = engines.testing_engine() + inspect(engine) + assert hasattr(engine.dialect, "default_schema_name") + + @testing.requires.schema_reflection + def test_get_default_schema_name(self, connection): + insp = inspect(connection) + eq_(insp.default_schema_name, connection.dialect.default_schema_name) + + @testing.combinations( + None, + ("foreign_key", testing.requires.foreign_key_constraint_reflection), + argnames="order_by", + ) + @testing.combinations( + (True, testing.requires.schemas), False, argnames="use_schema" + ) + def test_get_table_names(self, connection, order_by, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + _ignore_tables = { + "comment_test", + "noncol_idx_test_pk", + "noncol_idx_test_nopk", + "local_table", + "remote_table", + "remote_table_2", + "no_constraints", + } + + insp = inspect(connection) + + if order_by: + tables = [ + rec[0] + for rec in insp.get_sorted_table_and_fkc_names(schema) + if rec[0] + ] + else: + tables = insp.get_table_names(schema) + table_names = [t for t in tables if t not in _ignore_tables] + + if order_by == "foreign_key": + answer = ["users", "email_addresses", "dingalings"] + eq_(table_names, answer) + else: + answer = ["dingalings", "email_addresses", "users"] + eq_(sorted(table_names), answer) + + @testing.combinations( + (True, testing.requires.schemas), False, argnames="use_schema" + ) + def test_get_view_names(self, connection, use_schema): + insp = inspect(connection) + if use_schema: + schema = config.test_schema + else: + schema = None + table_names = insp.get_view_names(schema) + if testing.requires.materialized_views.enabled: + eq_(sorted(table_names), ["email_addresses_v", "users_v"]) + eq_(insp.get_materialized_view_names(schema), ["dingalings_v"]) + else: + answer = ["dingalings_v", "email_addresses_v", "users_v"] + eq_(sorted(table_names), answer) + + @testing.requires.temp_table_names + def test_get_temp_table_names(self, connection): + insp = inspect(connection) + temp_table_names = insp.get_temp_table_names() + eq_(sorted(temp_table_names), [f"user_tmp_{config.ident}"]) + + @testing.requires.view_reflection + @testing.requires.temporary_views + def test_get_temp_view_names(self, connection): + insp = inspect(connection) + temp_table_names = insp.get_temp_view_names() + eq_(sorted(temp_table_names), ["user_tmp_v"]) + + @testing.requires.comment_reflection + def test_get_comments(self, connection): + self._test_get_comments(connection) + + @testing.requires.comment_reflection + @testing.requires.schemas + def test_get_comments_with_schema(self, connection): + self._test_get_comments(connection, testing.config.test_schema) + + def _test_get_comments(self, connection, schema=None): + insp = inspect(connection) + exp = self.exp_comments(schema=schema) + eq_( + insp.get_table_comment("comment_test", schema=schema), + exp[(schema, "comment_test")], + ) + + eq_( + insp.get_table_comment("users", schema=schema), + exp[(schema, "users")], + ) + + eq_( + insp.get_table_comment("comment_test", schema=schema), + exp[(schema, "comment_test")], + ) + + no_cst = self.tables.no_constraints.name + eq_( + insp.get_table_comment(no_cst, schema=schema), + exp[(schema, no_cst)], + ) + + @testing.combinations( + (False, False), + (False, True, testing.requires.schemas), + (True, False, testing.requires.view_reflection), + ( + True, + True, + testing.requires.schemas + testing.requires.view_reflection, + ), + argnames="use_views,use_schema", + ) + def test_get_columns(self, connection, use_views, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + users, addresses = (self.tables.users, self.tables.email_addresses) + if use_views: + table_names = ["users_v", "email_addresses_v", "dingalings_v"] + else: + table_names = ["users", "email_addresses"] + + insp = inspect(connection) + for table_name, table in zip(table_names, (users, addresses)): + schema_name = schema + cols = insp.get_columns(table_name, schema=schema_name) + is_true(len(cols) > 0, len(cols)) + + # should be in order + + for i, col in enumerate(table.columns): + eq_(col.name, cols[i]["name"]) + ctype = cols[i]["type"].__class__ + ctype_def = col.type + if isinstance(ctype_def, sa.types.TypeEngine): + ctype_def = ctype_def.__class__ + + # Oracle returns Date for DateTime. + + if testing.against("oracle") and ctype_def in ( + sql_types.Date, + sql_types.DateTime, + ): + ctype_def = sql_types.Date + + # assert that the desired type and return type share + # a base within one of the generic types. + + is_true( + len( + set(ctype.__mro__) + .intersection(ctype_def.__mro__) + .intersection( + [ + sql_types.Integer, + sql_types.Numeric, + sql_types.DateTime, + sql_types.Date, + sql_types.Time, + sql_types.String, + sql_types._Binary, + ] + ) + ) + > 0, + "%s(%s), %s(%s)" + % (col.name, col.type, cols[i]["name"], ctype), + ) + + if not col.primary_key: + assert cols[i]["default"] is None + + # The case of a table with no column + # is tested below in TableNoColumnsTest + + @testing.requires.temp_table_reflection + def test_reflect_table_temp_table(self, connection): + table_name = self.temp_table_name() + user_tmp = self.tables[table_name] + + reflected_user_tmp = Table( + table_name, MetaData(), autoload_with=connection + ) + self.assert_tables_equal( + user_tmp, reflected_user_tmp, strict_constraints=False + ) + + @testing.requires.temp_table_reflection + def test_get_temp_table_columns(self, connection): + table_name = self.temp_table_name() + user_tmp = self.tables[table_name] + insp = inspect(connection) + cols = insp.get_columns(table_name) + is_true(len(cols) > 0, len(cols)) + + for i, col in enumerate(user_tmp.columns): + eq_(col.name, cols[i]["name"]) + + @testing.requires.temp_table_reflection + @testing.requires.view_column_reflection + @testing.requires.temporary_views + def test_get_temp_view_columns(self, connection): + insp = inspect(connection) + cols = insp.get_columns("user_tmp_v") + eq_([col["name"] for col in cols], ["id", "name", "foo"]) + + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.primary_key_constraint_reflection + def test_get_pk_constraint(self, connection, use_schema): + if use_schema: + schema = testing.config.test_schema + else: + schema = None + + users, addresses = self.tables.users, self.tables.email_addresses + insp = inspect(connection) + exp = self.exp_pks(schema=schema) + + users_cons = insp.get_pk_constraint(users.name, schema=schema) + self._check_list( + [users_cons], [exp[(schema, users.name)]], self._required_pk_keys + ) + + addr_cons = insp.get_pk_constraint(addresses.name, schema=schema) + exp_cols = exp[(schema, addresses.name)]["constrained_columns"] + eq_(addr_cons["constrained_columns"], exp_cols) + + with testing.requires.reflects_pk_names.fail_if(): + eq_(addr_cons["name"], "email_ad_pk") + + no_cst = self.tables.no_constraints.name + self._check_list( + [insp.get_pk_constraint(no_cst, schema=schema)], + [exp[(schema, no_cst)]], + self._required_pk_keys, + ) + + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.foreign_key_constraint_reflection + def test_get_foreign_keys(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + users, addresses = (self.tables.users, self.tables.email_addresses) + insp = inspect(connection) + expected_schema = schema + # users + + if testing.requires.self_referential_foreign_keys.enabled: + users_fkeys = insp.get_foreign_keys(users.name, schema=schema) + fkey1 = users_fkeys[0] + + with testing.requires.named_constraints.fail_if(): + eq_(fkey1["name"], "user_id_fk") + + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) + eq_(fkey1["constrained_columns"], ["parent_user_id"]) + + # addresses + addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema) + fkey1 = addr_fkeys[0] + + with testing.requires.implicitly_named_constraints.fail_if(): + is_true(fkey1["name"] is not None) + + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) + eq_(fkey1["constrained_columns"], ["remote_user_id"]) + + no_cst = self.tables.no_constraints.name + eq_(insp.get_foreign_keys(no_cst, schema=schema), []) + + @testing.requires.cross_schema_fk_reflection + @testing.requires.schemas + def test_get_inter_schema_foreign_keys(self, connection): + local_table, remote_table, remote_table_2 = self.tables( + "%s.local_table" % connection.dialect.default_schema_name, + "%s.remote_table" % testing.config.test_schema, + "%s.remote_table_2" % testing.config.test_schema, + ) + + insp = inspect(connection) + + local_fkeys = insp.get_foreign_keys(local_table.name) + eq_(len(local_fkeys), 1) + + fkey1 = local_fkeys[0] + eq_(fkey1["referred_schema"], testing.config.test_schema) + eq_(fkey1["referred_table"], remote_table_2.name) + eq_(fkey1["referred_columns"], ["id"]) + eq_(fkey1["constrained_columns"], ["remote_id"]) + + remote_fkeys = insp.get_foreign_keys( + remote_table.name, schema=testing.config.test_schema + ) + eq_(len(remote_fkeys), 1) + + fkey2 = remote_fkeys[0] + + is_true( + fkey2["referred_schema"] + in ( + None, + connection.dialect.default_schema_name, + ) + ) + eq_(fkey2["referred_table"], local_table.name) + eq_(fkey2["referred_columns"], ["id"]) + eq_(fkey2["constrained_columns"], ["local_id"]) + + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.index_reflection + def test_get_indexes(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + # The database may decide to create indexes for foreign keys, etc. + # so there may be more indexes than expected. + insp = inspect(connection) + indexes = insp.get_indexes("users", schema=schema) + exp = self.exp_indexes(schema=schema) + self._check_list( + indexes, exp[(schema, "users")], self._required_index_keys + ) + + no_cst = self.tables.no_constraints.name + self._check_list( + insp.get_indexes(no_cst, schema=schema), + exp[(schema, no_cst)], + self._required_index_keys, + ) + + @testing.combinations( + ("noncol_idx_test_nopk", "noncol_idx_nopk"), + ("noncol_idx_test_pk", "noncol_idx_pk"), + argnames="tname,ixname", + ) + @testing.requires.index_reflection + @testing.requires.indexes_with_ascdesc + @testing.requires.reflect_indexes_with_ascdesc + def test_get_noncol_index(self, connection, tname, ixname): + insp = inspect(connection) + indexes = insp.get_indexes(tname) + # reflecting an index that has "x DESC" in it as the column. + # the DB may or may not give us "x", but make sure we get the index + # back, it has a name, it's connected to the table. + expected_indexes = self.exp_indexes()[(None, tname)] + self._check_list(indexes, expected_indexes, self._required_index_keys) + + t = Table(tname, MetaData(), autoload_with=connection) + eq_(len(t.indexes), 1) + is_(list(t.indexes)[0].table, t) + eq_(list(t.indexes)[0].name, ixname) + + @testing.requires.temp_table_reflection + @testing.requires.unique_constraint_reflection + def test_get_temp_table_unique_constraints(self, connection): + insp = inspect(connection) + name = self.temp_table_name() + reflected = insp.get_unique_constraints(name) + exp = self.exp_ucs(all_=True)[(None, name)] + self._check_list(reflected, exp, self._required_index_keys) + + @testing.requires.temp_table_reflect_indexes + def test_get_temp_table_indexes(self, connection): + insp = inspect(connection) + table_name = self.temp_table_name() + indexes = insp.get_indexes(table_name) + for ind in indexes: + ind.pop("dialect_options", None) + expected = [ + {"unique": False, "column_names": ["foo"], "name": "user_tmp_ix"} + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] + eq_( + [idx for idx in indexes if idx["name"] == "user_tmp_ix"], + expected, + ) + + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) + @testing.requires.unique_constraint_reflection + def test_get_unique_constraints(self, metadata, connection, use_schema): + # SQLite dialect needs to parse the names of the constraints + # separately from what it gets from PRAGMA index_list(), and + # then matches them up. so same set of column_names in two + # constraints will confuse it. Perhaps we should no longer + # bother with index_list() here since we have the whole + # CREATE TABLE? + + if use_schema: + schema = config.test_schema + else: + schema = None + uniques = sorted( + [ + {"name": "unique_a", "column_names": ["a"]}, + {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]}, + {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]}, + {"name": "unique_asc_key", "column_names": ["asc", "key"]}, + {"name": "i.have.dots", "column_names": ["b"]}, + {"name": "i have spaces", "column_names": ["c"]}, + ], + key=operator.itemgetter("name"), + ) + table = Table( + "testtbl", + metadata, + Column("a", sa.String(20)), + Column("b", sa.String(30)), + Column("c", sa.Integer), + # reserved identifiers + Column("asc", sa.String(30)), + Column("key", sa.String(30)), + schema=schema, + ) + for uc in uniques: + table.append_constraint( + sa.UniqueConstraint(*uc["column_names"], name=uc["name"]) + ) + table.create(connection) + + insp = inspect(connection) + reflected = sorted( + insp.get_unique_constraints("testtbl", schema=schema), + key=operator.itemgetter("name"), + ) + + names_that_duplicate_index = set() + + eq_(len(uniques), len(reflected)) + + for orig, refl in zip(uniques, reflected): + # Different dialects handle duplicate index and constraints + # differently, so ignore this flag + dupe = refl.pop("duplicates_index", None) + if dupe: + names_that_duplicate_index.add(dupe) + eq_(refl.pop("comment", None), None) + eq_(orig, refl) + + reflected_metadata = MetaData() + reflected = Table( + "testtbl", + reflected_metadata, + autoload_with=connection, + schema=schema, + ) + + # test "deduplicates for index" logic. MySQL and Oracle + # "unique constraints" are actually unique indexes (with possible + # exception of a unique that is a dupe of another one in the case + # of Oracle). make sure # they aren't duplicated. + idx_names = {idx.name for idx in reflected.indexes} + uq_names = { + uq.name + for uq in reflected.constraints + if isinstance(uq, sa.UniqueConstraint) + }.difference(["unique_c_a_b"]) + + assert not idx_names.intersection(uq_names) + if names_that_duplicate_index: + eq_(names_that_duplicate_index, idx_names) + eq_(uq_names, set()) + + no_cst = self.tables.no_constraints.name + eq_(insp.get_unique_constraints(no_cst, schema=schema), []) + + @testing.requires.view_reflection + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + def test_get_view_definition(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + insp = inspect(connection) + for view in ["users_v", "email_addresses_v", "dingalings_v"]: + v = insp.get_view_definition(view, schema=schema) + is_true(bool(v)) + + @testing.requires.view_reflection + def test_get_view_definition_does_not_exist(self, connection): + insp = inspect(connection) + with expect_raises(NoSuchTableError): + insp.get_view_definition("view_does_not_exist") + with expect_raises(NoSuchTableError): + insp.get_view_definition("users") # a table + + @testing.requires.table_reflection + def test_autoincrement_col(self, connection): + """test that 'autoincrement' is reflected according to sqla's policy. + + Don't mark this test as unsupported for any backend ! + + (technically it fails with MySQL InnoDB since "id" comes before "id2") + + A backend is better off not returning "autoincrement" at all, + instead of potentially returning "False" for an auto-incrementing + primary key column. + + """ + + insp = inspect(connection) + + for tname, cname in [ + ("users", "user_id"), + ("email_addresses", "address_id"), + ("dingalings", "dingaling_id"), + ]: + cols = insp.get_columns(tname) + id_ = {c["name"]: c for c in cols}[cname] + assert id_.get("autoincrement", True) + + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) + def test_get_table_options(self, use_schema): + insp = inspect(config.db) + schema = config.test_schema if use_schema else None + + if testing.requires.reflect_table_options.enabled: + res = insp.get_table_options("users", schema=schema) + is_true(isinstance(res, dict)) + # NOTE: can't really create a table with no option + res = insp.get_table_options("no_constraints", schema=schema) + is_true(isinstance(res, dict)) + else: + with expect_raises(NotImplementedError): + res = insp.get_table_options("users", schema=schema) + + @testing.combinations((True, testing.requires.schemas), False) + def test_multi_get_table_options(self, use_schema): + insp = inspect(config.db) + if testing.requires.reflect_table_options.enabled: + schema = config.test_schema if use_schema else None + res = insp.get_multi_table_options(schema=schema) + + exp = { + (schema, table): insp.get_table_options(table, schema=schema) + for table in insp.get_table_names(schema=schema) + } + eq_(res, exp) + else: + with expect_raises(NotImplementedError): + res = insp.get_multi_table_options() + + @testing.fixture + def get_multi_exp(self, connection): + def provide_fixture( + schema, scope, kind, use_filter, single_reflect_fn, exp_method + ): + insp = inspect(connection) + # call the reflection function at least once to avoid + # "Unexpected success" errors if the result is actually empty + # and NotImplementedError is not raised + single_reflect_fn(insp, "email_addresses") + kw = {"scope": scope, "kind": kind} + if schema: + schema = schema() + + filter_names = [] + + if ObjectKind.TABLE in kind: + filter_names.extend( + ["comment_test", "users", "does-not-exist"] + ) + if ObjectKind.VIEW in kind: + filter_names.extend(["email_addresses_v", "does-not-exist"]) + if ObjectKind.MATERIALIZED_VIEW in kind: + filter_names.extend(["dingalings_v", "does-not-exist"]) + + if schema: + kw["schema"] = schema + if use_filter: + kw["filter_names"] = filter_names + + exp = exp_method( + schema=schema, + scope=scope, + kind=kind, + filter_names=kw.get("filter_names"), + ) + kws = [kw] + if scope == ObjectScope.DEFAULT: + nkw = kw.copy() + nkw.pop("scope") + kws.append(nkw) + if kind == ObjectKind.TABLE: + nkw = kw.copy() + nkw.pop("kind") + kws.append(nkw) + + return inspect(connection), kws, exp + + return provide_fixture + + @testing.requires.reflect_table_options + @_multi_combination + def test_multi_get_table_options_tables( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_table_options, + self.exp_options, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_table_options(**kw) + eq_(result, exp) + + @testing.requires.comment_reflection + @_multi_combination + def test_get_multi_table_comment( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_table_comment, + self.exp_comments, + ) + for kw in kws: + insp.clear_cache() + eq_(insp.get_multi_table_comment(**kw), exp) + + def _check_expressions(self, result, exp, err_msg): + def _clean(text: str): + return re.sub(r"['\" ]", "", text).lower() + + if isinstance(exp, dict): + eq_({_clean(e): v for e, v in result.items()}, exp, err_msg) + else: + eq_([_clean(e) for e in result], exp, err_msg) + + def _check_list(self, result, exp, req_keys=None, msg=None): + if req_keys is None: + eq_(result, exp, msg) + else: + eq_(len(result), len(exp), msg) + for r, e in zip(result, exp): + for k in set(r) | set(e): + if k in req_keys or (k in r and k in e): + err_msg = f"{msg} - {k} - {r}" + if k in ("expressions", "column_sorting"): + self._check_expressions(r[k], e[k], err_msg) + else: + eq_(r[k], e[k], err_msg) + + def _check_table_dict(self, result, exp, req_keys=None, make_lists=False): + eq_(set(result.keys()), set(exp.keys())) + for k in result: + r, e = result[k], exp[k] + if make_lists: + r, e = [r], [e] + self._check_list(r, e, req_keys, k) + + @_multi_combination + def test_get_multi_columns( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_columns, + self.exp_columns, + ) + + for kw in kws: + insp.clear_cache() + result = insp.get_multi_columns(**kw) + self._check_table_dict(result, exp, self._required_column_keys) + + @testing.requires.primary_key_constraint_reflection + @_multi_combination + def test_get_multi_pk_constraint( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_pk_constraint, + self.exp_pks, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_pk_constraint(**kw) + self._check_table_dict( + result, exp, self._required_pk_keys, make_lists=True + ) + + def _adjust_sort(self, result, expected, key): + if not testing.requires.implicitly_named_constraints.enabled: + for obj in [result, expected]: + for val in obj.values(): + if len(val) > 1 and any( + v.get("name") in (None, mock.ANY) for v in val + ): + val.sort(key=key) + + @testing.requires.foreign_key_constraint_reflection + @_multi_combination + def test_get_multi_foreign_keys( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_foreign_keys, + self.exp_fks, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_foreign_keys(**kw) + self._adjust_sort( + result, exp, lambda d: tuple(d["constrained_columns"]) + ) + self._check_table_dict(result, exp, self._required_fk_keys) + + @testing.requires.index_reflection + @_multi_combination + def test_get_multi_indexes( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_indexes, + self.exp_indexes, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_indexes(**kw) + self._check_table_dict(result, exp, self._required_index_keys) + + @testing.requires.unique_constraint_reflection + @_multi_combination + def test_get_multi_unique_constraints( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_unique_constraints, + self.exp_ucs, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_unique_constraints(**kw) + self._adjust_sort(result, exp, lambda d: tuple(d["column_names"])) + self._check_table_dict(result, exp, self._required_unique_cst_keys) + + @testing.requires.check_constraint_reflection + @_multi_combination + def test_get_multi_check_constraints( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_check_constraints, + self.exp_ccs, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_check_constraints(**kw) + self._adjust_sort(result, exp, lambda d: tuple(d["sqltext"])) + self._check_table_dict(result, exp, self._required_cc_keys) + + @testing.combinations( + ("get_table_options", testing.requires.reflect_table_options), + "get_columns", + ( + "get_pk_constraint", + testing.requires.primary_key_constraint_reflection, + ), + ( + "get_foreign_keys", + testing.requires.foreign_key_constraint_reflection, + ), + ("get_indexes", testing.requires.index_reflection), + ( + "get_unique_constraints", + testing.requires.unique_constraint_reflection, + ), + ( + "get_check_constraints", + testing.requires.check_constraint_reflection, + ), + ("get_table_comment", testing.requires.comment_reflection), + argnames="method", + ) + def test_not_existing_table(self, method, connection): + insp = inspect(connection) + meth = getattr(insp, method) + with expect_raises(NoSuchTableError): + meth("table_does_not_exists") + + def test_unreflectable(self, connection): + mc = Inspector.get_multi_columns + + def patched(*a, **k): + ur = k.setdefault("unreflectable", {}) + ur[(None, "some_table")] = UnreflectableTableError("err") + return mc(*a, **k) + + with mock.patch.object(Inspector, "get_multi_columns", patched): + with expect_raises_message(UnreflectableTableError, "err"): + inspect(connection).reflect_table( + Table("some_table", MetaData()), None + ) + + @testing.combinations(True, False, argnames="use_schema") + @testing.combinations( + (True, testing.requires.views), False, argnames="views" + ) + def test_metadata(self, connection, use_schema, views): + m = MetaData() + schema = config.test_schema if use_schema else None + m.reflect(connection, schema=schema, views=views, resolve_fks=False) + + insp = inspect(connection) + tables = insp.get_table_names(schema) + if views: + tables += insp.get_view_names(schema) + try: + tables += insp.get_materialized_view_names(schema) + except NotImplementedError: + pass + if schema: + tables = [f"{schema}.{t}" for t in tables] + eq_(sorted(m.tables), sorted(tables)) + + @testing.requires.comment_reflection + def test_comments_unicode(self, connection, metadata): + Table( + "unicode_comments", + metadata, + Column("unicode", Integer, comment="é試蛇ẟΩ"), + Column("emoji", Integer, comment="☁️✨"), + comment="試蛇ẟΩ✨", + ) + + metadata.create_all(connection) + + insp = inspect(connection) + tc = insp.get_table_comment("unicode_comments") + eq_(tc, {"text": "試蛇ẟΩ✨"}) + + cols = insp.get_columns("unicode_comments") + value = {c["name"]: c["comment"] for c in cols} + exp = {"unicode": "é試蛇ẟΩ", "emoji": "☁️✨"} + eq_(value, exp) + + @testing.requires.comment_reflection_full_unicode + def test_comments_unicode_full(self, connection, metadata): + Table( + "unicode_comments", + metadata, + Column("emoji", Integer, comment="🐍🧙🝝🧙♂️🧙♀️"), + comment="🎩🁰🝑🤷♀️🤷♂️", + ) + + metadata.create_all(connection) + + insp = inspect(connection) + tc = insp.get_table_comment("unicode_comments") + eq_(tc, {"text": "🎩🁰🝑🤷♀️🤷♂️"}) + c = insp.get_columns("unicode_comments")[0] + eq_({c["name"]: c["comment"]}, {"emoji": "🐍🧙🝝🧙♂️🧙♀️"}) + + +class TableNoColumnsTest(fixtures.TestBase): + __requires__ = ("reflect_tables_no_columns",) + __backend__ = True + + @testing.fixture + def table_no_columns(self, connection, metadata): + Table("empty", metadata) + metadata.create_all(connection) + + @testing.fixture + def view_no_columns(self, connection, metadata): + Table("empty", metadata) + event.listen( + metadata, + "after_create", + DDL("CREATE VIEW empty_v AS SELECT * FROM empty"), + ) + + # for transactional DDL the transaction is rolled back before this + # drop statement is invoked + event.listen( + metadata, "before_drop", DDL("DROP VIEW IF EXISTS empty_v") + ) + metadata.create_all(connection) + + def test_reflect_table_no_columns(self, connection, table_no_columns): + t2 = Table("empty", MetaData(), autoload_with=connection) + eq_(list(t2.c), []) + + def test_get_columns_table_no_columns(self, connection, table_no_columns): + insp = inspect(connection) + eq_(insp.get_columns("empty"), []) + multi = insp.get_multi_columns() + eq_(multi, {(None, "empty"): []}) + + def test_reflect_incl_table_no_columns(self, connection, table_no_columns): + m = MetaData() + m.reflect(connection) + assert set(m.tables).intersection(["empty"]) + + @testing.requires.views + def test_reflect_view_no_columns(self, connection, view_no_columns): + t2 = Table("empty_v", MetaData(), autoload_with=connection) + eq_(list(t2.c), []) + + @testing.requires.views + def test_get_columns_view_no_columns(self, connection, view_no_columns): + insp = inspect(connection) + eq_(insp.get_columns("empty_v"), []) + multi = insp.get_multi_columns(kind=ObjectKind.VIEW) + eq_(multi, {(None, "empty_v"): []}) + + +class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase): + __backend__ = True + + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) + @testing.requires.check_constraint_reflection + def test_get_check_constraints(self, metadata, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + Table( + "sa_cc", + metadata, + Column("a", Integer()), + sa.CheckConstraint("a > 1 AND a < 5", name="cc1"), + sa.CheckConstraint( + "a = 1 OR (a > 2 AND a < 5)", name="UsesCasing" + ), + schema=schema, + ) + Table( + "no_constraints", + metadata, + Column("data", sa.String(20)), + schema=schema, + ) + + metadata.create_all(connection) + + insp = inspect(connection) + reflected = sorted( + insp.get_check_constraints("sa_cc", schema=schema), + key=operator.itemgetter("name"), + ) + + # trying to minimize effect of quoting, parenthesis, etc. + # may need to add more to this as new dialects get CHECK + # constraint reflection support + def normalize(sqltext): + return " ".join( + re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I) + ) + + reflected = [ + {"name": item["name"], "sqltext": normalize(item["sqltext"])} + for item in reflected + ] + eq_( + reflected, + [ + {"name": "UsesCasing", "sqltext": "a = 1 or a > 2 and a < 5"}, + {"name": "cc1", "sqltext": "a > 1 and a < 5"}, + ], + ) + no_cst = "no_constraints" + eq_(insp.get_check_constraints(no_cst, schema=schema), []) + + @testing.requires.indexes_with_expressions + def test_reflect_expression_based_indexes(self, metadata, connection): + t = Table( + "t", + metadata, + Column("x", String(30)), + Column("y", String(30)), + Column("z", String(30)), + ) + + Index("t_idx", func.lower(t.c.x), t.c.z, func.lower(t.c.y)) + long_str = "long string " * 100 + Index("t_idx_long", func.coalesce(t.c.x, long_str)) + Index("t_idx_2", t.c.x) + + metadata.create_all(connection) + + insp = inspect(connection) + + expected = [ + { + "name": "t_idx_2", + "column_names": ["x"], + "unique": False, + "dialect_options": {}, + } + ] + + def completeIndex(entry): + if testing.requires.index_reflects_included_columns.enabled: + entry["include_columns"] = [] + entry["dialect_options"] = { + f"{connection.engine.name}_include": [] + } + else: + entry.setdefault("dialect_options", {}) + + completeIndex(expected[0]) + + class lower_index_str(str): + def __eq__(self, other): + ol = other.lower() + # test that lower and x or y are in the string + return "lower" in ol and ("x" in ol or "y" in ol) + + class coalesce_index_str(str): + def __eq__(self, other): + # test that coalesce and the string is in other + return "coalesce" in other.lower() and long_str in other + + if testing.requires.reflect_indexes_with_expressions.enabled: + expr_index = { + "name": "t_idx", + "column_names": [None, "z", None], + "expressions": [ + lower_index_str("lower(x)"), + "z", + lower_index_str("lower(y)"), + ], + "unique": False, + } + completeIndex(expr_index) + expected.insert(0, expr_index) + + expr_index_long = { + "name": "t_idx_long", + "column_names": [None], + "expressions": [ + coalesce_index_str(f"coalesce(x, '{long_str}')") + ], + "unique": False, + } + completeIndex(expr_index_long) + expected.append(expr_index_long) + + eq_(insp.get_indexes("t"), expected) + m2 = MetaData() + t2 = Table("t", m2, autoload_with=connection) + else: + with expect_warnings( + "Skipped unsupported reflection of expression-based " + "index t_idx" + ): + eq_(insp.get_indexes("t"), expected) + m2 = MetaData() + t2 = Table("t", m2, autoload_with=connection) + + self.compare_table_index_with_expected( + t2, expected, connection.engine.name + ) + + @testing.requires.index_reflects_included_columns + def test_reflect_covering_index(self, metadata, connection): + t = Table( + "t", + metadata, + Column("x", String(30)), + Column("y", String(30)), + ) + idx = Index("t_idx", t.c.x) + idx.dialect_options[connection.engine.name]["include"] = ["y"] + + metadata.create_all(connection) + + insp = inspect(connection) + + get_indexes = insp.get_indexes("t") + eq_( + get_indexes, + [ + { + "name": "t_idx", + "column_names": ["x"], + "include_columns": ["y"], + "unique": False, + "dialect_options": mock.ANY, + } + ], + ) + eq_( + get_indexes[0]["dialect_options"][ + "%s_include" % connection.engine.name + ], + ["y"], + ) + + t2 = Table("t", MetaData(), autoload_with=connection) + eq_( + list(t2.indexes)[0].dialect_options[connection.engine.name][ + "include" + ], + ["y"], + ) + + def _type_round_trip(self, connection, metadata, *types): + t = Table( + "t", + metadata, + *[Column("t%d" % i, type_) for i, type_ in enumerate(types)], + ) + t.create(connection) + + return [c["type"] for c in inspect(connection).get_columns("t")] + + @testing.requires.table_reflection + def test_numeric_reflection(self, connection, metadata): + for typ in self._type_round_trip( + connection, metadata, sql_types.Numeric(18, 5) + ): + assert isinstance(typ, sql_types.Numeric) + eq_(typ.precision, 18) + eq_(typ.scale, 5) + + @testing.requires.table_reflection + def test_varchar_reflection(self, connection, metadata): + typ = self._type_round_trip( + connection, metadata, sql_types.String(52) + )[0] + assert isinstance(typ, sql_types.String) + eq_(typ.length, 52) + + @testing.requires.table_reflection + def test_nullable_reflection(self, connection, metadata): + t = Table( + "t", + metadata, + Column("a", Integer, nullable=True), + Column("b", Integer, nullable=False), + ) + t.create(connection) + eq_( + { + col["name"]: col["nullable"] + for col in inspect(connection).get_columns("t") + }, + {"a": True, "b": False}, + ) + + @testing.combinations( + ( + None, + "CASCADE", + None, + testing.requires.foreign_key_constraint_option_reflection_ondelete, + ), + ( + None, + None, + "SET NULL", + testing.requires.foreign_key_constraint_option_reflection_onupdate, + ), + ( + {}, + None, + "NO ACTION", + testing.requires.foreign_key_constraint_option_reflection_onupdate, + ), + ( + {}, + "NO ACTION", + None, + testing.requires.fk_constraint_option_reflection_ondelete_noaction, + ), + ( + None, + None, + "RESTRICT", + testing.requires.fk_constraint_option_reflection_onupdate_restrict, + ), + ( + None, + "RESTRICT", + None, + testing.requires.fk_constraint_option_reflection_ondelete_restrict, + ), + argnames="expected,ondelete,onupdate", + ) + def test_get_foreign_key_options( + self, connection, metadata, expected, ondelete, onupdate + ): + options = {} + if ondelete: + options["ondelete"] = ondelete + if onupdate: + options["onupdate"] = onupdate + + if expected is None: + expected = options + + Table( + "x", + metadata, + Column("id", Integer, primary_key=True), + test_needs_fk=True, + ) + + Table( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("x_id", Integer, ForeignKey("x.id", name="xid")), + Column("test", String(10)), + test_needs_fk=True, + ) + + Table( + "user", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("tid", Integer), + sa.ForeignKeyConstraint( + ["tid"], ["table.id"], name="myfk", **options + ), + test_needs_fk=True, + ) + + metadata.create_all(connection) + + insp = inspect(connection) + + # test 'options' is always present for a backend + # that can reflect these, since alembic looks for this + opts = insp.get_foreign_keys("table")[0]["options"] + + eq_({k: opts[k] for k in opts if opts[k]}, {}) + + opts = insp.get_foreign_keys("user")[0]["options"] + eq_(opts, expected) + # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected) + + +class NormalizedNameTest(fixtures.TablesTest): + __requires__ = ("denormalized_names",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + quoted_name("t1", quote=True), + metadata, + Column("id", Integer, primary_key=True), + ) + Table( + quoted_name("t2", quote=True), + metadata, + Column("id", Integer, primary_key=True), + Column("t1id", ForeignKey("t1.id")), + ) + + def test_reflect_lowercase_forced_tables(self): + m2 = MetaData() + t2_ref = Table( + quoted_name("t2", quote=True), m2, autoload_with=config.db + ) + t1_ref = m2.tables["t1"] + assert t2_ref.c.t1id.references(t1_ref.c.id) + + m3 = MetaData() + m3.reflect( + config.db, only=lambda name, m: name.lower() in ("t1", "t2") + ) + assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id) + + def test_get_table_names(self): + tablenames = [ + t + for t in inspect(config.db).get_table_names() + if t.lower() in ("t1", "t2") + ] + + eq_(tablenames[0].upper(), tablenames[0].lower()) + eq_(tablenames[1].upper(), tablenames[1].lower()) + + +class ComputedReflectionTest(fixtures.ComputedReflectionFixtureTest): + def test_computed_col_default_not_set(self): + insp = inspect(config.db) + + cols = insp.get_columns("computed_default_table") + col_data = {c["name"]: c for c in cols} + is_true("42" in col_data["with_default"]["default"]) + is_(col_data["normal"]["default"], None) + is_(col_data["computed_col"]["default"], None) + + def test_get_column_returns_computed(self): + insp = inspect(config.db) + + cols = insp.get_columns("computed_default_table") + data = {c["name"]: c for c in cols} + for key in ("id", "normal", "with_default"): + is_true("computed" not in data[key]) + compData = data["computed_col"] + is_true("computed" in compData) + is_true("sqltext" in compData["computed"]) + eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42") + eq_( + "persisted" in compData["computed"], + testing.requires.computed_columns_reflect_persisted.enabled, + ) + if testing.requires.computed_columns_reflect_persisted.enabled: + eq_( + compData["computed"]["persisted"], + testing.requires.computed_columns_default_persisted.enabled, + ) + + def check_column(self, data, column, sqltext, persisted): + is_true("computed" in data[column]) + compData = data[column]["computed"] + eq_(self.normalize(compData["sqltext"]), sqltext) + if testing.requires.computed_columns_reflect_persisted.enabled: + is_true("persisted" in compData) + is_(compData["persisted"], persisted) + + def test_get_column_returns_persisted(self): + insp = inspect(config.db) + + cols = insp.get_columns("computed_column_table") + data = {c["name"]: c for c in cols} + + self.check_column( + data, + "computed_no_flag", + "normal+42", + testing.requires.computed_columns_default_persisted.enabled, + ) + if testing.requires.computed_columns_virtual.enabled: + self.check_column( + data, + "computed_virtual", + "normal+2", + False, + ) + if testing.requires.computed_columns_stored.enabled: + self.check_column( + data, + "computed_stored", + "normal-42", + True, + ) + + @testing.requires.schemas + def test_get_column_returns_persisted_with_schema(self): + insp = inspect(config.db) + + cols = insp.get_columns( + "computed_column_table", schema=config.test_schema + ) + data = {c["name"]: c for c in cols} + + self.check_column( + data, + "computed_no_flag", + "normal/42", + testing.requires.computed_columns_default_persisted.enabled, + ) + if testing.requires.computed_columns_virtual.enabled: + self.check_column( + data, + "computed_virtual", + "normal/2", + False, + ) + if testing.requires.computed_columns_stored.enabled: + self.check_column( + data, + "computed_stored", + "normal*42", + True, + ) + + +class IdentityReflectionTest(fixtures.TablesTest): + run_inserts = run_deletes = None + + __backend__ = True + __requires__ = ("identity_columns", "table_reflection") + + @classmethod + def define_tables(cls, metadata): + Table( + "t1", + metadata, + Column("normal", Integer), + Column("id1", Integer, Identity()), + ) + Table( + "t2", + metadata, + Column( + "id2", + Integer, + Identity( + always=True, + start=2, + increment=3, + minvalue=-2, + maxvalue=42, + cycle=True, + cache=4, + ), + ), + ) + if testing.requires.schemas.enabled: + Table( + "t1", + metadata, + Column("normal", Integer), + Column("id1", Integer, Identity(always=True, start=20)), + schema=config.test_schema, + ) + + def check(self, value, exp, approx): + if testing.requires.identity_columns_standard.enabled: + common_keys = ( + "always", + "start", + "increment", + "minvalue", + "maxvalue", + "cycle", + "cache", + ) + for k in list(value): + if k not in common_keys: + value.pop(k) + if approx: + eq_(len(value), len(exp)) + for k in value: + if k == "minvalue": + is_true(value[k] <= exp[k]) + elif k in {"maxvalue", "cache"}: + is_true(value[k] >= exp[k]) + else: + eq_(value[k], exp[k], k) + else: + eq_(value, exp) + else: + eq_(value["start"], exp["start"]) + eq_(value["increment"], exp["increment"]) + + def test_reflect_identity(self): + insp = inspect(config.db) + + cols = insp.get_columns("t1") + insp.get_columns("t2") + for col in cols: + if col["name"] == "normal": + is_false("identity" in col) + elif col["name"] == "id1": + if "autoincrement" in col: + is_true(col["autoincrement"]) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=False, + start=1, + increment=1, + minvalue=1, + maxvalue=2147483647, + cycle=False, + cache=1, + ), + approx=True, + ) + elif col["name"] == "id2": + if "autoincrement" in col: + is_true(col["autoincrement"]) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=True, + start=2, + increment=3, + minvalue=-2, + maxvalue=42, + cycle=True, + cache=4, + ), + approx=False, + ) + + @testing.requires.schemas + def test_reflect_identity_schema(self): + insp = inspect(config.db) + + cols = insp.get_columns("t1", schema=config.test_schema) + for col in cols: + if col["name"] == "normal": + is_false("identity" in col) + elif col["name"] == "id1": + if "autoincrement" in col: + is_true(col["autoincrement"]) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=True, + start=20, + increment=1, + minvalue=1, + maxvalue=2147483647, + cycle=False, + cache=1, + ), + approx=True, + ) + + +class CompositeKeyReflectionTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + tb1 = Table( + "tb1", + metadata, + Column("id", Integer), + Column("attr", Integer), + Column("name", sql_types.VARCHAR(20)), + sa.PrimaryKeyConstraint("name", "id", "attr", name="pk_tb1"), + schema=None, + test_needs_fk=True, + ) + Table( + "tb2", + metadata, + Column("id", Integer, primary_key=True), + Column("pid", Integer), + Column("pattr", Integer), + Column("pname", sql_types.VARCHAR(20)), + sa.ForeignKeyConstraint( + ["pname", "pid", "pattr"], + [tb1.c.name, tb1.c.id, tb1.c.attr], + name="fk_tb1_name_id_attr", + ), + schema=None, + test_needs_fk=True, + ) + + @testing.requires.primary_key_constraint_reflection + def test_pk_column_order(self, connection): + # test for issue #5661 + insp = inspect(connection) + primary_key = insp.get_pk_constraint(self.tables.tb1.name) + eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"]) + + @testing.requires.foreign_key_constraint_reflection + def test_fk_column_order(self, connection): + # test for issue #5661 + insp = inspect(connection) + foreign_keys = insp.get_foreign_keys(self.tables.tb2.name) + eq_(len(foreign_keys), 1) + fkey1 = foreign_keys[0] + eq_(fkey1.get("referred_columns"), ["name", "id", "attr"]) + eq_(fkey1.get("constrained_columns"), ["pname", "pid", "pattr"]) + + +__all__ = ( + "ComponentReflectionTest", + "ComponentReflectionTestExtra", + "TableNoColumnsTest", + "QuotedNameArgumentTest", + "BizarroCharacterFKResolutionTest", + "HasTableTest", + "HasIndexTest", + "NormalizedNameTest", + "ComputedReflectionTest", + "IdentityReflectionTest", + "CompositeKeyReflectionTest", +) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_results.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_results.py new file mode 100644 index 0000000..b3f432f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_results.py @@ -0,0 +1,468 @@ +# testing/suite/test_results.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import datetime + +from .. import engines +from .. import fixtures +from ..assertions import eq_ +from ..config import requirements +from ..schema import Column +from ..schema import Table +from ... import DateTime +from ... import func +from ... import Integer +from ... import select +from ... import sql +from ... import String +from ... import testing +from ... import text + + +class RowFetchTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "plain_pk", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + Table( + "has_dates", + metadata, + Column("id", Integer, primary_key=True), + Column("today", DateTime), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.plain_pk.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + connection.execute( + cls.tables.has_dates.insert(), + [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}], + ) + + def test_via_attr(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row.id, 1) + eq_(row.data, "d1") + + def test_via_string(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row._mapping["id"], 1) + eq_(row._mapping["data"], "d1") + + def test_via_int(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row[0], 1) + eq_(row[1], "d1") + + def test_via_col_object(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row._mapping[self.tables.plain_pk.c.id], 1) + eq_(row._mapping[self.tables.plain_pk.c.data], "d1") + + @requirements.duplicate_names_in_cursor_description + def test_row_with_dupe_names(self, connection): + result = connection.execute( + select( + self.tables.plain_pk.c.data, + self.tables.plain_pk.c.data.label("data"), + ).order_by(self.tables.plain_pk.c.id) + ) + row = result.first() + eq_(result.keys(), ["data", "data"]) + eq_(row, ("d1", "d1")) + + def test_row_w_scalar_select(self, connection): + """test that a scalar select as a column is returned as such + and that type conversion works OK. + + (this is half a SQLAlchemy Core test and half to catch database + backends that may have unusual behavior with scalar selects.) + + """ + datetable = self.tables.has_dates + s = select(datetable.alias("x").c.today).scalar_subquery() + s2 = select(datetable.c.id, s.label("somelabel")) + row = connection.execute(s2).first() + + eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0)) + + +class PercentSchemaNamesTest(fixtures.TablesTest): + """tests using percent signs, spaces in table and column names. + + This didn't work for PostgreSQL / MySQL drivers for a long time + but is now supported. + + """ + + __requires__ = ("percent_schema_names",) + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + cls.tables.percent_table = Table( + "percent%table", + metadata, + Column("percent%", Integer), + Column("spaces % more spaces", Integer), + ) + cls.tables.lightweight_percent_table = sql.table( + "percent%table", + sql.column("percent%"), + sql.column("spaces % more spaces"), + ) + + def test_single_roundtrip(self, connection): + percent_table = self.tables.percent_table + for params in [ + {"percent%": 5, "spaces % more spaces": 12}, + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, + ]: + connection.execute(percent_table.insert(), params) + self._assert_table(connection) + + def test_executemany_roundtrip(self, connection): + percent_table = self.tables.percent_table + connection.execute( + percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12} + ) + connection.execute( + percent_table.insert(), + [ + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, + ], + ) + self._assert_table(connection) + + @requirements.insert_executemany_returning + def test_executemany_returning_roundtrip(self, connection): + percent_table = self.tables.percent_table + connection.execute( + percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12} + ) + result = connection.execute( + percent_table.insert().returning( + percent_table.c["percent%"], + percent_table.c["spaces % more spaces"], + ), + [ + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, + ], + ) + eq_(result.all(), [(7, 11), (9, 10), (11, 9)]) + self._assert_table(connection) + + def _assert_table(self, conn): + percent_table = self.tables.percent_table + lightweight_percent_table = self.tables.lightweight_percent_table + + for table in ( + percent_table, + percent_table.alias(), + lightweight_percent_table, + lightweight_percent_table.alias(), + ): + eq_( + list( + conn.execute(table.select().order_by(table.c["percent%"])) + ), + [(5, 12), (7, 11), (9, 10), (11, 9)], + ) + + eq_( + list( + conn.execute( + table.select() + .where(table.c["spaces % more spaces"].in_([9, 10])) + .order_by(table.c["percent%"]) + ) + ), + [(9, 10), (11, 9)], + ) + + row = conn.execute( + table.select().order_by(table.c["percent%"]) + ).first() + eq_(row._mapping["percent%"], 5) + eq_(row._mapping["spaces % more spaces"], 12) + + eq_(row._mapping[table.c["percent%"]], 5) + eq_(row._mapping[table.c["spaces % more spaces"]], 12) + + conn.execute( + percent_table.update().values( + {percent_table.c["spaces % more spaces"]: 15} + ) + ) + + eq_( + list( + conn.execute( + percent_table.select().order_by( + percent_table.c["percent%"] + ) + ) + ), + [(5, 15), (7, 15), (9, 15), (11, 15)], + ) + + +class ServerSideCursorsTest( + fixtures.TestBase, testing.AssertsExecutionResults +): + __requires__ = ("server_side_cursors",) + + __backend__ = True + + def _is_server_side(self, cursor): + # TODO: this is a huge issue as it prevents these tests from being + # usable by third party dialects. + if self.engine.dialect.driver == "psycopg2": + return bool(cursor.name) + elif self.engine.dialect.driver == "pymysql": + sscursor = __import__("pymysql.cursors").cursors.SSCursor + return isinstance(cursor, sscursor) + elif self.engine.dialect.driver in ("aiomysql", "asyncmy", "aioodbc"): + return cursor.server_side + elif self.engine.dialect.driver == "mysqldb": + sscursor = __import__("MySQLdb.cursors").cursors.SSCursor + return isinstance(cursor, sscursor) + elif self.engine.dialect.driver == "mariadbconnector": + return not cursor.buffered + elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"): + return cursor.server_side + elif self.engine.dialect.driver == "pg8000": + return getattr(cursor, "server_side", False) + elif self.engine.dialect.driver == "psycopg": + return bool(getattr(cursor, "name", False)) + else: + return False + + def _fixture(self, server_side_cursors): + if server_side_cursors: + with testing.expect_deprecated( + "The create_engine.server_side_cursors parameter is " + "deprecated and will be removed in a future release. " + "Please use the Connection.execution_options.stream_results " + "parameter." + ): + self.engine = engines.testing_engine( + options={"server_side_cursors": server_side_cursors} + ) + else: + self.engine = engines.testing_engine( + options={"server_side_cursors": server_side_cursors} + ) + return self.engine + + @testing.combinations( + ("global_string", True, "select 1", True), + ("global_text", True, text("select 1"), True), + ("global_expr", True, select(1), True), + ("global_off_explicit", False, text("select 1"), False), + ( + "stmt_option", + False, + select(1).execution_options(stream_results=True), + True, + ), + ( + "stmt_option_disabled", + True, + select(1).execution_options(stream_results=False), + False, + ), + ("for_update_expr", True, select(1).with_for_update(), True), + # TODO: need a real requirement for this, or dont use this test + ( + "for_update_string", + True, + "SELECT 1 FOR UPDATE", + True, + testing.skip_if(["sqlite", "mssql"]), + ), + ("text_no_ss", False, text("select 42"), False), + ( + "text_ss_option", + False, + text("select 42").execution_options(stream_results=True), + True, + ), + id_="iaaa", + argnames="engine_ss_arg, statement, cursor_ss_status", + ) + def test_ss_cursor_status( + self, engine_ss_arg, statement, cursor_ss_status + ): + engine = self._fixture(engine_ss_arg) + with engine.begin() as conn: + if isinstance(statement, str): + result = conn.exec_driver_sql(statement) + else: + result = conn.execute(statement) + eq_(self._is_server_side(result.cursor), cursor_ss_status) + result.close() + + def test_conn_option(self): + engine = self._fixture(False) + + with engine.connect() as conn: + # should be enabled for this one + result = conn.execution_options( + stream_results=True + ).exec_driver_sql("select 1") + assert self._is_server_side(result.cursor) + + # the connection has autobegun, which means at the end of the + # block, we will roll back, which on MySQL at least will fail + # with "Commands out of sync" if the result set + # is not closed, so we close it first. + # + # fun fact! why did we not have this result.close() in this test + # before 2.0? don't we roll back in the connection pool + # unconditionally? yes! and in fact if you run this test in 1.4 + # with stdout shown, there is in fact "Exception during reset or + # similar" with "Commands out sync" emitted a warning! 2.0's + # architecture finds and fixes what was previously an expensive + # silent error condition. + result.close() + + def test_stmt_enabled_conn_option_disabled(self): + engine = self._fixture(False) + + s = select(1).execution_options(stream_results=True) + + with engine.connect() as conn: + # not this one + result = conn.execution_options(stream_results=False).execute(s) + assert not self._is_server_side(result.cursor) + + def test_aliases_and_ss(self): + engine = self._fixture(False) + s1 = ( + select(sql.literal_column("1").label("x")) + .execution_options(stream_results=True) + .subquery() + ) + + # options don't propagate out when subquery is used as a FROM clause + with engine.begin() as conn: + result = conn.execute(s1.select()) + assert not self._is_server_side(result.cursor) + result.close() + + s2 = select(1).select_from(s1) + with engine.begin() as conn: + result = conn.execute(s2) + assert not self._is_server_side(result.cursor) + result.close() + + def test_roundtrip_fetchall(self, metadata): + md = self.metadata + + engine = self._fixture(True) + test_table = Table( + "test_table", + md, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + with engine.begin() as connection: + test_table.create(connection, checkfirst=True) + connection.execute(test_table.insert(), dict(data="data1")) + connection.execute(test_table.insert(), dict(data="data2")) + eq_( + connection.execute( + test_table.select().order_by(test_table.c.id) + ).fetchall(), + [(1, "data1"), (2, "data2")], + ) + connection.execute( + test_table.update() + .where(test_table.c.id == 2) + .values(data=test_table.c.data + " updated") + ) + eq_( + connection.execute( + test_table.select().order_by(test_table.c.id) + ).fetchall(), + [(1, "data1"), (2, "data2 updated")], + ) + connection.execute(test_table.delete()) + eq_( + connection.scalar( + select(func.count("*")).select_from(test_table) + ), + 0, + ) + + def test_roundtrip_fetchmany(self, metadata): + md = self.metadata + + engine = self._fixture(True) + test_table = Table( + "test_table", + md, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + with engine.begin() as connection: + test_table.create(connection, checkfirst=True) + connection.execute( + test_table.insert(), + [dict(data="data%d" % i) for i in range(1, 20)], + ) + + result = connection.execute( + test_table.select().order_by(test_table.c.id) + ) + + eq_( + result.fetchmany(5), + [(i, "data%d" % i) for i in range(1, 6)], + ) + eq_( + result.fetchmany(10), + [(i, "data%d" % i) for i in range(6, 16)], + ) + eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)]) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_rowcount.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_rowcount.py new file mode 100644 index 0000000..a7dbd36 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_rowcount.py @@ -0,0 +1,258 @@ +# testing/suite/test_rowcount.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from sqlalchemy import bindparam +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy import text +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures + + +class RowCountTest(fixtures.TablesTest): + """test rowcount functionality""" + + __requires__ = ("sane_rowcount",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "employees", + metadata, + Column( + "employee_id", + Integer, + autoincrement=False, + primary_key=True, + ), + Column("name", String(50)), + Column("department", String(1)), + ) + + @classmethod + def insert_data(cls, connection): + cls.data = data = [ + ("Angela", "A"), + ("Andrew", "A"), + ("Anand", "A"), + ("Bob", "B"), + ("Bobette", "B"), + ("Buffy", "B"), + ("Charlie", "C"), + ("Cynthia", "C"), + ("Chris", "C"), + ] + + employees_table = cls.tables.employees + connection.execute( + employees_table.insert(), + [ + {"employee_id": i, "name": n, "department": d} + for i, (n, d) in enumerate(data) + ], + ) + + def test_basic(self, connection): + employees_table = self.tables.employees + s = select( + employees_table.c.name, employees_table.c.department + ).order_by(employees_table.c.employee_id) + rows = connection.execute(s).fetchall() + + eq_(rows, self.data) + + @testing.variation("statement", ["update", "delete", "insert", "select"]) + @testing.variation("close_first", [True, False]) + def test_non_rowcount_scenarios_no_raise( + self, connection, statement, close_first + ): + employees_table = self.tables.employees + + # WHERE matches 3, 3 rows changed + department = employees_table.c.department + + if statement.update: + r = connection.execute( + employees_table.update().where(department == "C"), + {"department": "Z"}, + ) + elif statement.delete: + r = connection.execute( + employees_table.delete().where(department == "C"), + {"department": "Z"}, + ) + elif statement.insert: + r = connection.execute( + employees_table.insert(), + [ + {"employee_id": 25, "name": "none 1", "department": "X"}, + {"employee_id": 26, "name": "none 2", "department": "Z"}, + {"employee_id": 27, "name": "none 3", "department": "Z"}, + ], + ) + elif statement.select: + s = select( + employees_table.c.name, employees_table.c.department + ).where(employees_table.c.department == "C") + r = connection.execute(s) + r.all() + else: + statement.fail() + + if close_first: + r.close() + + assert r.rowcount in (-1, 3) + + def test_update_rowcount1(self, connection): + employees_table = self.tables.employees + + # WHERE matches 3, 3 rows changed + department = employees_table.c.department + r = connection.execute( + employees_table.update().where(department == "C"), + {"department": "Z"}, + ) + assert r.rowcount == 3 + + def test_update_rowcount2(self, connection): + employees_table = self.tables.employees + + # WHERE matches 3, 0 rows changed + department = employees_table.c.department + + r = connection.execute( + employees_table.update().where(department == "C"), + {"department": "C"}, + ) + eq_(r.rowcount, 3) + + @testing.variation("implicit_returning", [True, False]) + @testing.variation( + "dml", + [ + ("update", testing.requires.update_returning), + ("delete", testing.requires.delete_returning), + ], + ) + def test_update_delete_rowcount_return_defaults( + self, connection, implicit_returning, dml + ): + """note this test should succeed for all RETURNING backends + as of 2.0. In + Idf28379f8705e403a3c6a937f6a798a042ef2540 we changed rowcount to use + len(rows) when we have implicit returning + + """ + + if implicit_returning: + employees_table = self.tables.employees + else: + employees_table = Table( + "employees", + MetaData(), + Column( + "employee_id", + Integer, + autoincrement=False, + primary_key=True, + ), + Column("name", String(50)), + Column("department", String(1)), + implicit_returning=False, + ) + + department = employees_table.c.department + + if dml.update: + stmt = ( + employees_table.update() + .where(department == "C") + .values(name=employees_table.c.department + "Z") + .return_defaults() + ) + elif dml.delete: + stmt = ( + employees_table.delete() + .where(department == "C") + .return_defaults() + ) + else: + dml.fail() + + r = connection.execute(stmt) + eq_(r.rowcount, 3) + + def test_raw_sql_rowcount(self, connection): + # test issue #3622, make sure eager rowcount is called for text + result = connection.exec_driver_sql( + "update employees set department='Z' where department='C'" + ) + eq_(result.rowcount, 3) + + def test_text_rowcount(self, connection): + # test issue #3622, make sure eager rowcount is called for text + result = connection.execute( + text("update employees set department='Z' where department='C'") + ) + eq_(result.rowcount, 3) + + def test_delete_rowcount(self, connection): + employees_table = self.tables.employees + + # WHERE matches 3, 3 rows deleted + department = employees_table.c.department + r = connection.execute( + employees_table.delete().where(department == "C") + ) + eq_(r.rowcount, 3) + + @testing.requires.sane_multi_rowcount + def test_multi_update_rowcount(self, connection): + employees_table = self.tables.employees + stmt = ( + employees_table.update() + .where(employees_table.c.name == bindparam("emp_name")) + .values(department="C") + ) + + r = connection.execute( + stmt, + [ + {"emp_name": "Bob"}, + {"emp_name": "Cynthia"}, + {"emp_name": "nonexistent"}, + ], + ) + + eq_(r.rowcount, 2) + + @testing.requires.sane_multi_rowcount + def test_multi_delete_rowcount(self, connection): + employees_table = self.tables.employees + + stmt = employees_table.delete().where( + employees_table.c.name == bindparam("emp_name") + ) + + r = connection.execute( + stmt, + [ + {"emp_name": "Bob"}, + {"emp_name": "Cynthia"}, + {"emp_name": "nonexistent"}, + ], + ) + + eq_(r.rowcount, 2) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_select.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_select.py new file mode 100644 index 0000000..866bf09 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_select.py @@ -0,0 +1,1888 @@ +# testing/suite/test_select.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import collections.abc as collections_abc +import itertools + +from .. import AssertsCompiledSQL +from .. import AssertsExecutionResults +from .. import config +from .. import fixtures +from ..assertions import assert_raises +from ..assertions import eq_ +from ..assertions import in_ +from ..assertsql import CursorSQL +from ..schema import Column +from ..schema import Table +from ... import bindparam +from ... import case +from ... import column +from ... import Computed +from ... import exists +from ... import false +from ... import ForeignKey +from ... import func +from ... import Identity +from ... import Integer +from ... import literal +from ... import literal_column +from ... import null +from ... import select +from ... import String +from ... import table +from ... import testing +from ... import text +from ... import true +from ... import tuple_ +from ... import TupleType +from ... import union +from ... import values +from ...exc import DatabaseError +from ...exc import ProgrammingError + + +class CollateTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(100)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "collate data1"}, + {"id": 2, "data": "collate data2"}, + ], + ) + + def _assert_result(self, select, result): + with config.db.connect() as conn: + eq_(conn.execute(select).fetchall(), result) + + @testing.requires.order_by_collation + def test_collate_order_by(self): + collation = testing.requires.get_order_by_collation(testing.config) + + self._assert_result( + select(self.tables.some_table).order_by( + self.tables.some_table.c.data.collate(collation).asc() + ), + [(1, "collate data1"), (2, "collate data2")], + ) + + +class OrderByLabelTest(fixtures.TablesTest): + """Test the dialect sends appropriate ORDER BY expressions when + labels are used. + + This essentially exercises the "supports_simple_order_by_label" + setting. + + """ + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("q", String(50)), + Column("p", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"}, + {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"}, + {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"}, + ], + ) + + def _assert_result(self, select, result): + with config.db.connect() as conn: + eq_(conn.execute(select).fetchall(), result) + + def test_plain(self): + table = self.tables.some_table + lx = table.c.x.label("lx") + self._assert_result(select(lx).order_by(lx), [(1,), (2,), (3,)]) + + def test_composed_int(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label("lx") + self._assert_result(select(lx).order_by(lx), [(3,), (5,), (7,)]) + + def test_composed_multiple(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label("lx") + ly = (func.lower(table.c.q) + table.c.p).label("ly") + self._assert_result( + select(lx, ly).order_by(lx, ly.desc()), + [(3, "q1p3"), (5, "q2p2"), (7, "q3p1")], + ) + + def test_plain_desc(self): + table = self.tables.some_table + lx = table.c.x.label("lx") + self._assert_result(select(lx).order_by(lx.desc()), [(3,), (2,), (1,)]) + + def test_composed_int_desc(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label("lx") + self._assert_result(select(lx).order_by(lx.desc()), [(7,), (5,), (3,)]) + + @testing.requires.group_by_complex_expression + def test_group_by_composed(self): + table = self.tables.some_table + expr = (table.c.x + table.c.y).label("lx") + stmt = ( + select(func.count(table.c.id), expr).group_by(expr).order_by(expr) + ) + self._assert_result(stmt, [(1, 3), (1, 5), (1, 7)]) + + +class ValuesExpressionTest(fixtures.TestBase): + __requires__ = ("table_value_constructor",) + + __backend__ = True + + def test_tuples(self, connection): + value_expr = values( + column("id", Integer), column("name", String), name="my_values" + ).data([(1, "name1"), (2, "name2"), (3, "name3")]) + + eq_( + connection.execute(select(value_expr)).all(), + [(1, "name1"), (2, "name2"), (3, "name3")], + ) + + +class FetchLimitOffsetTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + {"id": 5, "x": 4, "y": 6}, + ], + ) + + def _assert_result( + self, connection, select, result, params=(), set_=False + ): + if set_: + query_res = connection.execute(select, params).fetchall() + eq_(len(query_res), len(result)) + eq_(set(query_res), set(result)) + + else: + eq_(connection.execute(select, params).fetchall(), result) + + def _assert_result_str(self, select, result, params=()): + with config.db.connect() as conn: + eq_(conn.exec_driver_sql(select, params).fetchall(), result) + + def test_simple_limit(self, connection): + table = self.tables.some_table + stmt = select(table).order_by(table.c.id) + self._assert_result( + connection, + stmt.limit(2), + [(1, 1, 2), (2, 2, 3)], + ) + self._assert_result( + connection, + stmt.limit(3), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + ) + + def test_limit_render_multiple_times(self, connection): + table = self.tables.some_table + stmt = select(table.c.id).limit(1).scalar_subquery() + + u = union(select(stmt), select(stmt)).subquery().select() + + self._assert_result( + connection, + u, + [ + (1,), + ], + ) + + @testing.requires.fetch_first + def test_simple_fetch(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(2), + [(1, 1, 2), (2, 2, 3)], + ) + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(3), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.offset + def test_simple_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(2), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(3), + [(4, 4, 5), (5, 4, 6)], + ) + + @testing.combinations( + ([(2, 0), (2, 1), (3, 2)]), + ([(2, 1), (2, 0), (3, 2)]), + ([(3, 1), (2, 1), (3, 1)]), + argnames="cases", + ) + @testing.requires.offset + def test_simple_limit_offset(self, connection, cases): + table = self.tables.some_table + connection = connection.execution_options(compiled_cache={}) + + assert_data = [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)] + + for limit, offset in cases: + expected = assert_data[offset : offset + limit] + self._assert_result( + connection, + select(table).order_by(table.c.id).limit(limit).offset(offset), + expected, + ) + + @testing.requires.fetch_first + def test_simple_fetch_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(2).offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(3).offset(2), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.fetch_no_order_by + def test_fetch_offset_no_order(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).fetch(10), + [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + set_=True, + ) + + @testing.requires.offset + def test_simple_offset_zero(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(0), + [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(1), + [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.offset + def test_limit_offset_nobinds(self): + """test that 'literal binds' mode works - no bound params.""" + + table = self.tables.some_table + stmt = select(table).order_by(table.c.id).limit(2).offset(1) + sql = stmt.compile( + dialect=config.db.dialect, compile_kwargs={"literal_binds": True} + ) + sql = str(sql) + + self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)]) + + @testing.requires.fetch_first + def test_fetch_offset_nobinds(self): + """test that 'literal binds' mode works - no bound params.""" + + table = self.tables.some_table + stmt = select(table).order_by(table.c.id).fetch(2).offset(1) + sql = stmt.compile( + dialect=config.db.dialect, compile_kwargs={"literal_binds": True} + ) + sql = str(sql) + + self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)]) + + @testing.requires.bound_limit_offset + def test_bound_limit(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).limit(bindparam("l")), + [(1, 1, 2), (2, 2, 3)], + params={"l": 2}, + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).limit(bindparam("l")), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + params={"l": 3}, + ) + + @testing.requires.bound_limit_offset + def test_bound_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(bindparam("o")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"o": 2}, + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(bindparam("o")), + [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"o": 1}, + ) + + @testing.requires.bound_limit_offset + def test_bound_limit_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(bindparam("l")) + .offset(bindparam("o")), + [(2, 2, 3), (3, 3, 4)], + params={"l": 2, "o": 1}, + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(bindparam("l")) + .offset(bindparam("o")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"l": 3, "o": 2}, + ) + + @testing.requires.fetch_first + def test_bound_fetch_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(bindparam("f")) + .offset(bindparam("o")), + [(2, 2, 3), (3, 3, 4)], + params={"f": 2, "o": 1}, + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(bindparam("f")) + .offset(bindparam("o")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"f": 3, "o": 2}, + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .offset(literal_column("1") + literal_column("2")), + [(4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_limit(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("2")), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_limit_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("1")) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5)], + ) + + @testing.requires.fetch_first + @testing.requires.fetch_expression + def test_expr_fetch_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(literal_column("1") + literal_column("1")) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5)], + ) + + @testing.requires.sql_expression_limit_offset + def test_simple_limit_expr_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(2) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5)], + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(3) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_limit_simple_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("1")) + .offset(2), + [(3, 3, 4), (4, 4, 5)], + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("1")) + .offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.fetch_ties + def test_simple_fetch_ties(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.x.desc()).fetch(1, with_ties=True), + [(4, 4, 5), (5, 4, 6)], + set_=True, + ) + + self._assert_result( + connection, + select(table).order_by(table.c.x.desc()).fetch(3, with_ties=True), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + set_=True, + ) + + @testing.requires.fetch_ties + @testing.requires.fetch_offset_with_options + def test_fetch_offset_ties(self, connection): + table = self.tables.some_table + fa = connection.execute( + select(table) + .order_by(table.c.x) + .fetch(2, with_ties=True) + .offset(2) + ).fetchall() + eq_(fa[0], (3, 3, 4)) + eq_(set(fa), {(3, 3, 4), (4, 4, 5), (5, 4, 6)}) + + @testing.requires.fetch_ties + @testing.requires.fetch_offset_with_options + def test_fetch_offset_ties_exact_number(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.x) + .fetch(2, with_ties=True) + .offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.x) + .fetch(3, with_ties=True) + .offset(3), + [(4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.fetch_percent + def test_simple_fetch_percent(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(20, percent=True), + [(1, 1, 2)], + ) + + @testing.requires.fetch_percent + @testing.requires.fetch_offset_with_options + def test_fetch_offset_percent(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(40, percent=True) + .offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.fetch_ties + @testing.requires.fetch_percent + def test_simple_fetch_percent_ties(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.x.desc()) + .fetch(20, percent=True, with_ties=True), + [(4, 4, 5), (5, 4, 6)], + set_=True, + ) + + @testing.requires.fetch_ties + @testing.requires.fetch_percent + @testing.requires.fetch_offset_with_options + def test_fetch_offset_percent_ties(self, connection): + table = self.tables.some_table + fa = connection.execute( + select(table) + .order_by(table.c.x) + .fetch(40, percent=True, with_ties=True) + .offset(2) + ).fetchall() + eq_(fa[0], (3, 3, 4)) + eq_(set(fa), {(3, 3, 4), (4, 4, 5), (5, 4, 6)}) + + +class SameNamedSchemaTableTest(fixtures.TablesTest): + """tests for #7471""" + + __backend__ = True + + __requires__ = ("schemas",) + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + schema=config.test_schema, + ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column( + "some_table_id", + Integer, + # ForeignKey("%s.some_table.id" % config.test_schema), + nullable=False, + ), + ) + + @classmethod + def insert_data(cls, connection): + some_table, some_table_schema = cls.tables( + "some_table", "%s.some_table" % config.test_schema + ) + connection.execute(some_table_schema.insert(), {"id": 1}) + connection.execute(some_table.insert(), {"id": 1, "some_table_id": 1}) + + def test_simple_join_both_tables(self, connection): + some_table, some_table_schema = self.tables( + "some_table", "%s.some_table" % config.test_schema + ) + + eq_( + connection.execute( + select(some_table, some_table_schema).join_from( + some_table, + some_table_schema, + some_table.c.some_table_id == some_table_schema.c.id, + ) + ).first(), + (1, 1, 1), + ) + + def test_simple_join_whereclause_only(self, connection): + some_table, some_table_schema = self.tables( + "some_table", "%s.some_table" % config.test_schema + ) + + eq_( + connection.execute( + select(some_table) + .join_from( + some_table, + some_table_schema, + some_table.c.some_table_id == some_table_schema.c.id, + ) + .where(some_table.c.id == 1) + ).first(), + (1, 1), + ) + + def test_subquery(self, connection): + some_table, some_table_schema = self.tables( + "some_table", "%s.some_table" % config.test_schema + ) + + subq = ( + select(some_table) + .join_from( + some_table, + some_table_schema, + some_table.c.some_table_id == some_table_schema.c.id, + ) + .where(some_table.c.id == 1) + .subquery() + ) + + eq_( + connection.execute( + select(some_table, subq.c.id) + .join_from( + some_table, + subq, + some_table.c.some_table_id == subq.c.id, + ) + .where(some_table.c.id == 1) + ).first(), + (1, 1, 1), + ) + + +class JoinTest(fixtures.TablesTest): + __backend__ = True + + def _assert_result(self, select, result, params=()): + with config.db.connect() as conn: + eq_(conn.execute(select, params).fetchall(), result) + + @classmethod + def define_tables(cls, metadata): + Table("a", metadata, Column("id", Integer, primary_key=True)) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("a_id", ForeignKey("a.id"), nullable=False), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.a.insert(), + [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}], + ) + + connection.execute( + cls.tables.b.insert(), + [ + {"id": 1, "a_id": 1}, + {"id": 2, "a_id": 1}, + {"id": 4, "a_id": 2}, + {"id": 5, "a_id": 3}, + ], + ) + + def test_inner_join_fk(self): + a, b = self.tables("a", "b") + + stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id) + + self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)]) + + def test_inner_join_true(self): + a, b = self.tables("a", "b") + + stmt = ( + select(a, b) + .select_from(a.join(b, true())) + .order_by(a.c.id, b.c.id) + ) + + self._assert_result( + stmt, + [ + (a, b, c) + for (a,), (b, c) in itertools.product( + [(1,), (2,), (3,), (4,), (5,)], + [(1, 1), (2, 1), (4, 2), (5, 3)], + ) + ], + ) + + def test_inner_join_false(self): + a, b = self.tables("a", "b") + + stmt = ( + select(a, b) + .select_from(a.join(b, false())) + .order_by(a.c.id, b.c.id) + ) + + self._assert_result(stmt, []) + + def test_outer_join_false(self): + a, b = self.tables("a", "b") + + stmt = ( + select(a, b) + .select_from(a.outerjoin(b, false())) + .order_by(a.c.id, b.c.id) + ) + + self._assert_result( + stmt, + [ + (1, None, None), + (2, None, None), + (3, None, None), + (4, None, None), + (5, None, None), + ], + ) + + def test_outer_join_fk(self): + a, b = self.tables("a", "b") + + stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id) + + self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)]) + + +class CompoundSelectTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + ], + ) + + def _assert_result(self, select, result, params=()): + with config.db.connect() as conn: + eq_(conn.execute(select, params).fetchall(), result) + + def test_plain_union(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_select_from_plain_union(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2).alias().select() + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.order_by_col_from_union + @testing.requires.parens_in_union_contained_select_w_limit_offset + def test_limit_offset_selectable_in_unions(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_wo_limit_offset + def test_order_by_selectable_in_unions(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_distinct_selectable_in_unions(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).distinct() + s2 = select(table).where(table.c.id == 3).distinct() + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_w_limit_offset + def test_limit_offset_in_unions_from_alias(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id) + + # this necessarily has double parens + u1 = union(s1, s2).alias() + self._assert_result( + u1.select().limit(2).order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_limit_offset_aliased_selectable_in_unions(self): + table = self.tables.some_table + s1 = ( + select(table) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + s2 = ( + select(table) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + +class PostCompileParamsTest( + AssertsExecutionResults, AssertsCompiledSQL, fixtures.TablesTest +): + __backend__ = True + + __requires__ = ("standard_cursor_sql",) + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("z", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2, "z": "z1"}, + {"id": 2, "x": 2, "y": 3, "z": "z2"}, + {"id": 3, "x": 3, "y": 4, "z": "z3"}, + {"id": 4, "x": 4, "y": 5, "z": "z4"}, + ], + ) + + def test_compile(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x == bindparam("q", literal_execute=True) + ) + + self.assert_compile( + stmt, + "SELECT some_table.id FROM some_table " + "WHERE some_table.x = __[POSTCOMPILE_q]", + {}, + ) + + def test_compile_literal_binds(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x == bindparam("q", 10, literal_execute=True) + ) + + self.assert_compile( + stmt, + "SELECT some_table.id FROM some_table WHERE some_table.x = 10", + {}, + literal_binds=True, + ) + + def test_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x == bindparam("q", literal_execute=True) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=10)) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE some_table.x = 10", + () if config.db.dialect.positional else {}, + ) + ) + + def test_execute_expanding_plus_literal_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x.in_(bindparam("q", expanding=True, literal_execute=True)) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[5, 6, 7])) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE some_table.x IN (5, 6, 7)", + () if config.db.dialect.positional else {}, + ) + ) + + @testing.requires.tuple_in + def test_execute_tuple_expanding_plus_literal_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + tuple_(table.c.x, table.c.y).in_( + bindparam("q", expanding=True, literal_execute=True) + ) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[(5, 10), (12, 18)])) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE (some_table.x, some_table.y) " + "IN (%s(5, 10), (12, 18))" + % ("VALUES " if config.db.dialect.tuple_in_values else ""), + () if config.db.dialect.positional else {}, + ) + ) + + @testing.requires.tuple_in + def test_execute_tuple_expanding_plus_literal_heterogeneous_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + tuple_(table.c.x, table.c.z).in_( + bindparam("q", expanding=True, literal_execute=True) + ) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[(5, "z1"), (12, "z3")])) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE (some_table.x, some_table.z) " + "IN (%s(5, 'z1'), (12, 'z3'))" + % ("VALUES " if config.db.dialect.tuple_in_values else ""), + () if config.db.dialect.positional else {}, + ) + ) + + +class ExpandingBoundInTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("z", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2, "z": "z1"}, + {"id": 2, "x": 2, "y": 3, "z": "z2"}, + {"id": 3, "x": 3, "y": 4, "z": "z3"}, + {"id": 4, "x": 4, "y": 5, "z": "z4"}, + ], + ) + + def _assert_result(self, select, result, params=()): + with config.db.connect() as conn: + eq_(conn.execute(select, params).fetchall(), result) + + def test_multiple_empty_sets_bindparam(self): + # test that any anonymous aliasing used by the dialect + # is fine with duplicates + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_(bindparam("q"))) + .where(table.c.y.in_(bindparam("p"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": [], "p": []}) + + def test_multiple_empty_sets_direct(self): + # test that any anonymous aliasing used by the dialect + # is fine with duplicates + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([])) + .where(table.c.y.in_([])) + .order_by(table.c.id) + ) + self._assert_result(stmt, []) + + @testing.requires.tuple_in_w_empty + def test_empty_heterogeneous_tuples_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + @testing.requires.tuple_in_w_empty + def test_empty_heterogeneous_tuples_direct(self): + table = self.tables.some_table + + def go(val, expected): + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(val)) + .order_by(table.c.id) + ) + self._assert_result(stmt, expected) + + go([], []) + go([(2, "z2"), (3, "z3"), (4, "z4")], [(2,), (3,), (4,)]) + go([], []) + + @testing.requires.tuple_in_w_empty + def test_empty_homogeneous_tuples_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + @testing.requires.tuple_in_w_empty + def test_empty_homogeneous_tuples_direct(self): + table = self.tables.some_table + + def go(val, expected): + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(val)) + .order_by(table.c.id) + ) + self._assert_result(stmt, expected) + + go([], []) + go([(1, 2), (2, 3), (3, 4)], [(1,), (2,), (3,)]) + go([], []) + + def test_bound_in_scalar_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]}) + + def test_bound_in_scalar_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([2, 3, 4])) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,), (4,)]) + + def test_nonempty_in_plus_empty_notin(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([2, 3])) + .where(table.c.id.not_in([])) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,)]) + + def test_empty_in_plus_notempty_notin(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([])) + .where(table.c.id.not_in([2, 3])) + .order_by(table.c.id) + ) + self._assert_result(stmt, []) + + def test_typed_str_in(self): + """test related to #7292. + + as a type is given to the bound param, there is no ambiguity + to the type of element. + + """ + + stmt = text( + "select id FROM some_table WHERE z IN :q ORDER BY id" + ).bindparams(bindparam("q", type_=String, expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": ["z2", "z3", "z4"]}, + ) + + def test_untyped_str_in(self): + """test related to #7292. + + for untyped expression, we look at the types of elements. + Test for Sequence to detect tuple in. but not strings or bytes! + as always.... + + """ + + stmt = text( + "select id FROM some_table WHERE z IN :q ORDER BY id" + ).bindparams(bindparam("q", expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": ["z2", "z3", "z4"]}, + ) + + @testing.requires.tuple_in + def test_bound_in_two_tuple_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result( + stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]} + ) + + @testing.requires.tuple_in + def test_bound_in_two_tuple_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_([(2, 3), (3, 4), (4, 5)])) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,), (4,)]) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where( + tuple_(table.c.x, table.c.z).in_( + [(2, "z2"), (3, "z3"), (4, "z4")] + ) + ) + .order_by(table.c.id) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_text_bindparam(self): + # note this becomes ARRAY if we dont use expanding + # explicitly right now + stmt = text( + "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" + ).bindparams(bindparam("q", expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_typed_bindparam_non_tuple(self): + class LikeATuple(collections_abc.Sequence): + def __init__(self, *data): + self._data = data + + def __iter__(self): + return iter(self._data) + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + stmt = text( + "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" + ).bindparams( + bindparam( + "q", type_=TupleType(Integer(), String()), expanding=True + ) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={ + "q": [ + LikeATuple(2, "z2"), + LikeATuple(3, "z3"), + LikeATuple(4, "z4"), + ] + }, + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_text_bindparam_non_tuple(self): + # note this becomes ARRAY if we dont use expanding + # explicitly right now + + class LikeATuple(collections_abc.Sequence): + def __init__(self, *data): + self._data = data + + def __iter__(self): + return iter(self._data) + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + stmt = text( + "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" + ).bindparams(bindparam("q", expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={ + "q": [ + LikeATuple(2, "z2"), + LikeATuple(3, "z3"), + LikeATuple(4, "z4"), + ] + }, + ) + + def test_empty_set_against_integer_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + def test_empty_set_against_integer_direct(self): + table = self.tables.some_table + stmt = select(table.c.id).where(table.c.x.in_([])).order_by(table.c.id) + self._assert_result(stmt, []) + + def test_empty_set_against_integer_negation_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.not_in(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) + + def test_empty_set_against_integer_negation_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id).where(table.c.x.not_in([])).order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)]) + + def test_empty_set_against_string_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.z.in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + def test_empty_set_against_string_direct(self): + table = self.tables.some_table + stmt = select(table.c.id).where(table.c.z.in_([])).order_by(table.c.id) + self._assert_result(stmt, []) + + def test_empty_set_against_string_negation_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.z.not_in(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) + + def test_empty_set_against_string_negation_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id).where(table.c.z.not_in([])).order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)]) + + def test_null_in_empty_set_is_false_bindparam(self, connection): + stmt = select( + case( + ( + null().in_(bindparam("foo", value=())), + true(), + ), + else_=false(), + ) + ) + in_(connection.execute(stmt).fetchone()[0], (False, 0)) + + def test_null_in_empty_set_is_false_direct(self, connection): + stmt = select( + case( + ( + null().in_([]), + true(), + ), + else_=false(), + ) + ) + in_(connection.execute(stmt).fetchone()[0], (False, 0)) + + +class LikeFunctionsTest(fixtures.TablesTest): + __backend__ = True + + run_inserts = "once" + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "abcdefg"}, + {"id": 2, "data": "ab/cdefg"}, + {"id": 3, "data": "ab%cdefg"}, + {"id": 4, "data": "ab_cdefg"}, + {"id": 5, "data": "abcde/fg"}, + {"id": 6, "data": "abcde%fg"}, + {"id": 7, "data": "ab#cdefg"}, + {"id": 8, "data": "ab9cdefg"}, + {"id": 9, "data": "abcde#fg"}, + {"id": 10, "data": "abcd9fg"}, + {"id": 11, "data": None}, + ], + ) + + def _test(self, expr, expected): + some_table = self.tables.some_table + + with config.db.connect() as conn: + rows = { + value + for value, in conn.execute(select(some_table.c.id).where(expr)) + } + + eq_(rows, expected) + + def test_startswith_unescaped(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab%c"), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + + def test_startswith_autoescape(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab%c", autoescape=True), {3}) + + def test_startswith_sqlexpr(self): + col = self.tables.some_table.c.data + self._test( + col.startswith(literal_column("'ab%c'")), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + ) + + def test_startswith_escape(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab##c", escape="#"), {7}) + + def test_startswith_autoescape_escape(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab%c", autoescape=True, escape="#"), {3}) + self._test(col.startswith("ab#c", autoescape=True, escape="#"), {7}) + + def test_endswith_unescaped(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e%fg"), {1, 2, 3, 4, 5, 6, 7, 8, 9}) + + def test_endswith_sqlexpr(self): + col = self.tables.some_table.c.data + self._test( + col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9} + ) + + def test_endswith_autoescape(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e%fg", autoescape=True), {6}) + + def test_endswith_escape(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e##fg", escape="#"), {9}) + + def test_endswith_autoescape_escape(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e%fg", autoescape=True, escape="#"), {6}) + self._test(col.endswith("e#fg", autoescape=True, escape="#"), {9}) + + def test_contains_unescaped(self): + col = self.tables.some_table.c.data + self._test(col.contains("b%cde"), {1, 2, 3, 4, 5, 6, 7, 8, 9}) + + def test_contains_autoescape(self): + col = self.tables.some_table.c.data + self._test(col.contains("b%cde", autoescape=True), {3}) + + def test_contains_escape(self): + col = self.tables.some_table.c.data + self._test(col.contains("b##cde", escape="#"), {7}) + + def test_contains_autoescape_escape(self): + col = self.tables.some_table.c.data + self._test(col.contains("b%cd", autoescape=True, escape="#"), {3}) + self._test(col.contains("b#cd", autoescape=True, escape="#"), {7}) + + @testing.requires.regexp_match + def test_not_regexp_match(self): + col = self.tables.some_table.c.data + self._test(~col.regexp_match("a.cde"), {2, 3, 4, 7, 8, 10}) + + @testing.requires.regexp_replace + def test_regexp_replace(self): + col = self.tables.some_table.c.data + self._test( + col.regexp_replace("a.cde", "FOO").contains("FOO"), {1, 5, 6, 9} + ) + + @testing.requires.regexp_match + @testing.combinations( + ("a.cde", {1, 5, 6, 9}), + ("abc", {1, 5, 6, 9, 10}), + ("^abc", {1, 5, 6, 9, 10}), + ("9cde", {8}), + ("^a", set(range(1, 11))), + ("(b|c)", set(range(1, 11))), + ("^(b|c)", set()), + ) + def test_regexp_match(self, text, expected): + col = self.tables.some_table.c.data + self._test(col.regexp_match(text), expected) + + +class ComputedColumnTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("computed_columns",) + + @classmethod + def define_tables(cls, metadata): + Table( + "square", + metadata, + Column("id", Integer, primary_key=True), + Column("side", Integer), + Column("area", Integer, Computed("side * side")), + Column("perimeter", Integer, Computed("4 * side")), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.square.insert(), + [{"id": 1, "side": 10}, {"id": 10, "side": 42}], + ) + + def test_select_all(self): + with config.db.connect() as conn: + res = conn.execute( + select(text("*")) + .select_from(self.tables.square) + .order_by(self.tables.square.c.id) + ).fetchall() + eq_(res, [(1, 10, 100, 40), (10, 42, 1764, 168)]) + + def test_select_columns(self): + with config.db.connect() as conn: + res = conn.execute( + select( + self.tables.square.c.area, self.tables.square.c.perimeter + ) + .select_from(self.tables.square) + .order_by(self.tables.square.c.id) + ).fetchall() + eq_(res, [(100, 40), (1764, 168)]) + + +class IdentityColumnTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("identity_columns",) + run_inserts = "once" + run_deletes = "once" + + @classmethod + def define_tables(cls, metadata): + Table( + "tbl_a", + metadata, + Column( + "id", + Integer, + Identity( + always=True, start=42, nominvalue=True, nomaxvalue=True + ), + primary_key=True, + ), + Column("desc", String(100)), + ) + Table( + "tbl_b", + metadata, + Column( + "id", + Integer, + Identity(increment=-5, start=0, minvalue=-1000, maxvalue=0), + primary_key=True, + ), + Column("desc", String(100)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.tbl_a.insert(), + [{"desc": "a"}, {"desc": "b"}], + ) + connection.execute( + cls.tables.tbl_b.insert(), + [{"desc": "a"}, {"desc": "b"}], + ) + connection.execute( + cls.tables.tbl_b.insert(), + [{"id": 42, "desc": "c"}], + ) + + def test_select_all(self, connection): + res = connection.execute( + select(text("*")) + .select_from(self.tables.tbl_a) + .order_by(self.tables.tbl_a.c.id) + ).fetchall() + eq_(res, [(42, "a"), (43, "b")]) + + res = connection.execute( + select(text("*")) + .select_from(self.tables.tbl_b) + .order_by(self.tables.tbl_b.c.id) + ).fetchall() + eq_(res, [(-5, "b"), (0, "a"), (42, "c")]) + + def test_select_columns(self, connection): + res = connection.execute( + select(self.tables.tbl_a.c.id).order_by(self.tables.tbl_a.c.id) + ).fetchall() + eq_(res, [(42,), (43,)]) + + @testing.requires.identity_columns_standard + def test_insert_always_error(self, connection): + def fn(): + connection.execute( + self.tables.tbl_a.insert(), + [{"id": 200, "desc": "a"}], + ) + + assert_raises((DatabaseError, ProgrammingError), fn) + + +class IdentityAutoincrementTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("autoincrement_without_sequence",) + + @classmethod + def define_tables(cls, metadata): + Table( + "tbl", + metadata, + Column( + "id", + Integer, + Identity(), + primary_key=True, + autoincrement=True, + ), + Column("desc", String(100)), + ) + + def test_autoincrement_with_identity(self, connection): + res = connection.execute(self.tables.tbl.insert(), {"desc": "row"}) + res = connection.execute(self.tables.tbl.select()).first() + eq_(res, (1, "row")) + + +class ExistsTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "stuff", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.stuff.insert(), + [ + {"id": 1, "data": "some data"}, + {"id": 2, "data": "some data"}, + {"id": 3, "data": "some data"}, + {"id": 4, "data": "some other data"}, + ], + ) + + def test_select_exists(self, connection): + stuff = self.tables.stuff + eq_( + connection.execute( + select(literal(1)).where( + exists().where(stuff.c.data == "some data") + ) + ).fetchall(), + [(1,)], + ) + + def test_select_exists_false(self, connection): + stuff = self.tables.stuff + eq_( + connection.execute( + select(literal(1)).where( + exists().where(stuff.c.data == "no data") + ) + ).fetchall(), + [], + ) + + +class DistinctOnTest(AssertsCompiledSQL, fixtures.TablesTest): + __backend__ = True + + @testing.fails_if(testing.requires.supports_distinct_on) + def test_distinct_on(self): + stm = select("*").distinct(column("q")).select_from(table("foo")) + with testing.expect_deprecated( + "DISTINCT ON is currently supported only by the PostgreSQL " + ): + self.assert_compile(stm, "SELECT DISTINCT * FROM foo") + + +class IsOrIsNotDistinctFromTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("supports_is_distinct_from",) + + @classmethod + def define_tables(cls, metadata): + Table( + "is_distinct_test", + metadata, + Column("id", Integer, primary_key=True), + Column("col_a", Integer, nullable=True), + Column("col_b", Integer, nullable=True), + ) + + @testing.combinations( + ("both_int_different", 0, 1, 1), + ("both_int_same", 1, 1, 0), + ("one_null_first", None, 1, 1), + ("one_null_second", 0, None, 1), + ("both_null", None, None, 0), + id_="iaaa", + argnames="col_a_value, col_b_value, expected_row_count_for_is", + ) + def test_is_or_is_not_distinct_from( + self, col_a_value, col_b_value, expected_row_count_for_is, connection + ): + tbl = self.tables.is_distinct_test + + connection.execute( + tbl.insert(), + [{"id": 1, "col_a": col_a_value, "col_b": col_b_value}], + ) + + result = connection.execute( + tbl.select().where(tbl.c.col_a.is_distinct_from(tbl.c.col_b)) + ).fetchall() + eq_( + len(result), + expected_row_count_for_is, + ) + + expected_row_count_for_is_not = ( + 1 if expected_row_count_for_is == 0 else 0 + ) + result = connection.execute( + tbl.select().where(tbl.c.col_a.is_not_distinct_from(tbl.c.col_b)) + ).fetchall() + eq_( + len(result), + expected_row_count_for_is_not, + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_sequence.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_sequence.py new file mode 100644 index 0000000..138616f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_sequence.py @@ -0,0 +1,317 @@ +# testing/suite/test_sequence.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .. import config +from .. import fixtures +from ..assertions import eq_ +from ..assertions import is_true +from ..config import requirements +from ..provision import normalize_sequence +from ..schema import Column +from ..schema import Table +from ... import inspect +from ... import Integer +from ... import MetaData +from ... import Sequence +from ... import String +from ... import testing + + +class SequenceTest(fixtures.TablesTest): + __requires__ = ("sequences",) + __backend__ = True + + run_create_tables = "each" + + @classmethod + def define_tables(cls, metadata): + Table( + "seq_pk", + metadata, + Column( + "id", + Integer, + normalize_sequence(config, Sequence("tab_id_seq")), + primary_key=True, + ), + Column("data", String(50)), + ) + + Table( + "seq_opt_pk", + metadata, + Column( + "id", + Integer, + normalize_sequence( + config, + Sequence("tab_id_seq", data_type=Integer, optional=True), + ), + primary_key=True, + ), + Column("data", String(50)), + ) + + Table( + "seq_no_returning", + metadata, + Column( + "id", + Integer, + normalize_sequence(config, Sequence("noret_id_seq")), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=False, + ) + + if testing.requires.schemas.enabled: + Table( + "seq_no_returning_sch", + metadata, + Column( + "id", + Integer, + normalize_sequence( + config, + Sequence( + "noret_sch_id_seq", schema=config.test_schema + ), + ), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=False, + schema=config.test_schema, + ) + + def test_insert_roundtrip(self, connection): + connection.execute(self.tables.seq_pk.insert(), dict(data="some data")) + self._assert_round_trip(self.tables.seq_pk, connection) + + def test_insert_lastrowid(self, connection): + r = connection.execute( + self.tables.seq_pk.insert(), dict(data="some data") + ) + eq_( + r.inserted_primary_key, (testing.db.dialect.default_sequence_base,) + ) + + def test_nextval_direct(self, connection): + r = connection.scalar(self.tables.seq_pk.c.id.default) + eq_(r, testing.db.dialect.default_sequence_base) + + @requirements.sequences_optional + def test_optional_seq(self, connection): + r = connection.execute( + self.tables.seq_opt_pk.insert(), dict(data="some data") + ) + eq_(r.inserted_primary_key, (1,)) + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_(row, (testing.db.dialect.default_sequence_base, "some data")) + + def test_insert_roundtrip_no_implicit_returning(self, connection): + connection.execute( + self.tables.seq_no_returning.insert(), dict(data="some data") + ) + self._assert_round_trip(self.tables.seq_no_returning, connection) + + @testing.combinations((True,), (False,), argnames="implicit_returning") + @testing.requires.schemas + def test_insert_roundtrip_translate(self, connection, implicit_returning): + seq_no_returning = Table( + "seq_no_returning_sch", + MetaData(), + Column( + "id", + Integer, + normalize_sequence( + config, Sequence("noret_sch_id_seq", schema="alt_schema") + ), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=implicit_returning, + schema="alt_schema", + ) + + connection = connection.execution_options( + schema_translate_map={"alt_schema": config.test_schema} + ) + connection.execute(seq_no_returning.insert(), dict(data="some data")) + self._assert_round_trip(seq_no_returning, connection) + + @testing.requires.schemas + def test_nextval_direct_schema_translate(self, connection): + seq = normalize_sequence( + config, Sequence("noret_sch_id_seq", schema="alt_schema") + ) + connection = connection.execution_options( + schema_translate_map={"alt_schema": config.test_schema} + ) + + r = connection.scalar(seq) + eq_(r, testing.db.dialect.default_sequence_base) + + +class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase): + __requires__ = ("sequences",) + __backend__ = True + + def test_literal_binds_inline_compile(self, connection): + table = Table( + "x", + MetaData(), + Column( + "y", Integer, normalize_sequence(config, Sequence("y_seq")) + ), + Column("q", Integer), + ) + + stmt = table.insert().values(q=5) + + seq_nextval = connection.dialect.statement_compiler( + statement=None, dialect=connection.dialect + ).visit_sequence(normalize_sequence(config, Sequence("y_seq"))) + self.assert_compile( + stmt, + "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,), + literal_binds=True, + dialect=connection.dialect, + ) + + +class HasSequenceTest(fixtures.TablesTest): + run_deletes = None + + __requires__ = ("sequences",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + normalize_sequence(config, Sequence("user_id_seq", metadata=metadata)) + normalize_sequence( + config, + Sequence( + "other_seq", + metadata=metadata, + nomaxvalue=True, + nominvalue=True, + ), + ) + if testing.requires.schemas.enabled: + normalize_sequence( + config, + Sequence( + "user_id_seq", schema=config.test_schema, metadata=metadata + ), + ) + normalize_sequence( + config, + Sequence( + "schema_seq", schema=config.test_schema, metadata=metadata + ), + ) + Table( + "user_id_table", + metadata, + Column("id", Integer, primary_key=True), + ) + + def test_has_sequence(self, connection): + eq_(inspect(connection).has_sequence("user_id_seq"), True) + + def test_has_sequence_cache(self, connection, metadata): + insp = inspect(connection) + eq_(insp.has_sequence("user_id_seq"), True) + ss = normalize_sequence(config, Sequence("new_seq", metadata=metadata)) + eq_(insp.has_sequence("new_seq"), False) + ss.create(connection) + try: + eq_(insp.has_sequence("new_seq"), False) + insp.clear_cache() + eq_(insp.has_sequence("new_seq"), True) + finally: + ss.drop(connection) + + def test_has_sequence_other_object(self, connection): + eq_(inspect(connection).has_sequence("user_id_table"), False) + + @testing.requires.schemas + def test_has_sequence_schema(self, connection): + eq_( + inspect(connection).has_sequence( + "user_id_seq", schema=config.test_schema + ), + True, + ) + + def test_has_sequence_neg(self, connection): + eq_(inspect(connection).has_sequence("some_sequence"), False) + + @testing.requires.schemas + def test_has_sequence_schemas_neg(self, connection): + eq_( + inspect(connection).has_sequence( + "some_sequence", schema=config.test_schema + ), + False, + ) + + @testing.requires.schemas + def test_has_sequence_default_not_in_remote(self, connection): + eq_( + inspect(connection).has_sequence( + "other_sequence", schema=config.test_schema + ), + False, + ) + + @testing.requires.schemas + def test_has_sequence_remote_not_in_default(self, connection): + eq_(inspect(connection).has_sequence("schema_seq"), False) + + def test_get_sequence_names(self, connection): + exp = {"other_seq", "user_id_seq"} + + res = set(inspect(connection).get_sequence_names()) + is_true(res.intersection(exp) == exp) + is_true("schema_seq" not in res) + + @testing.requires.schemas + def test_get_sequence_names_no_sequence_schema(self, connection): + eq_( + inspect(connection).get_sequence_names( + schema=config.test_schema_2 + ), + [], + ) + + @testing.requires.schemas + def test_get_sequence_names_sequences_schema(self, connection): + eq_( + sorted( + inspect(connection).get_sequence_names( + schema=config.test_schema + ) + ), + ["schema_seq", "user_id_seq"], + ) + + +class HasSequenceTestEmpty(fixtures.TestBase): + __requires__ = ("sequences",) + __backend__ = True + + def test_get_sequence_names_no_sequence(self, connection): + eq_( + inspect(connection).get_sequence_names(), + [], + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_types.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_types.py new file mode 100644 index 0000000..4a7c1f1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_types.py @@ -0,0 +1,2071 @@ +# testing/suite/test_types.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +import datetime +import decimal +import json +import re +import uuid + +from .. import config +from .. import engines +from .. import fixtures +from .. import mock +from ..assertions import eq_ +from ..assertions import is_ +from ..assertions import ne_ +from ..config import requirements +from ..schema import Column +from ..schema import Table +from ... import and_ +from ... import ARRAY +from ... import BigInteger +from ... import bindparam +from ... import Boolean +from ... import case +from ... import cast +from ... import Date +from ... import DateTime +from ... import Float +from ... import Integer +from ... import Interval +from ... import JSON +from ... import literal +from ... import literal_column +from ... import MetaData +from ... import null +from ... import Numeric +from ... import select +from ... import String +from ... import testing +from ... import Text +from ... import Time +from ... import TIMESTAMP +from ... import type_coerce +from ... import TypeDecorator +from ... import Unicode +from ... import UnicodeText +from ... import UUID +from ... import Uuid +from ...orm import declarative_base +from ...orm import Session +from ...sql import sqltypes +from ...sql.sqltypes import LargeBinary +from ...sql.sqltypes import PickleType + + +class _LiteralRoundTripFixture: + supports_whereclause = True + + @testing.fixture + def literal_round_trip(self, metadata, connection): + """test literal rendering""" + + # for literal, we test the literal render in an INSERT + # into a typed column. we can then SELECT it back as its + # official type; ideally we'd be able to use CAST here + # but MySQL in particular can't CAST fully + + def run( + type_, + input_, + output, + filter_=None, + compare=None, + support_whereclause=True, + ): + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + + for value in input_: + ins = t.insert().values( + x=literal(value, type_, literal_execute=True) + ) + connection.execute(ins) + + ins = t.insert().values( + x=literal(None, type_, literal_execute=True) + ) + connection.execute(ins) + + if support_whereclause and self.supports_whereclause: + if compare: + stmt = t.select().where( + t.c.x + == literal( + compare, + type_, + literal_execute=True, + ), + t.c.x + == literal( + input_[0], + type_, + literal_execute=True, + ), + ) + else: + stmt = t.select().where( + t.c.x + == literal( + compare if compare is not None else input_[0], + type_, + literal_execute=True, + ) + ) + else: + stmt = t.select().where(t.c.x.is_not(None)) + + rows = connection.execute(stmt).all() + assert rows, "No rows returned" + for row in rows: + value = row[0] + if filter_ is not None: + value = filter_(value) + assert value in output + + stmt = t.select().where(t.c.x.is_(None)) + rows = connection.execute(stmt).all() + eq_(rows, [(None,)]) + + return run + + +class _UnicodeFixture(_LiteralRoundTripFixture, fixtures.TestBase): + __requires__ = ("unicode_data",) + + data = ( + "Alors vous imaginez ma 🐍 surprise, au lever du jour, " + "quand une drôle de petite 🐍 voix m’a réveillé. Elle " + "disait: « S’il vous plaît… dessine-moi 🐍 un mouton! »" + ) + + @property + def supports_whereclause(self): + return config.requirements.expressions_against_unbounded_text.enabled + + @classmethod + def define_tables(cls, metadata): + Table( + "unicode_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("unicode_data", cls.datatype), + ) + + def test_round_trip(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), {"id": 1, "unicode_data": self.data} + ) + + row = connection.execute(select(unicode_table.c.unicode_data)).first() + + eq_(row, (self.data,)) + assert isinstance(row[0], str) + + def test_round_trip_executemany(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), + [{"id": i, "unicode_data": self.data} for i in range(1, 4)], + ) + + rows = connection.execute( + select(unicode_table.c.unicode_data) + ).fetchall() + eq_(rows, [(self.data,) for i in range(1, 4)]) + for row in rows: + assert isinstance(row[0], str) + + def _test_null_strings(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), {"id": 1, "unicode_data": None} + ) + row = connection.execute(select(unicode_table.c.unicode_data)).first() + eq_(row, (None,)) + + def _test_empty_strings(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), {"id": 1, "unicode_data": ""} + ) + row = connection.execute(select(unicode_table.c.unicode_data)).first() + eq_(row, ("",)) + + def test_literal(self, literal_round_trip): + literal_round_trip(self.datatype, [self.data], [self.data]) + + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip(self.datatype, ["réve🐍 illé"], ["réve🐍 illé"]) + + +class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest): + __requires__ = ("unicode_data",) + __backend__ = True + + datatype = Unicode(255) + + @requirements.empty_strings_varchar + def test_empty_strings_varchar(self, connection): + self._test_empty_strings(connection) + + def test_null_strings_varchar(self, connection): + self._test_null_strings(connection) + + +class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest): + __requires__ = "unicode_data", "text_type" + __backend__ = True + + datatype = UnicodeText() + + @requirements.empty_strings_text + def test_empty_strings_text(self, connection): + self._test_empty_strings(connection) + + def test_null_strings_text(self, connection): + self._test_null_strings(connection) + + +class ArrayTest(_LiteralRoundTripFixture, fixtures.TablesTest): + """Add ARRAY test suite, #8138. + + This only works on PostgreSQL right now. + + """ + + __requires__ = ("array_type",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "array_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("single_dim", ARRAY(Integer)), + Column("multi_dim", ARRAY(String, dimensions=2)), + ) + + def test_array_roundtrip(self, connection): + array_table = self.tables.array_table + + connection.execute( + array_table.insert(), + { + "id": 1, + "single_dim": [1, 2, 3], + "multi_dim": [["one", "two"], ["thr'ee", "réve🐍 illé"]], + }, + ) + row = connection.execute( + select(array_table.c.single_dim, array_table.c.multi_dim) + ).first() + eq_(row, ([1, 2, 3], [["one", "two"], ["thr'ee", "réve🐍 illé"]])) + + def test_literal_simple(self, literal_round_trip): + literal_round_trip( + ARRAY(Integer), + ([1, 2, 3],), + ([1, 2, 3],), + support_whereclause=False, + ) + + def test_literal_complex(self, literal_round_trip): + literal_round_trip( + ARRAY(String, dimensions=2), + ([["one", "two"], ["thr'ee", "réve🐍 illé"]],), + ([["one", "two"], ["thr'ee", "réve🐍 illé"]],), + support_whereclause=False, + ) + + +class BinaryTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "binary_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("binary_data", LargeBinary), + Column("pickle_data", PickleType), + ) + + @testing.combinations(b"this is binary", b"7\xe7\x9f", argnames="data") + def test_binary_roundtrip(self, connection, data): + binary_table = self.tables.binary_table + + connection.execute( + binary_table.insert(), {"id": 1, "binary_data": data} + ) + row = connection.execute(select(binary_table.c.binary_data)).first() + eq_(row, (data,)) + + def test_pickle_roundtrip(self, connection): + binary_table = self.tables.binary_table + + connection.execute( + binary_table.insert(), + {"id": 1, "pickle_data": {"foo": [1, 2, 3], "bar": "bat"}}, + ) + row = connection.execute(select(binary_table.c.pickle_data)).first() + eq_(row, ({"foo": [1, 2, 3], "bar": "bat"},)) + + +class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __requires__ = ("text_type",) + __backend__ = True + + @property + def supports_whereclause(self): + return config.requirements.expressions_against_unbounded_text.enabled + + @classmethod + def define_tables(cls, metadata): + Table( + "text_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("text_data", Text), + ) + + def test_text_roundtrip(self, connection): + text_table = self.tables.text_table + + connection.execute( + text_table.insert(), {"id": 1, "text_data": "some text"} + ) + row = connection.execute(select(text_table.c.text_data)).first() + eq_(row, ("some text",)) + + @testing.requires.empty_strings_text + def test_text_empty_strings(self, connection): + text_table = self.tables.text_table + + connection.execute(text_table.insert(), {"id": 1, "text_data": ""}) + row = connection.execute(select(text_table.c.text_data)).first() + eq_(row, ("",)) + + def test_text_null_strings(self, connection): + text_table = self.tables.text_table + + connection.execute(text_table.insert(), {"id": 1, "text_data": None}) + row = connection.execute(select(text_table.c.text_data)).first() + eq_(row, (None,)) + + def test_literal(self, literal_round_trip): + literal_round_trip(Text, ["some text"], ["some text"]) + + @requirements.unicode_data_no_special_types + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip(Text, ["réve🐍 illé"], ["réve🐍 illé"]) + + def test_literal_quoting(self, literal_round_trip): + data = """some 'text' hey "hi there" that's text""" + literal_round_trip(Text, [data], [data]) + + def test_literal_backslashes(self, literal_round_trip): + data = r"backslash one \ backslash two \\ end" + literal_round_trip(Text, [data], [data]) + + def test_literal_percentsigns(self, literal_round_trip): + data = r"percent % signs %% percent" + literal_round_trip(Text, [data], [data]) + + +class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @requirements.unbounded_varchar + def test_nolength_string(self): + metadata = MetaData() + foo = Table("foo", metadata, Column("one", String)) + + foo.create(config.db) + foo.drop(config.db) + + def test_literal(self, literal_round_trip): + # note that in Python 3, this invokes the Unicode + # datatype for the literal part because all strings are unicode + literal_round_trip(String(40), ["some text"], ["some text"]) + + @requirements.unicode_data_no_special_types + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip(String(40), ["réve🐍 illé"], ["réve🐍 illé"]) + + @testing.combinations( + ("%B%", ["AB", "BC"]), + ("A%C", ["AC"]), + ("A%C%Z", []), + argnames="expr, expected", + ) + def test_dont_truncate_rightside( + self, metadata, connection, expr, expected + ): + t = Table("t", metadata, Column("x", String(2))) + t.create(connection) + + connection.execute(t.insert(), [{"x": "AB"}, {"x": "BC"}, {"x": "AC"}]) + + eq_( + connection.scalars(select(t.c.x).where(t.c.x.like(expr))).all(), + expected, + ) + + def test_literal_quoting(self, literal_round_trip): + data = """some 'text' hey "hi there" that's text""" + literal_round_trip(String(40), [data], [data]) + + def test_literal_backslashes(self, literal_round_trip): + data = r"backslash one \ backslash two \\ end" + literal_round_trip(String(40), [data], [data]) + + def test_concatenate_binary(self, connection): + """dialects with special string concatenation operators should + implement visit_concat_op_binary() and visit_concat_op_clauselist() + in their compiler. + + .. versionchanged:: 2.0 visit_concat_op_clauselist() is also needed + for dialects to override the string concatenation operator. + + """ + eq_(connection.scalar(select(literal("a") + "b")), "ab") + + def test_concatenate_clauselist(self, connection): + """dialects with special string concatenation operators should + implement visit_concat_op_binary() and visit_concat_op_clauselist() + in their compiler. + + .. versionchanged:: 2.0 visit_concat_op_clauselist() is also needed + for dialects to override the string concatenation operator. + + """ + eq_( + connection.scalar(select(literal("a") + "b" + "c" + "d" + "e")), + "abcde", + ) + + +class IntervalTest(_LiteralRoundTripFixture, fixtures.TestBase): + __requires__ = ("datetime_interval",) + __backend__ = True + + datatype = Interval + data = datetime.timedelta(days=1, seconds=4) + + def test_literal(self, literal_round_trip): + literal_round_trip(self.datatype, [self.data], [self.data]) + + def test_select_direct_literal_interval(self, connection): + row = connection.execute(select(literal(self.data))).first() + eq_(row, (self.data,)) + + def test_arithmetic_operation_literal_interval(self, connection): + now = datetime.datetime.now().replace(microsecond=0) + # Able to subtract + row = connection.execute( + select(literal(now) - literal(self.data)) + ).scalar() + eq_(row, now - self.data) + + # Able to Add + row = connection.execute( + select(literal(now) + literal(self.data)) + ).scalar() + eq_(row, now + self.data) + + @testing.fixture + def arithmetic_table_fixture(cls, metadata, connection): + class Decorated(TypeDecorator): + impl = cls.datatype + cache_ok = True + + it = Table( + "interval_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("interval_data", cls.datatype), + Column("date_data", DateTime), + Column("decorated_interval_data", Decorated), + ) + it.create(connection) + return it + + def test_arithmetic_operation_table_interval_and_literal_interval( + self, connection, arithmetic_table_fixture + ): + interval_table = arithmetic_table_fixture + data = datetime.timedelta(days=2, seconds=5) + connection.execute( + interval_table.insert(), {"id": 1, "interval_data": data} + ) + # Subtraction Operation + value = connection.execute( + select(interval_table.c.interval_data - literal(self.data)) + ).scalar() + eq_(value, data - self.data) + + # Addition Operation + value = connection.execute( + select(interval_table.c.interval_data + literal(self.data)) + ).scalar() + eq_(value, data + self.data) + + def test_arithmetic_operation_table_date_and_literal_interval( + self, connection, arithmetic_table_fixture + ): + interval_table = arithmetic_table_fixture + now = datetime.datetime.now().replace(microsecond=0) + connection.execute( + interval_table.insert(), {"id": 1, "date_data": now} + ) + # Subtraction Operation + value = connection.execute( + select(interval_table.c.date_data - literal(self.data)) + ).scalar() + eq_(value, (now - self.data)) + + # Addition Operation + value = connection.execute( + select(interval_table.c.date_data + literal(self.data)) + ).scalar() + eq_(value, (now + self.data)) + + +class PrecisionIntervalTest(IntervalTest): + __requires__ = ("datetime_interval",) + __backend__ = True + + datatype = Interval(day_precision=9, second_precision=9) + data = datetime.timedelta(days=103, seconds=4) + + +class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): + compare = None + + @classmethod + def define_tables(cls, metadata): + class Decorated(TypeDecorator): + impl = cls.datatype + cache_ok = True + + Table( + "date_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("date_data", cls.datatype), + Column("decorated_date_data", Decorated), + ) + + def test_round_trip(self, connection): + date_table = self.tables.date_table + + connection.execute( + date_table.insert(), {"id": 1, "date_data": self.data} + ) + + row = connection.execute(select(date_table.c.date_data)).first() + + compare = self.compare or self.data + eq_(row, (compare,)) + assert isinstance(row[0], type(compare)) + + def test_round_trip_decorated(self, connection): + date_table = self.tables.date_table + + connection.execute( + date_table.insert(), {"id": 1, "decorated_date_data": self.data} + ) + + row = connection.execute( + select(date_table.c.decorated_date_data) + ).first() + + compare = self.compare or self.data + eq_(row, (compare,)) + assert isinstance(row[0], type(compare)) + + def test_null(self, connection): + date_table = self.tables.date_table + + connection.execute(date_table.insert(), {"id": 1, "date_data": None}) + + row = connection.execute(select(date_table.c.date_data)).first() + eq_(row, (None,)) + + @testing.requires.datetime_literals + def test_literal(self, literal_round_trip): + compare = self.compare or self.data + + literal_round_trip( + self.datatype, [self.data], [compare], compare=compare + ) + + @testing.requires.standalone_null_binds_whereclause + def test_null_bound_comparison(self): + # this test is based on an Oracle issue observed in #4886. + # passing NULL for an expression that needs to be interpreted as + # a certain type, does the DBAPI have the info it needs to do this. + date_table = self.tables.date_table + with config.db.begin() as conn: + result = conn.execute( + date_table.insert(), {"id": 1, "date_data": self.data} + ) + id_ = result.inserted_primary_key[0] + stmt = select(date_table.c.id).where( + case( + ( + bindparam("foo", type_=self.datatype) != None, + bindparam("foo", type_=self.datatype), + ), + else_=date_table.c.date_data, + ) + == date_table.c.date_data + ) + + row = conn.execute(stmt, {"foo": None}).first() + eq_(row[0], id_) + + +class DateTimeTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime",) + __backend__ = True + datatype = DateTime + data = datetime.datetime(2012, 10, 15, 12, 57, 18) + + @testing.requires.datetime_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class DateTimeTZTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime_timezone",) + __backend__ = True + datatype = DateTime(timezone=True) + data = datetime.datetime( + 2012, 10, 15, 12, 57, 18, tzinfo=datetime.timezone.utc + ) + + @testing.requires.datetime_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime_microseconds",) + __backend__ = True + datatype = DateTime + data = datetime.datetime(2012, 10, 15, 12, 57, 18, 39642) + + +class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("timestamp_microseconds",) + __backend__ = True + datatype = TIMESTAMP + data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) + + @testing.requires.timestamp_microseconds_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class TimeTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("time",) + __backend__ = True + datatype = Time + data = datetime.time(12, 57, 18) + + @testing.requires.time_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class TimeTZTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("time_timezone",) + __backend__ = True + datatype = Time(timezone=True) + data = datetime.time(12, 57, 18, tzinfo=datetime.timezone.utc) + + @testing.requires.time_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("time_microseconds",) + __backend__ = True + datatype = Time + data = datetime.time(12, 57, 18, 396) + + @testing.requires.time_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class DateTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("date",) + __backend__ = True + datatype = Date + data = datetime.date(2012, 10, 15) + + @testing.requires.date_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest): + """this particular suite is testing that datetime parameters get + coerced to dates, which tends to be something DBAPIs do. + + """ + + __requires__ = "date", "date_coerces_from_datetime" + __backend__ = True + datatype = Date + data = datetime.datetime(2012, 10, 15, 12, 57, 18) + compare = datetime.date(2012, 10, 15) + + @testing.requires.datetime_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime_historic",) + __backend__ = True + datatype = DateTime + data = datetime.datetime(1850, 11, 10, 11, 52, 35) + + @testing.requires.date_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class DateHistoricTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("date_historic",) + __backend__ = True + datatype = Date + data = datetime.date(1727, 4, 1) + + @testing.requires.date_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + def test_literal(self, literal_round_trip): + literal_round_trip(Integer, [5], [5]) + + def _huge_ints(): + return testing.combinations( + 2147483649, # 32 bits + 2147483648, # 32 bits + 2147483647, # 31 bits + 2147483646, # 31 bits + -2147483649, # 32 bits + -2147483648, # 32 interestingly, asyncpg accepts this one as int32 + -2147483647, # 31 + -2147483646, # 31 + 0, + 1376537018368127, + -1376537018368127, + argnames="intvalue", + ) + + @_huge_ints() + def test_huge_int_auto_accommodation(self, connection, intvalue): + """test #7909""" + + eq_( + connection.scalar( + select(intvalue).where(literal(intvalue) == intvalue) + ), + intvalue, + ) + + @_huge_ints() + def test_huge_int(self, integer_round_trip, intvalue): + integer_round_trip(BigInteger, intvalue) + + @testing.fixture + def integer_round_trip(self, metadata, connection): + def run(datatype, data): + int_table = Table( + "integer_table", + metadata, + Column( + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("integer_data", datatype), + ) + + metadata.create_all(config.db) + + connection.execute( + int_table.insert(), {"id": 1, "integer_data": data} + ) + + row = connection.execute(select(int_table.c.integer_data)).first() + + eq_(row, (data,)) + + assert isinstance(row[0], int) + + return run + + +class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @testing.fixture + def string_as_int(self): + class StringAsInt(TypeDecorator): + impl = String(50) + cache_ok = True + + def column_expression(self, col): + return cast(col, Integer) + + def bind_expression(self, col): + return cast(type_coerce(col, Integer), String(50)) + + return StringAsInt() + + def test_special_type(self, metadata, connection, string_as_int): + type_ = string_as_int + + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + + connection.execute(t.insert(), [{"x": x} for x in [1, 2, 3]]) + + result = {row[0] for row in connection.execute(t.select())} + eq_(result, {1, 2, 3}) + + result = { + row[0] for row in connection.execute(t.select().where(t.c.x == 2)) + } + eq_(result, {2}) + + +class TrueDivTest(fixtures.TestBase): + __backend__ = True + + @testing.combinations( + ("15", "10", 1.5), + ("-15", "10", -1.5), + argnames="left, right, expected", + ) + def test_truediv_integer(self, connection, left, right, expected): + """test #4926""" + + eq_( + connection.scalar( + select( + literal_column(left, type_=Integer()) + / literal_column(right, type_=Integer()) + ) + ), + expected, + ) + + @testing.combinations( + ("15", "10", 1), ("-15", "5", -3), argnames="left, right, expected" + ) + def test_floordiv_integer(self, connection, left, right, expected): + """test #4926""" + + eq_( + connection.scalar( + select( + literal_column(left, type_=Integer()) + // literal_column(right, type_=Integer()) + ) + ), + expected, + ) + + @testing.combinations( + ("5.52", "2.4", "2.3"), argnames="left, right, expected" + ) + def test_truediv_numeric(self, connection, left, right, expected): + """test #4926""" + + eq_( + connection.scalar( + select( + literal_column(left, type_=Numeric(10, 2)) + / literal_column(right, type_=Numeric(10, 2)) + ) + ), + decimal.Decimal(expected), + ) + + @testing.combinations( + ("5.52", "2.4", 2.3), argnames="left, right, expected" + ) + def test_truediv_float(self, connection, left, right, expected): + """test #4926""" + + eq_( + connection.scalar( + select( + literal_column(left, type_=Float()) + / literal_column(right, type_=Float()) + ) + ), + expected, + ) + + @testing.combinations( + ("5.52", "2.4", "2.0"), argnames="left, right, expected" + ) + def test_floordiv_numeric(self, connection, left, right, expected): + """test #4926""" + + eq_( + connection.scalar( + select( + literal_column(left, type_=Numeric()) + // literal_column(right, type_=Numeric()) + ) + ), + decimal.Decimal(expected), + ) + + def test_truediv_integer_bound(self, connection): + """test #4926""" + + eq_( + connection.scalar(select(literal(15) / literal(10))), + 1.5, + ) + + def test_floordiv_integer_bound(self, connection): + """test #4926""" + + eq_( + connection.scalar(select(literal(15) // literal(10))), + 1, + ) + + +class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @testing.fixture + def do_numeric_test(self, metadata, connection): + def run(type_, input_, output, filter_=None, check_scale=False): + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + connection.execute(t.insert(), [{"x": x} for x in input_]) + + result = {row[0] for row in connection.execute(t.select())} + output = set(output) + if filter_: + result = {filter_(x) for x in result} + output = {filter_(x) for x in output} + eq_(result, output) + if check_scale: + eq_([str(x) for x in result], [str(x) for x in output]) + + connection.execute(t.delete()) + + # test that this is actually a number! + # note we have tiny scale here as we have tests with very + # small scale Numeric types. PostgreSQL will raise an error + # if you use values outside the available scale. + if type_.asdecimal: + test_value = decimal.Decimal("2.9") + add_value = decimal.Decimal("37.12") + else: + test_value = 2.9 + add_value = 37.12 + + connection.execute(t.insert(), {"x": test_value}) + assert_we_are_a_number = connection.scalar( + select(type_coerce(t.c.x + add_value, type_)) + ) + eq_( + round(assert_we_are_a_number, 3), + round(test_value + add_value, 3), + ) + + return run + + def test_render_literal_numeric(self, literal_round_trip): + literal_round_trip( + Numeric(precision=8, scale=4), + [15.7563, decimal.Decimal("15.7563")], + [decimal.Decimal("15.7563")], + ) + + def test_render_literal_numeric_asfloat(self, literal_round_trip): + literal_round_trip( + Numeric(precision=8, scale=4, asdecimal=False), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + ) + + def test_render_literal_float(self, literal_round_trip): + literal_round_trip( + Float(), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + filter_=lambda n: n is not None and round(n, 5) or None, + support_whereclause=False, + ) + + @testing.requires.precision_generic_float_type + def test_float_custom_scale(self, do_numeric_test): + do_numeric_test( + Float(None, decimal_return_scale=7, asdecimal=True), + [15.7563827, decimal.Decimal("15.7563827")], + [decimal.Decimal("15.7563827")], + check_scale=True, + ) + + def test_numeric_as_decimal(self, do_numeric_test): + do_numeric_test( + Numeric(precision=8, scale=4), + [15.7563, decimal.Decimal("15.7563")], + [decimal.Decimal("15.7563")], + ) + + def test_numeric_as_float(self, do_numeric_test): + do_numeric_test( + Numeric(precision=8, scale=4, asdecimal=False), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + ) + + @testing.requires.infinity_floats + def test_infinity_floats(self, do_numeric_test): + """test for #977, #7283""" + + do_numeric_test( + Float(None), + [float("inf")], + [float("inf")], + ) + + @testing.requires.fetch_null_from_numeric + def test_numeric_null_as_decimal(self, do_numeric_test): + do_numeric_test(Numeric(precision=8, scale=4), [None], [None]) + + @testing.requires.fetch_null_from_numeric + def test_numeric_null_as_float(self, do_numeric_test): + do_numeric_test( + Numeric(precision=8, scale=4, asdecimal=False), [None], [None] + ) + + @testing.requires.floats_to_four_decimals + def test_float_as_decimal(self, do_numeric_test): + do_numeric_test( + Float(asdecimal=True), + [15.756, decimal.Decimal("15.756"), None], + [decimal.Decimal("15.756"), None], + filter_=lambda n: n is not None and round(n, 4) or None, + ) + + def test_float_as_float(self, do_numeric_test): + do_numeric_test( + Float(), + [15.756, decimal.Decimal("15.756")], + [15.756], + filter_=lambda n: n is not None and round(n, 5) or None, + ) + + @testing.requires.literal_float_coercion + def test_float_coerce_round_trip(self, connection): + expr = 15.7563 + + val = connection.scalar(select(literal(expr))) + eq_(val, expr) + + # this does not work in MySQL, see #4036, however we choose not + # to render CAST unconditionally since this is kind of an edge case. + + @testing.requires.implicit_decimal_binds + def test_decimal_coerce_round_trip(self, connection): + expr = decimal.Decimal("15.7563") + + val = connection.scalar(select(literal(expr))) + eq_(val, expr) + + def test_decimal_coerce_round_trip_w_cast(self, connection): + expr = decimal.Decimal("15.7563") + + val = connection.scalar(select(cast(expr, Numeric(10, 4)))) + eq_(val, expr) + + @testing.requires.precision_numerics_general + def test_precision_decimal(self, do_numeric_test): + numbers = { + decimal.Decimal("54.234246451650"), + decimal.Decimal("0.004354"), + decimal.Decimal("900.0"), + } + + do_numeric_test(Numeric(precision=18, scale=12), numbers, numbers) + + @testing.requires.precision_numerics_enotation_large + def test_enotation_decimal(self, do_numeric_test): + """test exceedingly small decimals. + + Decimal reports values with E notation when the exponent + is greater than 6. + + """ + + numbers = { + decimal.Decimal("1E-2"), + decimal.Decimal("1E-3"), + decimal.Decimal("1E-4"), + decimal.Decimal("1E-5"), + decimal.Decimal("1E-6"), + decimal.Decimal("1E-7"), + decimal.Decimal("1E-8"), + decimal.Decimal("0.01000005940696"), + decimal.Decimal("0.00000005940696"), + decimal.Decimal("0.00000000000696"), + decimal.Decimal("0.70000000000696"), + decimal.Decimal("696E-12"), + } + do_numeric_test(Numeric(precision=18, scale=14), numbers, numbers) + + @testing.requires.precision_numerics_enotation_large + def test_enotation_decimal_large(self, do_numeric_test): + """test exceedingly large decimals.""" + + numbers = { + decimal.Decimal("4E+8"), + decimal.Decimal("5748E+15"), + decimal.Decimal("1.521E+15"), + decimal.Decimal("00000000000000.1E+12"), + } + do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers) + + @testing.requires.precision_numerics_many_significant_digits + def test_many_significant_digits(self, do_numeric_test): + numbers = { + decimal.Decimal("31943874831932418390.01"), + decimal.Decimal("319438950232418390.273596"), + decimal.Decimal("87673.594069654243"), + } + do_numeric_test(Numeric(precision=38, scale=12), numbers, numbers) + + @testing.requires.precision_numerics_retains_significant_digits + def test_numeric_no_decimal(self, do_numeric_test): + numbers = {decimal.Decimal("1.000")} + do_numeric_test( + Numeric(precision=5, scale=3), numbers, numbers, check_scale=True + ) + + @testing.combinations(sqltypes.Float, sqltypes.Double, argnames="cls_") + @testing.requires.float_is_numeric + def test_float_is_not_numeric(self, connection, cls_): + target_type = cls_().dialect_impl(connection.dialect) + numeric_type = sqltypes.Numeric().dialect_impl(connection.dialect) + + ne_(target_type.__visit_name__, numeric_type.__visit_name__) + ne_(target_type.__class__, numeric_type.__class__) + + +class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "boolean_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("value", Boolean), + Column("unconstrained_value", Boolean(create_constraint=False)), + ) + + def test_render_literal_bool(self, literal_round_trip): + literal_round_trip(Boolean(), [True, False], [True, False]) + + def test_round_trip(self, connection): + boolean_table = self.tables.boolean_table + + connection.execute( + boolean_table.insert(), + {"id": 1, "value": True, "unconstrained_value": False}, + ) + + row = connection.execute( + select(boolean_table.c.value, boolean_table.c.unconstrained_value) + ).first() + + eq_(row, (True, False)) + assert isinstance(row[0], bool) + + @testing.requires.nullable_booleans + def test_null(self, connection): + boolean_table = self.tables.boolean_table + + connection.execute( + boolean_table.insert(), + {"id": 1, "value": None, "unconstrained_value": None}, + ) + + row = connection.execute( + select(boolean_table.c.value, boolean_table.c.unconstrained_value) + ).first() + + eq_(row, (None, None)) + + def test_whereclause(self): + # testing "WHERE <column>" renders a compatible expression + boolean_table = self.tables.boolean_table + + with config.db.begin() as conn: + conn.execute( + boolean_table.insert(), + [ + {"id": 1, "value": True, "unconstrained_value": True}, + {"id": 2, "value": False, "unconstrained_value": False}, + ], + ) + + eq_( + conn.scalar( + select(boolean_table.c.id).where(boolean_table.c.value) + ), + 1, + ) + eq_( + conn.scalar( + select(boolean_table.c.id).where( + boolean_table.c.unconstrained_value + ) + ), + 1, + ) + eq_( + conn.scalar( + select(boolean_table.c.id).where(~boolean_table.c.value) + ), + 2, + ) + eq_( + conn.scalar( + select(boolean_table.c.id).where( + ~boolean_table.c.unconstrained_value + ) + ), + 2, + ) + + +class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __requires__ = ("json_type",) + __backend__ = True + + datatype = JSON + + @classmethod + def define_tables(cls, metadata): + Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), nullable=False), + Column("data", cls.datatype, nullable=False), + Column("nulldata", cls.datatype(none_as_null=True)), + ) + + def test_round_trip_data1(self, connection): + self._test_round_trip({"key1": "value1", "key2": "value2"}, connection) + + @testing.combinations( + ("unicode", True), ("ascii", False), argnames="unicode_", id_="ia" + ) + @testing.combinations(100, 1999, 3000, 4000, 5000, 9000, argnames="length") + def test_round_trip_pretty_large_data(self, connection, unicode_, length): + if unicode_: + data = "réve🐍illé" * ((length // 9) + 1) + data = data[0 : (length // 2)] + else: + data = "abcdefg" * ((length // 7) + 1) + data = data[0:length] + + self._test_round_trip({"key1": data, "key2": data}, connection) + + def _test_round_trip(self, data_element, connection): + data_table = self.tables.data_table + + connection.execute( + data_table.insert(), + {"id": 1, "name": "row1", "data": data_element}, + ) + + row = connection.execute(select(data_table.c.data)).first() + + eq_(row, (data_element,)) + + def _index_fixtures(include_comparison): + if include_comparison: + # basically SQL Server and MariaDB can kind of do json + # comparison, MySQL, PG and SQLite can't. not worth it. + json_elements = [] + else: + json_elements = [ + ("json", {"foo": "bar"}), + ("json", ["one", "two", "three"]), + (None, {"foo": "bar"}), + (None, ["one", "two", "three"]), + ] + + elements = [ + ("boolean", True), + ("boolean", False), + ("boolean", None), + ("string", "some string"), + ("string", None), + ("string", "réve illé"), + ( + "string", + "réve🐍 illé", + testing.requires.json_index_supplementary_unicode_element, + ), + ("integer", 15), + ("integer", 1), + ("integer", 0), + ("integer", None), + ("float", 28.5), + ("float", None), + ("float", 1234567.89, testing.requires.literal_float_coercion), + ("numeric", 1234567.89), + # this one "works" because the float value you see here is + # lost immediately to floating point stuff + ( + "numeric", + 99998969694839.983485848, + ), + ("numeric", 99939.983485848), + ("_decimal", decimal.Decimal("1234567.89")), + ( + "_decimal", + decimal.Decimal("99998969694839.983485848"), + # fails on SQLite and MySQL (non-mariadb) + requirements.cast_precision_numerics_many_significant_digits, + ), + ( + "_decimal", + decimal.Decimal("99939.983485848"), + ), + ] + json_elements + + def decorate(fn): + fn = testing.combinations(id_="sa", *elements)(fn) + + return fn + + return decorate + + def _json_value_insert(self, connection, datatype, value, data_element): + data_table = self.tables.data_table + if datatype == "_decimal": + # Python's builtin json serializer basically doesn't support + # Decimal objects without implicit float conversion period. + # users can otherwise use simplejson which supports + # precision decimals + + # https://bugs.python.org/issue16535 + + # inserting as strings to avoid a new fixture around the + # dialect which would have idiosyncrasies for different + # backends. + + class DecimalEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, decimal.Decimal): + return str(o) + return super().default(o) + + json_data = json.dumps(data_element, cls=DecimalEncoder) + + # take the quotes out. yup, there is *literally* no other + # way to get Python's json.dumps() to put all the digits in + # the string + json_data = re.sub(r'"(%s)"' % str(value), str(value), json_data) + + datatype = "numeric" + + connection.execute( + data_table.insert().values( + name="row1", + # to pass the string directly to every backend, including + # PostgreSQL which needs the value to be CAST as JSON + # both in the SQL as well as at the prepared statement + # level for asyncpg, while at the same time MySQL + # doesn't even support CAST for JSON, here we are + # sending the string embedded in the SQL without using + # a parameter. + data=bindparam(None, json_data, literal_execute=True), + nulldata=bindparam(None, json_data, literal_execute=True), + ), + ) + else: + connection.execute( + data_table.insert(), + { + "name": "row1", + "data": data_element, + "nulldata": data_element, + }, + ) + + p_s = None + + if datatype: + if datatype == "numeric": + a, b = str(value).split(".") + s = len(b) + p = len(a) + s + + if isinstance(value, decimal.Decimal): + compare_value = value + else: + compare_value = decimal.Decimal(str(value)) + + p_s = (p, s) + else: + compare_value = value + else: + compare_value = value + + return datatype, compare_value, p_s + + @_index_fixtures(False) + def test_index_typed_access(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": value} + + with config.db.begin() as conn: + datatype, compare_value, p_s = self._json_value_insert( + conn, datatype, value, data_element + ) + + expr = data_table.c.data["key1"] + if datatype: + if datatype == "numeric" and p_s: + expr = expr.as_numeric(*p_s) + else: + expr = getattr(expr, "as_%s" % datatype)() + + roundtrip = conn.scalar(select(expr)) + eq_(roundtrip, compare_value) + is_(type(roundtrip), type(compare_value)) + + @_index_fixtures(True) + def test_index_typed_comparison(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": value} + + with config.db.begin() as conn: + datatype, compare_value, p_s = self._json_value_insert( + conn, datatype, value, data_element + ) + + expr = data_table.c.data["key1"] + if datatype: + if datatype == "numeric" and p_s: + expr = expr.as_numeric(*p_s) + else: + expr = getattr(expr, "as_%s" % datatype)() + + row = conn.execute( + select(expr).where(expr == compare_value) + ).first() + + # make sure we get a row even if value is None + eq_(row, (compare_value,)) + + @_index_fixtures(True) + def test_path_typed_comparison(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": {"subkey1": value}} + with config.db.begin() as conn: + datatype, compare_value, p_s = self._json_value_insert( + conn, datatype, value, data_element + ) + + expr = data_table.c.data[("key1", "subkey1")] + + if datatype: + if datatype == "numeric" and p_s: + expr = expr.as_numeric(*p_s) + else: + expr = getattr(expr, "as_%s" % datatype)() + + row = conn.execute( + select(expr).where(expr == compare_value) + ).first() + + # make sure we get a row even if value is None + eq_(row, (compare_value,)) + + @testing.combinations( + (True,), + (False,), + (None,), + (15,), + (0,), + (-1,), + (-1.0,), + (15.052,), + ("a string",), + ("réve illé",), + ("réve🐍 illé",), + ) + def test_single_element_round_trip(self, element): + data_table = self.tables.data_table + data_element = element + with config.db.begin() as conn: + conn.execute( + data_table.insert(), + { + "name": "row1", + "data": data_element, + "nulldata": data_element, + }, + ) + + row = conn.execute( + select(data_table.c.data, data_table.c.nulldata) + ).first() + + eq_(row, (data_element, data_element)) + + def test_round_trip_custom_json(self): + data_table = self.tables.data_table + data_element = {"key1": "data1"} + + js = mock.Mock(side_effect=json.dumps) + jd = mock.Mock(side_effect=json.loads) + engine = engines.testing_engine( + options=dict(json_serializer=js, json_deserializer=jd) + ) + + # support sqlite :memory: database... + data_table.create(engine, checkfirst=True) + with engine.begin() as conn: + conn.execute( + data_table.insert(), {"name": "row1", "data": data_element} + ) + row = conn.execute(select(data_table.c.data)).first() + + eq_(row, (data_element,)) + eq_(js.mock_calls, [mock.call(data_element)]) + if testing.requires.json_deserializer_binary.enabled: + eq_( + jd.mock_calls, + [mock.call(json.dumps(data_element).encode())], + ) + else: + eq_(jd.mock_calls, [mock.call(json.dumps(data_element))]) + + @testing.combinations( + ("parameters",), + ("multiparameters",), + ("values",), + ("omit",), + argnames="insert_type", + ) + def test_round_trip_none_as_sql_null(self, connection, insert_type): + col = self.tables.data_table.c["nulldata"] + + conn = connection + + if insert_type == "parameters": + stmt, params = self.tables.data_table.insert(), { + "name": "r1", + "nulldata": None, + "data": None, + } + elif insert_type == "multiparameters": + stmt, params = self.tables.data_table.insert(), [ + {"name": "r1", "nulldata": None, "data": None} + ] + elif insert_type == "values": + stmt, params = ( + self.tables.data_table.insert().values( + name="r1", + nulldata=None, + data=None, + ), + {}, + ) + elif insert_type == "omit": + stmt, params = ( + self.tables.data_table.insert(), + {"name": "r1", "data": None}, + ) + + else: + assert False + + conn.execute(stmt, params) + + eq_( + conn.scalar( + select(self.tables.data_table.c.name).where(col.is_(null())) + ), + "r1", + ) + + eq_(conn.scalar(select(col)), None) + + def test_round_trip_json_null_as_json_null(self, connection): + col = self.tables.data_table.c["data"] + + conn = connection + conn.execute( + self.tables.data_table.insert(), + {"name": "r1", "data": JSON.NULL}, + ) + + eq_( + conn.scalar( + select(self.tables.data_table.c.name).where( + cast(col, String) == "null" + ) + ), + "r1", + ) + + eq_(conn.scalar(select(col)), None) + + @testing.combinations( + ("parameters",), + ("multiparameters",), + ("values",), + argnames="insert_type", + ) + def test_round_trip_none_as_json_null(self, connection, insert_type): + col = self.tables.data_table.c["data"] + + if insert_type == "parameters": + stmt, params = self.tables.data_table.insert(), { + "name": "r1", + "data": None, + } + elif insert_type == "multiparameters": + stmt, params = self.tables.data_table.insert(), [ + {"name": "r1", "data": None} + ] + elif insert_type == "values": + stmt, params = ( + self.tables.data_table.insert().values(name="r1", data=None), + {}, + ) + else: + assert False + + conn = connection + conn.execute(stmt, params) + + eq_( + conn.scalar( + select(self.tables.data_table.c.name).where( + cast(col, String) == "null" + ) + ), + "r1", + ) + + eq_(conn.scalar(select(col)), None) + + def test_unicode_round_trip(self): + # note we include Unicode supplementary characters as well + with config.db.begin() as conn: + conn.execute( + self.tables.data_table.insert(), + { + "name": "r1", + "data": { + "réve🐍 illé": "réve🐍 illé", + "data": {"k1": "drôl🐍e"}, + }, + }, + ) + + eq_( + conn.scalar(select(self.tables.data_table.c.data)), + { + "réve🐍 illé": "réve🐍 illé", + "data": {"k1": "drôl🐍e"}, + }, + ) + + def test_eval_none_flag_orm(self, connection): + Base = declarative_base() + + class Data(Base): + __table__ = self.tables.data_table + + with Session(connection) as s: + d1 = Data(name="d1", data=None, nulldata=None) + s.add(d1) + s.commit() + + s.bulk_insert_mappings( + Data, [{"name": "d2", "data": None, "nulldata": None}] + ) + eq_( + s.query( + cast(self.tables.data_table.c.data, String()), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d1") + .first(), + ("null", None), + ) + eq_( + s.query( + cast(self.tables.data_table.c.data, String()), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d2") + .first(), + ("null", None), + ) + + +class JSONLegacyStringCastIndexTest( + _LiteralRoundTripFixture, fixtures.TablesTest +): + """test JSON index access with "cast to string", which we have documented + for a long time as how to compare JSON values, but is ultimately not + reliable in all cases. The "as_XYZ()" comparators should be used + instead. + + """ + + __requires__ = ("json_type", "legacy_unconditional_json_extract") + __backend__ = True + + datatype = JSON + + data1 = {"key1": "value1", "key2": "value2"} + + data2 = { + "Key 'One'": "value1", + "key two": "value2", + "key three": "value ' three '", + } + + data3 = { + "key1": [1, 2, 3], + "key2": ["one", "two", "three"], + "key3": [{"four": "five"}, {"six": "seven"}], + } + + data4 = ["one", "two", "three"] + + data5 = { + "nested": { + "elem1": [{"a": "b", "c": "d"}, {"e": "f", "g": "h"}], + "elem2": {"elem3": {"elem4": "elem5"}}, + } + } + + data6 = {"a": 5, "b": "some value", "c": {"foo": "bar"}} + + @classmethod + def define_tables(cls, metadata): + Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), nullable=False), + Column("data", cls.datatype), + Column("nulldata", cls.datatype(none_as_null=True)), + ) + + def _criteria_fixture(self): + with config.db.begin() as conn: + conn.execute( + self.tables.data_table.insert(), + [ + {"name": "r1", "data": self.data1}, + {"name": "r2", "data": self.data2}, + {"name": "r3", "data": self.data3}, + {"name": "r4", "data": self.data4}, + {"name": "r5", "data": self.data5}, + {"name": "r6", "data": self.data6}, + ], + ) + + def _test_index_criteria(self, crit, expected, test_literal=True): + self._criteria_fixture() + with config.db.connect() as conn: + stmt = select(self.tables.data_table.c.name).where(crit) + + eq_(conn.scalar(stmt), expected) + + if test_literal: + literal_sql = str( + stmt.compile( + config.db, compile_kwargs={"literal_binds": True} + ) + ) + + eq_(conn.exec_driver_sql(literal_sql).scalar(), expected) + + def test_string_cast_crit_spaces_in_key(self): + name = self.tables.data_table.c.name + col = self.tables.data_table.c["data"] + + # limit the rows here to avoid PG error + # "cannot extract field from a non-object", which is + # fixed in 9.4 but may exist in 9.3 + self._test_index_criteria( + and_( + name.in_(["r1", "r2", "r3"]), + cast(col["key two"], String) == '"value2"', + ), + "r2", + ) + + @config.requirements.json_array_indexes + def test_string_cast_crit_simple_int(self): + name = self.tables.data_table.c.name + col = self.tables.data_table.c["data"] + + # limit the rows here to avoid PG error + # "cannot extract array element from a non-array", which is + # fixed in 9.4 but may exist in 9.3 + self._test_index_criteria( + and_( + name == "r4", + cast(col[1], String) == '"two"', + ), + "r4", + ) + + def test_string_cast_crit_mixed_path(self): + col = self.tables.data_table.c["data"] + self._test_index_criteria( + cast(col[("key3", 1, "six")], String) == '"seven"', + "r3", + ) + + def test_string_cast_crit_string_path(self): + col = self.tables.data_table.c["data"] + self._test_index_criteria( + cast(col[("nested", "elem2", "elem3", "elem4")], String) + == '"elem5"', + "r5", + ) + + def test_string_cast_crit_against_string_basic(self): + name = self.tables.data_table.c.name + col = self.tables.data_table.c["data"] + + self._test_index_criteria( + and_( + name == "r6", + cast(col["b"], String) == '"some value"', + ), + "r6", + ) + + +class UuidTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __backend__ = True + + datatype = Uuid + + @classmethod + def define_tables(cls, metadata): + Table( + "uuid_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("uuid_data", cls.datatype), + Column("uuid_text_data", cls.datatype(as_uuid=False)), + Column("uuid_data_nonnative", Uuid(native_uuid=False)), + Column( + "uuid_text_data_nonnative", + Uuid(as_uuid=False, native_uuid=False), + ), + ) + + def test_uuid_round_trip(self, connection): + data = uuid.uuid4() + uuid_table = self.tables.uuid_table + + connection.execute( + uuid_table.insert(), + {"id": 1, "uuid_data": data, "uuid_data_nonnative": data}, + ) + row = connection.execute( + select( + uuid_table.c.uuid_data, uuid_table.c.uuid_data_nonnative + ).where( + uuid_table.c.uuid_data == data, + uuid_table.c.uuid_data_nonnative == data, + ) + ).first() + eq_(row, (data, data)) + + def test_uuid_text_round_trip(self, connection): + data = str(uuid.uuid4()) + uuid_table = self.tables.uuid_table + + connection.execute( + uuid_table.insert(), + { + "id": 1, + "uuid_text_data": data, + "uuid_text_data_nonnative": data, + }, + ) + row = connection.execute( + select( + uuid_table.c.uuid_text_data, + uuid_table.c.uuid_text_data_nonnative, + ).where( + uuid_table.c.uuid_text_data == data, + uuid_table.c.uuid_text_data_nonnative == data, + ) + ).first() + eq_((row[0].lower(), row[1].lower()), (data, data)) + + def test_literal_uuid(self, literal_round_trip): + data = uuid.uuid4() + literal_round_trip(self.datatype, [data], [data]) + + def test_literal_text(self, literal_round_trip): + data = str(uuid.uuid4()) + literal_round_trip( + self.datatype(as_uuid=False), + [data], + [data], + filter_=lambda x: x.lower(), + ) + + def test_literal_nonnative_uuid(self, literal_round_trip): + data = uuid.uuid4() + literal_round_trip(Uuid(native_uuid=False), [data], [data]) + + def test_literal_nonnative_text(self, literal_round_trip): + data = str(uuid.uuid4()) + literal_round_trip( + Uuid(as_uuid=False, native_uuid=False), + [data], + [data], + filter_=lambda x: x.lower(), + ) + + @testing.requires.insert_returning + def test_uuid_returning(self, connection): + data = uuid.uuid4() + str_data = str(data) + uuid_table = self.tables.uuid_table + + result = connection.execute( + uuid_table.insert().returning( + uuid_table.c.uuid_data, + uuid_table.c.uuid_text_data, + uuid_table.c.uuid_data_nonnative, + uuid_table.c.uuid_text_data_nonnative, + ), + { + "id": 1, + "uuid_data": data, + "uuid_text_data": str_data, + "uuid_data_nonnative": data, + "uuid_text_data_nonnative": str_data, + }, + ) + row = result.first() + + eq_(row, (data, str_data, data, str_data)) + + +class NativeUUIDTest(UuidTest): + __requires__ = ("uuid_data_type",) + + datatype = UUID + + +__all__ = ( + "ArrayTest", + "BinaryTest", + "UnicodeVarcharTest", + "UnicodeTextTest", + "JSONTest", + "JSONLegacyStringCastIndexTest", + "DateTest", + "DateTimeTest", + "DateTimeTZTest", + "TextTest", + "NumericTest", + "IntegerTest", + "IntervalTest", + "PrecisionIntervalTest", + "CastTypeDecoratorTest", + "DateTimeHistoricTest", + "DateTimeCoercedToDateTimeTest", + "TimeMicrosecondsTest", + "TimestampMicrosecondsTest", + "TimeTest", + "TimeTZTest", + "TrueDivTest", + "DateTimeMicrosecondsTest", + "DateHistoricTest", + "StringTest", + "BooleanTest", + "UuidTest", + "NativeUUIDTest", +) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_unicode_ddl.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_unicode_ddl.py new file mode 100644 index 0000000..1f15ab5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_unicode_ddl.py @@ -0,0 +1,189 @@ +# testing/suite/test_unicode_ddl.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from sqlalchemy import desc +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import testing +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures +from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import Table + + +class UnicodeSchemaTest(fixtures.TablesTest): + __requires__ = ("unicode_ddl",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + global t1, t2, t3 + + t1 = Table( + "unitable1", + metadata, + Column("méil", Integer, primary_key=True), + Column("\u6e2c\u8a66", Integer), + test_needs_fk=True, + ) + t2 = Table( + "Unitéble2", + metadata, + Column("méil", Integer, primary_key=True, key="a"), + Column( + "\u6e2c\u8a66", + Integer, + ForeignKey("unitable1.méil"), + key="b", + ), + test_needs_fk=True, + ) + + # Few DBs support Unicode foreign keys + if testing.against("sqlite"): + t3 = Table( + "\u6e2c\u8a66", + metadata, + Column( + "\u6e2c\u8a66_id", + Integer, + primary_key=True, + autoincrement=False, + ), + Column( + "unitable1_\u6e2c\u8a66", + Integer, + ForeignKey("unitable1.\u6e2c\u8a66"), + ), + Column("Unitéble2_b", Integer, ForeignKey("Unitéble2.b")), + Column( + "\u6e2c\u8a66_self", + Integer, + ForeignKey("\u6e2c\u8a66.\u6e2c\u8a66_id"), + ), + test_needs_fk=True, + ) + else: + t3 = Table( + "\u6e2c\u8a66", + metadata, + Column( + "\u6e2c\u8a66_id", + Integer, + primary_key=True, + autoincrement=False, + ), + Column("unitable1_\u6e2c\u8a66", Integer), + Column("Unitéble2_b", Integer), + Column("\u6e2c\u8a66_self", Integer), + test_needs_fk=True, + ) + + def test_insert(self, connection): + connection.execute(t1.insert(), {"méil": 1, "\u6e2c\u8a66": 5}) + connection.execute(t2.insert(), {"a": 1, "b": 1}) + connection.execute( + t3.insert(), + { + "\u6e2c\u8a66_id": 1, + "unitable1_\u6e2c\u8a66": 5, + "Unitéble2_b": 1, + "\u6e2c\u8a66_self": 1, + }, + ) + + eq_(connection.execute(t1.select()).fetchall(), [(1, 5)]) + eq_(connection.execute(t2.select()).fetchall(), [(1, 1)]) + eq_(connection.execute(t3.select()).fetchall(), [(1, 5, 1, 1)]) + + def test_col_targeting(self, connection): + connection.execute(t1.insert(), {"méil": 1, "\u6e2c\u8a66": 5}) + connection.execute(t2.insert(), {"a": 1, "b": 1}) + connection.execute( + t3.insert(), + { + "\u6e2c\u8a66_id": 1, + "unitable1_\u6e2c\u8a66": 5, + "Unitéble2_b": 1, + "\u6e2c\u8a66_self": 1, + }, + ) + + row = connection.execute(t1.select()).first() + eq_(row._mapping[t1.c["méil"]], 1) + eq_(row._mapping[t1.c["\u6e2c\u8a66"]], 5) + + row = connection.execute(t2.select()).first() + eq_(row._mapping[t2.c["a"]], 1) + eq_(row._mapping[t2.c["b"]], 1) + + row = connection.execute(t3.select()).first() + eq_(row._mapping[t3.c["\u6e2c\u8a66_id"]], 1) + eq_(row._mapping[t3.c["unitable1_\u6e2c\u8a66"]], 5) + eq_(row._mapping[t3.c["Unitéble2_b"]], 1) + eq_(row._mapping[t3.c["\u6e2c\u8a66_self"]], 1) + + def test_reflect(self, connection): + connection.execute(t1.insert(), {"méil": 2, "\u6e2c\u8a66": 7}) + connection.execute(t2.insert(), {"a": 2, "b": 2}) + connection.execute( + t3.insert(), + { + "\u6e2c\u8a66_id": 2, + "unitable1_\u6e2c\u8a66": 7, + "Unitéble2_b": 2, + "\u6e2c\u8a66_self": 2, + }, + ) + + meta = MetaData() + tt1 = Table(t1.name, meta, autoload_with=connection) + tt2 = Table(t2.name, meta, autoload_with=connection) + tt3 = Table(t3.name, meta, autoload_with=connection) + + connection.execute(tt1.insert(), {"méil": 1, "\u6e2c\u8a66": 5}) + connection.execute(tt2.insert(), {"méil": 1, "\u6e2c\u8a66": 1}) + connection.execute( + tt3.insert(), + { + "\u6e2c\u8a66_id": 1, + "unitable1_\u6e2c\u8a66": 5, + "Unitéble2_b": 1, + "\u6e2c\u8a66_self": 1, + }, + ) + + eq_( + connection.execute(tt1.select().order_by(desc("méil"))).fetchall(), + [(2, 7), (1, 5)], + ) + eq_( + connection.execute(tt2.select().order_by(desc("méil"))).fetchall(), + [(2, 2), (1, 1)], + ) + eq_( + connection.execute( + tt3.select().order_by(desc("\u6e2c\u8a66_id")) + ).fetchall(), + [(2, 7, 2, 2), (1, 5, 1, 1)], + ) + + def test_repr(self): + meta = MetaData() + t = Table("\u6e2c\u8a66", meta, Column("\u6e2c\u8a66_id", Integer)) + eq_( + repr(t), + ( + "Table('測試', MetaData(), " + "Column('測試_id', Integer(), " + "table=<測試>), " + "schema=None)" + ), + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_update_delete.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_update_delete.py new file mode 100644 index 0000000..fd4757f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_update_delete.py @@ -0,0 +1,139 @@ +# testing/suite/test_update_delete.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .. import fixtures +from ..assertions import eq_ +from ..schema import Column +from ..schema import Table +from ... import Integer +from ... import String +from ... import testing + + +class SimpleUpdateDeleteTest(fixtures.TablesTest): + run_deletes = "each" + __requires__ = ("sane_rowcount",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "plain_pk", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.plain_pk.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + def test_update(self, connection): + t = self.tables.plain_pk + r = connection.execute( + t.update().where(t.c.id == 2), dict(data="d2_new") + ) + assert not r.is_insert + assert not r.returns_rows + assert r.rowcount == 1 + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + def test_delete(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + assert not r.returns_rows + assert r.rowcount == 1 + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) + + @testing.variation("criteria", ["rows", "norows", "emptyin"]) + @testing.requires.update_returning + def test_update_returning(self, connection, criteria): + t = self.tables.plain_pk + + stmt = t.update().returning(t.c.id, t.c.data) + + if criteria.norows: + stmt = stmt.where(t.c.id == 10) + elif criteria.rows: + stmt = stmt.where(t.c.id == 2) + elif criteria.emptyin: + stmt = stmt.where(t.c.id.in_([])) + else: + criteria.fail() + + r = connection.execute(stmt, dict(data="d2_new")) + assert not r.is_insert + assert r.returns_rows + eq_(r.keys(), ["id", "data"]) + + if criteria.rows: + eq_(r.all(), [(2, "d2_new")]) + else: + eq_(r.all(), []) + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + ( + [(1, "d1"), (2, "d2_new"), (3, "d3")] + if criteria.rows + else [(1, "d1"), (2, "d2"), (3, "d3")] + ), + ) + + @testing.variation("criteria", ["rows", "norows", "emptyin"]) + @testing.requires.delete_returning + def test_delete_returning(self, connection, criteria): + t = self.tables.plain_pk + + stmt = t.delete().returning(t.c.id, t.c.data) + + if criteria.norows: + stmt = stmt.where(t.c.id == 10) + elif criteria.rows: + stmt = stmt.where(t.c.id == 2) + elif criteria.emptyin: + stmt = stmt.where(t.c.id.in_([])) + else: + criteria.fail() + + r = connection.execute(stmt) + assert not r.is_insert + assert r.returns_rows + eq_(r.keys(), ["id", "data"]) + + if criteria.rows: + eq_(r.all(), [(2, "d2")]) + else: + eq_(r.all(), []) + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + ( + [(1, "d1"), (3, "d3")] + if criteria.rows + else [(1, "d1"), (2, "d2"), (3, "d3")] + ), + ) + + +__all__ = ("SimpleUpdateDeleteTest",) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/util.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/util.py new file mode 100644 index 0000000..a6ce6ca --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/util.py @@ -0,0 +1,519 @@ +# testing/util.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +from collections import deque +import decimal +import gc +from itertools import chain +import random +import sys +from sys import getsizeof +import types + +from . import config +from . import mock +from .. import inspect +from ..engine import Connection +from ..schema import Column +from ..schema import DropConstraint +from ..schema import DropTable +from ..schema import ForeignKeyConstraint +from ..schema import MetaData +from ..schema import Table +from ..sql import schema +from ..sql.sqltypes import Integer +from ..util import decorator +from ..util import defaultdict +from ..util import has_refcount_gc +from ..util import inspect_getfullargspec + + +if not has_refcount_gc: + + def non_refcount_gc_collect(*args): + gc.collect() + gc.collect() + + gc_collect = lazy_gc = non_refcount_gc_collect +else: + # assume CPython - straight gc.collect, lazy_gc() is a pass + gc_collect = gc.collect + + def lazy_gc(): + pass + + +def picklers(): + picklers = set() + import pickle + + picklers.add(pickle) + + # yes, this thing needs this much testing + for pickle_ in picklers: + for protocol in range(-2, pickle.HIGHEST_PROTOCOL + 1): + yield pickle_.loads, lambda d: pickle_.dumps(d, protocol) + + +def random_choices(population, k=1): + return random.choices(population, k=k) + + +def round_decimal(value, prec): + if isinstance(value, float): + return round(value, prec) + + # can also use shift() here but that is 2.6 only + return (value * decimal.Decimal("1" + "0" * prec)).to_integral( + decimal.ROUND_FLOOR + ) / pow(10, prec) + + +class RandomSet(set): + def __iter__(self): + l = list(set.__iter__(self)) + random.shuffle(l) + return iter(l) + + def pop(self): + index = random.randint(0, len(self) - 1) + item = list(set.__iter__(self))[index] + self.remove(item) + return item + + def union(self, other): + return RandomSet(set.union(self, other)) + + def difference(self, other): + return RandomSet(set.difference(self, other)) + + def intersection(self, other): + return RandomSet(set.intersection(self, other)) + + def copy(self): + return RandomSet(self) + + +def conforms_partial_ordering(tuples, sorted_elements): + """True if the given sorting conforms to the given partial ordering.""" + + deps = defaultdict(set) + for parent, child in tuples: + deps[parent].add(child) + for i, node in enumerate(sorted_elements): + for n in sorted_elements[i:]: + if node in deps[n]: + return False + else: + return True + + +def all_partial_orderings(tuples, elements): + edges = defaultdict(set) + for parent, child in tuples: + edges[child].add(parent) + + def _all_orderings(elements): + if len(elements) == 1: + yield list(elements) + else: + for elem in elements: + subset = set(elements).difference([elem]) + if not subset.intersection(edges[elem]): + for sub_ordering in _all_orderings(subset): + yield [elem] + sub_ordering + + return iter(_all_orderings(elements)) + + +def function_named(fn, name): + """Return a function with a given __name__. + + Will assign to __name__ and return the original function if possible on + the Python implementation, otherwise a new function will be constructed. + + This function should be phased out as much as possible + in favor of @decorator. Tests that "generate" many named tests + should be modernized. + + """ + try: + fn.__name__ = name + except TypeError: + fn = types.FunctionType( + fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__ + ) + return fn + + +def run_as_contextmanager(ctx, fn, *arg, **kw): + """Run the given function under the given contextmanager, + simulating the behavior of 'with' to support older + Python versions. + + This is not necessary anymore as we have placed 2.6 + as minimum Python version, however some tests are still using + this structure. + + """ + + obj = ctx.__enter__() + try: + result = fn(obj, *arg, **kw) + ctx.__exit__(None, None, None) + return result + except: + exc_info = sys.exc_info() + raise_ = ctx.__exit__(*exc_info) + if not raise_: + raise + else: + return raise_ + + +def rowset(results): + """Converts the results of sql execution into a plain set of column tuples. + + Useful for asserting the results of an unordered query. + """ + + return {tuple(row) for row in results} + + +def fail(msg): + assert False, msg + + +@decorator +def provide_metadata(fn, *args, **kw): + """Provide bound MetaData for a single test, dropping afterwards. + + Legacy; use the "metadata" pytest fixture. + + """ + + from . import fixtures + + metadata = schema.MetaData() + self = args[0] + prev_meta = getattr(self, "metadata", None) + self.metadata = metadata + try: + return fn(*args, **kw) + finally: + # close out some things that get in the way of dropping tables. + # when using the "metadata" fixture, there is a set ordering + # of things that makes sure things are cleaned up in order, however + # the simple "decorator" nature of this legacy function means + # we have to hardcode some of that cleanup ahead of time. + + # close ORM sessions + fixtures.close_all_sessions() + + # integrate with the "connection" fixture as there are many + # tests where it is used along with provide_metadata + cfc = fixtures.base._connection_fixture_connection + if cfc: + # TODO: this warning can be used to find all the places + # this is used with connection fixture + # warn("mixing legacy provide metadata with connection fixture") + drop_all_tables_from_metadata(metadata, cfc) + # as the provide_metadata fixture is often used with "testing.db", + # when we do the drop we have to commit the transaction so that + # the DB is actually updated as the CREATE would have been + # committed + cfc.get_transaction().commit() + else: + drop_all_tables_from_metadata(metadata, config.db) + self.metadata = prev_meta + + +def flag_combinations(*combinations): + """A facade around @testing.combinations() oriented towards boolean + keyword-based arguments. + + Basically generates a nice looking identifier based on the keywords + and also sets up the argument names. + + E.g.:: + + @testing.flag_combinations( + dict(lazy=False, passive=False), + dict(lazy=True, passive=False), + dict(lazy=False, passive=True), + dict(lazy=False, passive=True, raiseload=True), + ) + + + would result in:: + + @testing.combinations( + ('', False, False, False), + ('lazy', True, False, False), + ('lazy_passive', True, True, False), + ('lazy_passive', True, True, True), + id_='iaaa', + argnames='lazy,passive,raiseload' + ) + + """ + + keys = set() + + for d in combinations: + keys.update(d) + + keys = sorted(keys) + + return config.combinations( + *[ + ("_".join(k for k in keys if d.get(k, False)),) + + tuple(d.get(k, False) for k in keys) + for d in combinations + ], + id_="i" + ("a" * len(keys)), + argnames=",".join(keys), + ) + + +def lambda_combinations(lambda_arg_sets, **kw): + args = inspect_getfullargspec(lambda_arg_sets) + + arg_sets = lambda_arg_sets(*[mock.Mock() for arg in args[0]]) + + def create_fixture(pos): + def fixture(**kw): + return lambda_arg_sets(**kw)[pos] + + fixture.__name__ = "fixture_%3.3d" % pos + return fixture + + return config.combinations( + *[(create_fixture(i),) for i in range(len(arg_sets))], **kw + ) + + +def resolve_lambda(__fn, **kw): + """Given a no-arg lambda and a namespace, return a new lambda that + has all the values filled in. + + This is used so that we can have module-level fixtures that + refer to instance-level variables using lambdas. + + """ + + pos_args = inspect_getfullargspec(__fn)[0] + pass_pos_args = {arg: kw.pop(arg) for arg in pos_args} + glb = dict(__fn.__globals__) + glb.update(kw) + new_fn = types.FunctionType(__fn.__code__, glb) + return new_fn(**pass_pos_args) + + +def metadata_fixture(ddl="function"): + """Provide MetaData for a pytest fixture.""" + + def decorate(fn): + def run_ddl(self): + metadata = self.metadata = schema.MetaData() + try: + result = fn(self, metadata) + metadata.create_all(config.db) + # TODO: + # somehow get a per-function dml erase fixture here + yield result + finally: + metadata.drop_all(config.db) + + return config.fixture(scope=ddl)(run_ddl) + + return decorate + + +def force_drop_names(*names): + """Force the given table names to be dropped after test complete, + isolating for foreign key cycles + + """ + + @decorator + def go(fn, *args, **kw): + try: + return fn(*args, **kw) + finally: + drop_all_tables(config.db, inspect(config.db), include_names=names) + + return go + + +class adict(dict): + """Dict keys available as attributes. Shadows.""" + + def __getattribute__(self, key): + try: + return self[key] + except KeyError: + return dict.__getattribute__(self, key) + + def __call__(self, *keys): + return tuple([self[key] for key in keys]) + + get_all = __call__ + + +def drop_all_tables_from_metadata(metadata, engine_or_connection): + from . import engines + + def go(connection): + engines.testing_reaper.prepare_for_drop_tables(connection) + + if not connection.dialect.supports_alter: + from . import assertions + + with assertions.expect_warnings( + "Can't sort tables", assert_=False + ): + metadata.drop_all(connection) + else: + metadata.drop_all(connection) + + if not isinstance(engine_or_connection, Connection): + with engine_or_connection.begin() as connection: + go(connection) + else: + go(engine_or_connection) + + +def drop_all_tables( + engine, + inspector, + schema=None, + consider_schemas=(None,), + include_names=None, +): + if include_names is not None: + include_names = set(include_names) + + if schema is not None: + assert consider_schemas == ( + None, + ), "consider_schemas and schema are mutually exclusive" + consider_schemas = (schema,) + + with engine.begin() as conn: + for table_key, fkcs in reversed( + inspector.sort_tables_on_foreign_key_dependency( + consider_schemas=consider_schemas + ) + ): + if table_key: + if ( + include_names is not None + and table_key[1] not in include_names + ): + continue + conn.execute( + DropTable( + Table(table_key[1], MetaData(), schema=table_key[0]) + ) + ) + elif fkcs: + if not engine.dialect.supports_alter: + continue + for t_key, fkc in fkcs: + if ( + include_names is not None + and t_key[1] not in include_names + ): + continue + tb = Table( + t_key[1], + MetaData(), + Column("x", Integer), + Column("y", Integer), + schema=t_key[0], + ) + conn.execute( + DropConstraint( + ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc) + ) + ) + + +def teardown_events(event_cls): + @decorator + def decorate(fn, *arg, **kw): + try: + return fn(*arg, **kw) + finally: + event_cls._clear() + + return decorate + + +def total_size(o): + """Returns the approximate memory footprint an object and all of its + contents. + + source: https://code.activestate.com/recipes/577504/ + + + """ + + def dict_handler(d): + return chain.from_iterable(d.items()) + + all_handlers = { + tuple: iter, + list: iter, + deque: iter, + dict: dict_handler, + set: iter, + frozenset: iter, + } + seen = set() # track which object id's have already been seen + default_size = getsizeof(0) # estimate sizeof object without __sizeof__ + + def sizeof(o): + if id(o) in seen: # do not double count the same object + return 0 + seen.add(id(o)) + s = getsizeof(o, default_size) + + for typ, handler in all_handlers.items(): + if isinstance(o, typ): + s += sum(map(sizeof, handler(o))) + break + return s + + return sizeof(o) + + +def count_cache_key_tuples(tup): + """given a cache key tuple, counts how many instances of actual + tuples are found. + + used to alert large jumps in cache key complexity. + + """ + stack = [tup] + + sentinel = object() + num_elements = 0 + + while stack: + elem = stack.pop(0) + if elem is sentinel: + num_elements += 1 + elif isinstance(elem, tuple): + if elem: + stack = list(elem) + [sentinel] + stack + return num_elements diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/warnings.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/warnings.py new file mode 100644 index 0000000..baef037 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/warnings.py @@ -0,0 +1,52 @@ +# testing/warnings.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from __future__ import annotations + +import warnings + +from . import assertions +from .. import exc +from .. import exc as sa_exc +from ..exc import SATestSuiteWarning +from ..util.langhelpers import _warnings_warn + + +def warn_test_suite(message): + _warnings_warn(message, category=SATestSuiteWarning) + + +def setup_filters(): + """hook for setting up warnings filters. + + SQLAlchemy-specific classes must only be here and not in pytest config, + as we need to delay importing SQLAlchemy until conftest.py has been + processed. + + NOTE: filters on subclasses of DeprecationWarning or + PendingDeprecationWarning have no effect if added here, since pytest + will add at each test the following filters + ``always::PendingDeprecationWarning`` and ``always::DeprecationWarning`` + that will take precedence over any added here. + + """ + warnings.filterwarnings("error", category=exc.SAWarning) + warnings.filterwarnings("always", category=exc.SATestSuiteWarning) + + +def assert_warnings(fn, warning_msgs, regex=False): + """Assert that each of the given warnings are emitted by fn. + + Deprecated. Please use assertions.expect_warnings(). + + """ + + with assertions._expect_warnings( + sa_exc.SAWarning, warning_msgs, regex=regex + ): + return fn() |