diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures')
10 files changed, 1426 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__init__.py new file mode 100644 index 0000000..5981fb5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__init__.py @@ -0,0 +1,28 @@ +# testing/fixtures/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from .base import FutureEngineMixin as FutureEngineMixin +from .base import TestBase as TestBase +from .mypy import MypyTest as MypyTest +from .orm import after_test as after_test +from .orm import close_all_sessions as close_all_sessions +from .orm import DeclarativeMappedTest as DeclarativeMappedTest +from .orm import fixture_session as fixture_session +from .orm import MappedTest as MappedTest +from .orm import ORMTest as ORMTest +from .orm import RemoveORMEventsGlobally as RemoveORMEventsGlobally +from .orm import ( + stop_test_class_inside_fixtures as stop_test_class_inside_fixtures, +) +from .sql import CacheKeyFixture as CacheKeyFixture +from .sql import ( + ComputedReflectionFixtureTest as ComputedReflectionFixtureTest, +) +from .sql import insertmanyvalues_fixture as insertmanyvalues_fixture +from .sql import NoCache as NoCache +from .sql import RemovesEvents as RemovesEvents +from .sql import TablesTest as TablesTest diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0f95b6a --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ccff61f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/mypy.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/mypy.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..278a5f6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/mypy.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/orm.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/orm.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..71d2d9a --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/orm.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/sql.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/sql.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..bc870cb --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/sql.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/base.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/base.py new file mode 100644 index 0000000..0697f49 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/base.py @@ -0,0 +1,366 @@ +# testing/fixtures/base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import sqlalchemy as sa +from .. import assertions +from .. import config +from ..assertions import eq_ +from ..util import drop_all_tables_from_metadata +from ... import Column +from ... import func +from ... import Integer +from ... import select +from ... import Table +from ...orm import DeclarativeBase +from ...orm import MappedAsDataclass +from ...orm import registry + + +@config.mark_base_test_class() +class TestBase: + # A sequence of requirement names matching testing.requires decorators + __requires__ = () + + # A sequence of dialect names to exclude from the test class. + __unsupported_on__ = () + + # If present, test class is only runnable for the *single* specified + # dialect. If you need multiple, use __unsupported_on__ and invert. + __only_on__ = None + + # A sequence of no-arg callables. If any are True, the entire testcase is + # skipped. + __skip_if__ = None + + # if True, the testing reaper will not attempt to touch connection + # state after a test is completed and before the outer teardown + # starts + __leave_connections_for_teardown__ = False + + def assert_(self, val, msg=None): + assert val, msg + + @config.fixture() + def nocache(self): + _cache = config.db._compiled_cache + config.db._compiled_cache = None + yield + config.db._compiled_cache = _cache + + @config.fixture() + def connection_no_trans(self): + eng = getattr(self, "bind", None) or config.db + + with eng.connect() as conn: + yield conn + + @config.fixture() + def connection(self): + global _connection_fixture_connection + + eng = getattr(self, "bind", None) or config.db + + conn = eng.connect() + trans = conn.begin() + + _connection_fixture_connection = conn + yield conn + + _connection_fixture_connection = None + + if trans.is_active: + trans.rollback() + # trans would not be active here if the test is using + # the legacy @provide_metadata decorator still, as it will + # run a close all connections. + conn.close() + + @config.fixture() + def close_result_when_finished(self): + to_close = [] + to_consume = [] + + def go(result, consume=False): + to_close.append(result) + if consume: + to_consume.append(result) + + yield go + for r in to_consume: + try: + r.all() + except: + pass + for r in to_close: + try: + r.close() + except: + pass + + @config.fixture() + def registry(self, metadata): + reg = registry( + metadata=metadata, + type_annotation_map={ + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb", "oracle" + ) + }, + ) + yield reg + reg.dispose() + + @config.fixture + def decl_base(self, metadata): + _md = metadata + + class Base(DeclarativeBase): + metadata = _md + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb", "oracle" + ) + } + + yield Base + Base.registry.dispose() + + @config.fixture + def dc_decl_base(self, metadata): + _md = metadata + + class Base(MappedAsDataclass, DeclarativeBase): + metadata = _md + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb" + ) + } + + yield Base + Base.registry.dispose() + + @config.fixture() + def future_connection(self, future_engine, connection): + # integrate the future_engine and connection fixtures so + # that users of the "connection" fixture will get at the + # "future" connection + yield connection + + @config.fixture() + def future_engine(self): + yield + + @config.fixture() + def testing_engine(self): + from .. import engines + + def gen_testing_engine( + url=None, + options=None, + future=None, + asyncio=False, + transfer_staticpool=False, + share_pool=False, + ): + if options is None: + options = {} + options["scope"] = "fixture" + return engines.testing_engine( + url=url, + options=options, + asyncio=asyncio, + transfer_staticpool=transfer_staticpool, + share_pool=share_pool, + ) + + yield gen_testing_engine + + engines.testing_reaper._drop_testing_engines("fixture") + + @config.fixture() + def async_testing_engine(self, testing_engine): + def go(**kw): + kw["asyncio"] = True + return testing_engine(**kw) + + return go + + @config.fixture() + def metadata(self, request): + """Provide bound MetaData for a single test, dropping afterwards.""" + + from ...sql import schema + + metadata = schema.MetaData() + request.instance.metadata = metadata + yield metadata + del request.instance.metadata + + if ( + _connection_fixture_connection + and _connection_fixture_connection.in_transaction() + ): + trans = _connection_fixture_connection.get_transaction() + trans.rollback() + with _connection_fixture_connection.begin(): + drop_all_tables_from_metadata( + metadata, _connection_fixture_connection + ) + else: + drop_all_tables_from_metadata(metadata, config.db) + + @config.fixture( + params=[ + (rollback, second_operation, begin_nested) + for rollback in (True, False) + for second_operation in ("none", "execute", "begin") + for begin_nested in ( + True, + False, + ) + ] + ) + def trans_ctx_manager_fixture(self, request, metadata): + rollback, second_operation, begin_nested = request.param + + t = Table("test", metadata, Column("data", Integer)) + eng = getattr(self, "bind", None) or config.db + + t.create(eng) + + def run_test(subject, trans_on_subject, execute_on_subject): + with subject.begin() as trans: + if begin_nested: + if not config.requirements.savepoints.enabled: + config.skip_test("savepoints not enabled") + if execute_on_subject: + nested_trans = subject.begin_nested() + else: + nested_trans = trans.begin_nested() + + with nested_trans: + if execute_on_subject: + subject.execute(t.insert(), {"data": 10}) + else: + trans.execute(t.insert(), {"data": 10}) + + # for nested trans, we always commit/rollback on the + # "nested trans" object itself. + # only Session(future=False) will affect savepoint + # transaction for session.commit/rollback + + if rollback: + nested_trans.rollback() + else: + nested_trans.commit() + + if second_operation != "none": + with assertions.expect_raises_message( + sa.exc.InvalidRequestError, + "Can't operate on closed transaction " + "inside context " + "manager. Please complete the context " + "manager " + "before emitting further commands.", + ): + if second_operation == "execute": + if execute_on_subject: + subject.execute( + t.insert(), {"data": 12} + ) + else: + trans.execute(t.insert(), {"data": 12}) + elif second_operation == "begin": + if execute_on_subject: + subject.begin_nested() + else: + trans.begin_nested() + + # outside the nested trans block, but still inside the + # transaction block, we can run SQL, and it will be + # committed + if execute_on_subject: + subject.execute(t.insert(), {"data": 14}) + else: + trans.execute(t.insert(), {"data": 14}) + + else: + if execute_on_subject: + subject.execute(t.insert(), {"data": 10}) + else: + trans.execute(t.insert(), {"data": 10}) + + if trans_on_subject: + if rollback: + subject.rollback() + else: + subject.commit() + else: + if rollback: + trans.rollback() + else: + trans.commit() + + if second_operation != "none": + with assertions.expect_raises_message( + sa.exc.InvalidRequestError, + "Can't operate on closed transaction inside " + "context " + "manager. Please complete the context manager " + "before emitting further commands.", + ): + if second_operation == "execute": + if execute_on_subject: + subject.execute(t.insert(), {"data": 12}) + else: + trans.execute(t.insert(), {"data": 12}) + elif second_operation == "begin": + if hasattr(trans, "begin"): + trans.begin() + else: + subject.begin() + elif second_operation == "begin_nested": + if execute_on_subject: + subject.begin_nested() + else: + trans.begin_nested() + + expected_committed = 0 + if begin_nested: + # begin_nested variant, we inserted a row after the nested + # block + expected_committed += 1 + if not rollback: + # not rollback variant, our row inserted in the target + # block itself would be committed + expected_committed += 1 + + if execute_on_subject: + eq_( + subject.scalar(select(func.count()).select_from(t)), + expected_committed, + ) + else: + with subject.connect() as conn: + eq_( + conn.scalar(select(func.count()).select_from(t)), + expected_committed, + ) + + return run_test + + +_connection_fixture_connection = None + + +class FutureEngineMixin: + """alembic's suite still using this""" diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/mypy.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/mypy.py new file mode 100644 index 0000000..149df9f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/mypy.py @@ -0,0 +1,312 @@ +# testing/fixtures/mypy.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from __future__ import annotations + +import inspect +import os +from pathlib import Path +import re +import shutil +import sys +import tempfile + +from .base import TestBase +from .. import config +from ..assertions import eq_ +from ... import util + + +@config.add_to_marker.mypy +class MypyTest(TestBase): + __requires__ = ("no_sqlalchemy2_stubs",) + + @config.fixture(scope="function") + def per_func_cachedir(self): + yield from self._cachedir() + + @config.fixture(scope="class") + def cachedir(self): + yield from self._cachedir() + + def _cachedir(self): + # as of mypy 0.971 i think we need to keep mypy_path empty + mypy_path = "" + + with tempfile.TemporaryDirectory() as cachedir: + with open( + Path(cachedir) / "sqla_mypy_config.cfg", "w" + ) as config_file: + config_file.write( + f""" + [mypy]\n + plugins = sqlalchemy.ext.mypy.plugin\n + show_error_codes = True\n + {mypy_path} + disable_error_code = no-untyped-call + + [mypy-sqlalchemy.*] + ignore_errors = True + + """ + ) + with open( + Path(cachedir) / "plain_mypy_config.cfg", "w" + ) as config_file: + config_file.write( + f""" + [mypy]\n + show_error_codes = True\n + {mypy_path} + disable_error_code = var-annotated,no-untyped-call + [mypy-sqlalchemy.*] + ignore_errors = True + + """ + ) + yield cachedir + + @config.fixture() + def mypy_runner(self, cachedir): + from mypy import api + + def run(path, use_plugin=False, use_cachedir=None): + if use_cachedir is None: + use_cachedir = cachedir + args = [ + "--strict", + "--raise-exceptions", + "--cache-dir", + use_cachedir, + "--config-file", + os.path.join( + use_cachedir, + ( + "sqla_mypy_config.cfg" + if use_plugin + else "plain_mypy_config.cfg" + ), + ), + ] + + # mypy as of 0.990 is more aggressively blocking messaging + # for paths that are in sys.path, and as pytest puts currdir, + # test/ etc in sys.path, just copy the source file to the + # tempdir we are working in so that we don't have to try to + # manipulate sys.path and/or guess what mypy is doing + filename = os.path.basename(path) + test_program = os.path.join(use_cachedir, filename) + if path != test_program: + shutil.copyfile(path, test_program) + args.append(test_program) + + # I set this locally but for the suite here needs to be + # disabled + os.environ.pop("MYPY_FORCE_COLOR", None) + + stdout, stderr, exitcode = api.run(args) + return stdout, stderr, exitcode + + return run + + @config.fixture + def mypy_typecheck_file(self, mypy_runner): + def run(path, use_plugin=False): + expected_messages = self._collect_messages(path) + stdout, stderr, exitcode = mypy_runner(path, use_plugin=use_plugin) + self._check_output( + path, expected_messages, stdout, stderr, exitcode + ) + + return run + + @staticmethod + def file_combinations(dirname): + if os.path.isabs(dirname): + path = dirname + else: + caller_path = inspect.stack()[1].filename + path = os.path.join(os.path.dirname(caller_path), dirname) + files = list(Path(path).glob("**/*.py")) + + for extra_dir in config.options.mypy_extra_test_paths: + if extra_dir and os.path.isdir(extra_dir): + files.extend((Path(extra_dir) / dirname).glob("**/*.py")) + return files + + def _collect_messages(self, path): + from sqlalchemy.ext.mypy.util import mypy_14 + + expected_messages = [] + expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)") + py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)") + with open(path) as file_: + current_assert_messages = [] + for num, line in enumerate(file_, 1): + m = py_ver_re.match(line) + if m: + major, _, minor = m.group(1).partition(".") + if sys.version_info < (int(major), int(minor)): + config.skip_test( + "Requires python >= %s" % (m.group(1)) + ) + continue + + m = expected_re.match(line) + if m: + is_mypy = bool(m.group(1)) + is_re = bool(m.group(2)) + is_type = bool(m.group(3)) + + expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4)) + if is_type: + if not is_re: + # the goal here is that we can cut-and-paste + # from vscode -> pylance into the + # EXPECTED_TYPE: line, then the test suite will + # validate that line against what mypy produces + expected_msg = re.sub( + r"([\[\]])", + lambda m: rf"\{m.group(0)}", + expected_msg, + ) + + # note making sure preceding text matches + # with a dot, so that an expect for "Select" + # does not match "TypedSelect" + expected_msg = re.sub( + r"([\w_]+)", + lambda m: rf"(?:.*\.)?{m.group(1)}\*?", + expected_msg, + ) + + expected_msg = re.sub( + "List", "builtins.list", expected_msg + ) + + expected_msg = re.sub( + r"\b(int|str|float|bool)\b", + lambda m: rf"builtins.{m.group(0)}\*?", + expected_msg, + ) + # expected_msg = re.sub( + # r"(Sequence|Tuple|List|Union)", + # lambda m: fr"typing.{m.group(0)}\*?", + # expected_msg, + # ) + + is_mypy = is_re = True + expected_msg = f'Revealed type is "{expected_msg}"' + + if mypy_14 and util.py39: + # use_lowercase_names, py39 and above + # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L363 # noqa: E501 + + # skip first character which could be capitalized + # "List item x not found" type of message + expected_msg = expected_msg[0] + re.sub( + ( + r"\b(List|Tuple|Dict|Set)\b" + if is_type + else r"\b(List|Tuple|Dict|Set|Type)\b" + ), + lambda m: m.group(1).lower(), + expected_msg[1:], + ) + + if mypy_14 and util.py310: + # use_or_syntax, py310 and above + # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368 # noqa: E501 + expected_msg = re.sub( + r"Optional\[(.*?)\]", + lambda m: f"{m.group(1)} | None", + expected_msg, + ) + current_assert_messages.append( + (is_mypy, is_re, expected_msg.strip()) + ) + elif current_assert_messages: + expected_messages.extend( + (num, is_mypy, is_re, expected_msg) + for ( + is_mypy, + is_re, + expected_msg, + ) in current_assert_messages + ) + current_assert_messages[:] = [] + + return expected_messages + + def _check_output(self, path, expected_messages, stdout, stderr, exitcode): + not_located = [] + filename = os.path.basename(path) + if expected_messages: + # mypy 0.990 changed how return codes work, so don't assume a + # 1 or a 0 return code here, could be either depending on if + # errors were generated or not + + output = [] + + raw_lines = stdout.split("\n") + while raw_lines: + e = raw_lines.pop(0) + if re.match(r".+\.py:\d+: error: .*", e): + output.append(("error", e)) + elif re.match( + r".+\.py:\d+: note: +(?:Possible overload|def ).*", e + ): + while raw_lines: + ol = raw_lines.pop(0) + if not re.match(r".+\.py:\d+: note: +def \[.*", ol): + break + elif re.match( + r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I + ): + pass + elif re.match(r".+\.py:\d+: note: .*", e): + output.append(("note", e)) + + for num, is_mypy, is_re, msg in expected_messages: + msg = msg.replace("'", '"') + prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else "" + for idx, (typ, errmsg) in enumerate(output): + if is_re: + if re.match( + rf".*{filename}\:{num}\: {typ}\: {prefix}{msg}", + errmsg, + ): + break + elif ( + f"{filename}:{num}: {typ}: {prefix}{msg}" + in errmsg.replace("'", '"') + ): + break + else: + not_located.append(msg) + continue + del output[idx] + + if not_located: + missing = "\n".join(not_located) + print("Couldn't locate expected messages:", missing, sep="\n") + if output: + extra = "\n".join(msg for _, msg in output) + print("Remaining messages:", extra, sep="\n") + assert False, "expected messages not found, see stdout" + + if output: + print(f"{len(output)} messages from mypy were not consumed:") + print("\n".join(msg for _, msg in output)) + assert False, "errors and/or notes remain, see stdout" + + else: + if exitcode != 0: + print(stdout, stderr, sep="\n") + + eq_(exitcode, 0, msg=stdout) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/orm.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/orm.py new file mode 100644 index 0000000..5ddd21e --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/orm.py @@ -0,0 +1,227 @@ +# testing/fixtures/orm.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from __future__ import annotations + +from typing import Any + +import sqlalchemy as sa +from .base import TestBase +from .sql import TablesTest +from .. import assertions +from .. import config +from .. import schema +from ..entities import BasicEntity +from ..entities import ComparableEntity +from ..util import adict +from ... import orm +from ...orm import DeclarativeBase +from ...orm import events as orm_events +from ...orm import registry + + +class ORMTest(TestBase): + @config.fixture + def fixture_session(self): + return fixture_session() + + +class MappedTest(ORMTest, TablesTest, assertions.AssertsExecutionResults): + # 'once', 'each', None + run_setup_classes = "once" + + # 'once', 'each', None + run_setup_mappers = "each" + + classes: Any = None + + @config.fixture(autouse=True, scope="class") + def _setup_tables_test_class(self): + cls = self.__class__ + cls._init_class() + + if cls.classes is None: + cls.classes = adict() + + cls._setup_once_tables() + cls._setup_once_classes() + cls._setup_once_mappers() + cls._setup_once_inserts() + + yield + + cls._teardown_once_class() + cls._teardown_once_metadata_bind() + + @config.fixture(autouse=True, scope="function") + def _setup_tables_test_instance(self): + self._setup_each_tables() + self._setup_each_classes() + self._setup_each_mappers() + self._setup_each_inserts() + + yield + + orm.session.close_all_sessions() + self._teardown_each_mappers() + self._teardown_each_classes() + self._teardown_each_tables() + + @classmethod + def _teardown_once_class(cls): + cls.classes.clear() + + @classmethod + def _setup_once_classes(cls): + if cls.run_setup_classes == "once": + cls._with_register_classes(cls.setup_classes) + + @classmethod + def _setup_once_mappers(cls): + if cls.run_setup_mappers == "once": + cls.mapper_registry, cls.mapper = cls._generate_registry() + cls._with_register_classes(cls.setup_mappers) + + def _setup_each_mappers(self): + if self.run_setup_mappers != "once": + ( + self.__class__.mapper_registry, + self.__class__.mapper, + ) = self._generate_registry() + + if self.run_setup_mappers == "each": + self._with_register_classes(self.setup_mappers) + + def _setup_each_classes(self): + if self.run_setup_classes == "each": + self._with_register_classes(self.setup_classes) + + @classmethod + def _generate_registry(cls): + decl = registry(metadata=cls._tables_metadata) + return decl, decl.map_imperatively + + @classmethod + def _with_register_classes(cls, fn): + """Run a setup method, framing the operation with a Base class + that will catch new subclasses to be established within + the "classes" registry. + + """ + cls_registry = cls.classes + + class _Base: + def __init_subclass__(cls) -> None: + assert cls_registry is not None + cls_registry[cls.__name__] = cls + super().__init_subclass__() + + class Basic(BasicEntity, _Base): + pass + + class Comparable(ComparableEntity, _Base): + pass + + cls.Basic = Basic + cls.Comparable = Comparable + fn() + + def _teardown_each_mappers(self): + # some tests create mappers in the test bodies + # and will define setup_mappers as None - + # clear mappers in any case + if self.run_setup_mappers != "once": + orm.clear_mappers() + + def _teardown_each_classes(self): + if self.run_setup_classes != "once": + self.classes.clear() + + @classmethod + def setup_classes(cls): + pass + + @classmethod + def setup_mappers(cls): + pass + + +class DeclarativeMappedTest(MappedTest): + run_setup_classes = "once" + run_setup_mappers = "once" + + @classmethod + def _setup_once_tables(cls): + pass + + @classmethod + def _with_register_classes(cls, fn): + cls_registry = cls.classes + + class _DeclBase(DeclarativeBase): + __table_cls__ = schema.Table + metadata = cls._tables_metadata + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb", "oracle" + ) + } + + def __init_subclass__(cls, **kw) -> None: + assert cls_registry is not None + cls_registry[cls.__name__] = cls + super().__init_subclass__(**kw) + + cls.DeclarativeBasic = _DeclBase + + # sets up cls.Basic which is helpful for things like composite + # classes + super()._with_register_classes(fn) + + if cls._tables_metadata.tables and cls.run_create_tables: + cls._tables_metadata.create_all(config.db) + + +class RemoveORMEventsGlobally: + @config.fixture(autouse=True) + def _remove_listeners(self): + yield + orm_events.MapperEvents._clear() + orm_events.InstanceEvents._clear() + orm_events.SessionEvents._clear() + orm_events.InstrumentationEvents._clear() + orm_events.QueryEvents._clear() + + +_fixture_sessions = set() + + +def fixture_session(**kw): + kw.setdefault("autoflush", True) + kw.setdefault("expire_on_commit", True) + + bind = kw.pop("bind", config.db) + + sess = orm.Session(bind, **kw) + _fixture_sessions.add(sess) + return sess + + +def close_all_sessions(): + # will close all still-referenced sessions + orm.close_all_sessions() + _fixture_sessions.clear() + + +def stop_test_class_inside_fixtures(cls): + close_all_sessions() + orm.clear_mappers() + + +def after_test(): + if _fixture_sessions: + close_all_sessions() diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/sql.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/sql.py new file mode 100644 index 0000000..830fa27 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/sql.py @@ -0,0 +1,493 @@ +# testing/fixtures/sql.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from __future__ import annotations + +import itertools +import random +import re +import sys + +import sqlalchemy as sa +from .base import TestBase +from .. import config +from .. import mock +from ..assertions import eq_ +from ..assertions import ne_ +from ..util import adict +from ..util import drop_all_tables_from_metadata +from ... import event +from ... import util +from ...schema import sort_tables_and_constraints +from ...sql import visitors +from ...sql.elements import ClauseElement + + +class TablesTest(TestBase): + # 'once', None + run_setup_bind = "once" + + # 'once', 'each', None + run_define_tables = "once" + + # 'once', 'each', None + run_create_tables = "once" + + # 'once', 'each', None + run_inserts = "each" + + # 'each', None + run_deletes = "each" + + # 'once', None + run_dispose_bind = None + + bind = None + _tables_metadata = None + tables = None + other = None + sequences = None + + @config.fixture(autouse=True, scope="class") + def _setup_tables_test_class(self): + cls = self.__class__ + cls._init_class() + + cls._setup_once_tables() + + cls._setup_once_inserts() + + yield + + cls._teardown_once_metadata_bind() + + @config.fixture(autouse=True, scope="function") + def _setup_tables_test_instance(self): + self._setup_each_tables() + self._setup_each_inserts() + + yield + + self._teardown_each_tables() + + @property + def tables_test_metadata(self): + return self._tables_metadata + + @classmethod + def _init_class(cls): + if cls.run_define_tables == "each": + if cls.run_create_tables == "once": + cls.run_create_tables = "each" + assert cls.run_inserts in ("each", None) + + cls.other = adict() + cls.tables = adict() + cls.sequences = adict() + + cls.bind = cls.setup_bind() + cls._tables_metadata = sa.MetaData() + + @classmethod + def _setup_once_inserts(cls): + if cls.run_inserts == "once": + cls._load_fixtures() + with cls.bind.begin() as conn: + cls.insert_data(conn) + + @classmethod + def _setup_once_tables(cls): + if cls.run_define_tables == "once": + cls.define_tables(cls._tables_metadata) + if cls.run_create_tables == "once": + cls._tables_metadata.create_all(cls.bind) + cls.tables.update(cls._tables_metadata.tables) + cls.sequences.update(cls._tables_metadata._sequences) + + def _setup_each_tables(self): + if self.run_define_tables == "each": + self.define_tables(self._tables_metadata) + if self.run_create_tables == "each": + self._tables_metadata.create_all(self.bind) + self.tables.update(self._tables_metadata.tables) + self.sequences.update(self._tables_metadata._sequences) + elif self.run_create_tables == "each": + self._tables_metadata.create_all(self.bind) + + def _setup_each_inserts(self): + if self.run_inserts == "each": + self._load_fixtures() + with self.bind.begin() as conn: + self.insert_data(conn) + + def _teardown_each_tables(self): + if self.run_define_tables == "each": + self.tables.clear() + if self.run_create_tables == "each": + drop_all_tables_from_metadata(self._tables_metadata, self.bind) + self._tables_metadata.clear() + elif self.run_create_tables == "each": + drop_all_tables_from_metadata(self._tables_metadata, self.bind) + + savepoints = getattr(config.requirements, "savepoints", False) + if savepoints: + savepoints = savepoints.enabled + + # no need to run deletes if tables are recreated on setup + if ( + self.run_define_tables != "each" + and self.run_create_tables != "each" + and self.run_deletes == "each" + ): + with self.bind.begin() as conn: + for table in reversed( + [ + t + for (t, fks) in sort_tables_and_constraints( + self._tables_metadata.tables.values() + ) + if t is not None + ] + ): + try: + if savepoints: + with conn.begin_nested(): + conn.execute(table.delete()) + else: + conn.execute(table.delete()) + except sa.exc.DBAPIError as ex: + print( + ("Error emptying table %s: %r" % (table, ex)), + file=sys.stderr, + ) + + @classmethod + def _teardown_once_metadata_bind(cls): + if cls.run_create_tables: + drop_all_tables_from_metadata(cls._tables_metadata, cls.bind) + + if cls.run_dispose_bind == "once": + cls.dispose_bind(cls.bind) + + cls._tables_metadata.bind = None + + if cls.run_setup_bind is not None: + cls.bind = None + + @classmethod + def setup_bind(cls): + return config.db + + @classmethod + def dispose_bind(cls, bind): + if hasattr(bind, "dispose"): + bind.dispose() + elif hasattr(bind, "close"): + bind.close() + + @classmethod + def define_tables(cls, metadata): + pass + + @classmethod + def fixtures(cls): + return {} + + @classmethod + def insert_data(cls, connection): + pass + + def sql_count_(self, count, fn): + self.assert_sql_count(self.bind, fn, count) + + def sql_eq_(self, callable_, statements): + self.assert_sql(self.bind, callable_, statements) + + @classmethod + def _load_fixtures(cls): + """Insert rows as represented by the fixtures() method.""" + headers, rows = {}, {} + for table, data in cls.fixtures().items(): + if len(data) < 2: + continue + if isinstance(table, str): + table = cls.tables[table] + headers[table] = data[0] + rows[table] = data[1:] + for table, fks in sort_tables_and_constraints( + cls._tables_metadata.tables.values() + ): + if table is None: + continue + if table not in headers: + continue + with cls.bind.begin() as conn: + conn.execute( + table.insert(), + [ + dict(zip(headers[table], column_values)) + for column_values in rows[table] + ], + ) + + +class NoCache: + @config.fixture(autouse=True, scope="function") + def _disable_cache(self): + _cache = config.db._compiled_cache + config.db._compiled_cache = None + yield + config.db._compiled_cache = _cache + + +class RemovesEvents: + @util.memoized_property + def _event_fns(self): + return set() + + def event_listen(self, target, name, fn, **kw): + self._event_fns.add((target, name, fn)) + event.listen(target, name, fn, **kw) + + @config.fixture(autouse=True, scope="function") + def _remove_events(self): + yield + for key in self._event_fns: + event.remove(*key) + + +class ComputedReflectionFixtureTest(TablesTest): + run_inserts = run_deletes = None + + __backend__ = True + __requires__ = ("computed_columns", "table_reflection") + + regexp = re.compile(r"[\[\]\(\)\s`'\"]*") + + def normalize(self, text): + return self.regexp.sub("", text).lower() + + @classmethod + def define_tables(cls, metadata): + from ... import Integer + from ... import testing + from ...schema import Column + from ...schema import Computed + from ...schema import Table + + Table( + "computed_default_table", + metadata, + Column("id", Integer, primary_key=True), + Column("normal", Integer), + Column("computed_col", Integer, Computed("normal + 42")), + Column("with_default", Integer, server_default="42"), + ) + + t = Table( + "computed_column_table", + metadata, + Column("id", Integer, primary_key=True), + Column("normal", Integer), + Column("computed_no_flag", Integer, Computed("normal + 42")), + ) + + if testing.requires.schemas.enabled: + t2 = Table( + "computed_column_table", + metadata, + Column("id", Integer, primary_key=True), + Column("normal", Integer), + Column("computed_no_flag", Integer, Computed("normal / 42")), + schema=config.test_schema, + ) + + if testing.requires.computed_columns_virtual.enabled: + t.append_column( + Column( + "computed_virtual", + Integer, + Computed("normal + 2", persisted=False), + ) + ) + if testing.requires.schemas.enabled: + t2.append_column( + Column( + "computed_virtual", + Integer, + Computed("normal / 2", persisted=False), + ) + ) + if testing.requires.computed_columns_stored.enabled: + t.append_column( + Column( + "computed_stored", + Integer, + Computed("normal - 42", persisted=True), + ) + ) + if testing.requires.schemas.enabled: + t2.append_column( + Column( + "computed_stored", + Integer, + Computed("normal * 42", persisted=True), + ) + ) + + +class CacheKeyFixture: + def _compare_equal(self, a, b, compare_values): + a_key = a._generate_cache_key() + b_key = b._generate_cache_key() + + if a_key is None: + assert a._annotations.get("nocache") + + assert b_key is None + else: + eq_(a_key.key, b_key.key) + eq_(hash(a_key.key), hash(b_key.key)) + + for a_param, b_param in zip(a_key.bindparams, b_key.bindparams): + assert a_param.compare(b_param, compare_values=compare_values) + return a_key, b_key + + def _run_cache_key_fixture(self, fixture, compare_values): + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + if a == b: + a_key, b_key = self._compare_equal( + case_a[a], case_b[b], compare_values + ) + if a_key is None: + continue + else: + a_key = case_a[a]._generate_cache_key() + b_key = case_b[b]._generate_cache_key() + + if a_key is None or b_key is None: + if a_key is None: + assert case_a[a]._annotations.get("nocache") + if b_key is None: + assert case_b[b]._annotations.get("nocache") + continue + + if a_key.key == b_key.key: + for a_param, b_param in zip( + a_key.bindparams, b_key.bindparams + ): + if not a_param.compare( + b_param, compare_values=compare_values + ): + break + else: + # this fails unconditionally since we could not + # find bound parameter values that differed. + # Usually we intended to get two distinct keys here + # so the failure will be more descriptive using the + # ne_() assertion. + ne_(a_key.key, b_key.key) + else: + ne_(a_key.key, b_key.key) + + # ClauseElement-specific test to ensure the cache key + # collected all the bound parameters that aren't marked + # as "literal execute" + if isinstance(case_a[a], ClauseElement) and isinstance( + case_b[b], ClauseElement + ): + assert_a_params = [] + assert_b_params = [] + + for elem in visitors.iterate(case_a[a]): + if elem.__visit_name__ == "bindparam": + assert_a_params.append(elem) + + for elem in visitors.iterate(case_b[b]): + if elem.__visit_name__ == "bindparam": + assert_b_params.append(elem) + + # note we're asserting the order of the params as well as + # if there are dupes or not. ordering has to be + # deterministic and matches what a traversal would provide. + eq_( + sorted(a_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_a_params), key=lambda b: b.key + ), + ) + eq_( + sorted(b_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_b_params), key=lambda b: b.key + ), + ) + + def _run_cache_key_equal_fixture(self, fixture, compare_values): + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + self._compare_equal(case_a[a], case_b[b], compare_values) + + +def insertmanyvalues_fixture( + connection, randomize_rows=False, warn_on_downgraded=False +): + dialect = connection.dialect + orig_dialect = dialect._deliver_insertmanyvalues_batches + orig_conn = connection._exec_insertmany_context + + class RandomCursor: + __slots__ = ("cursor",) + + def __init__(self, cursor): + self.cursor = cursor + + # only this method is called by the deliver method. + # by not having the other methods we assert that those aren't being + # used + + @property + def description(self): + return self.cursor.description + + def fetchall(self): + rows = self.cursor.fetchall() + rows = list(rows) + random.shuffle(rows) + return rows + + def _deliver_insertmanyvalues_batches( + cursor, statement, parameters, generic_setinputsizes, context + ): + if randomize_rows: + cursor = RandomCursor(cursor) + for batch in orig_dialect( + cursor, statement, parameters, generic_setinputsizes, context + ): + if warn_on_downgraded and batch.is_downgraded: + util.warn("Batches were downgraded for sorted INSERT") + + yield batch + + def _exec_insertmany_context(dialect, context): + with mock.patch.object( + dialect, + "_deliver_insertmanyvalues_batches", + new=_deliver_insertmanyvalues_batches, + ): + return orig_conn(dialect, context) + + connection._exec_insertmany_context = _exec_insertmany_context |