diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin')
8 files changed, 1704 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__init__.py new file mode 100644 index 0000000..0f98777 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__init__.py @@ -0,0 +1,6 @@ +# testing/plugin/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..163dc80 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/bootstrap.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/bootstrap.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..9b70281 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/bootstrap.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/plugin_base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/plugin_base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c65e39e --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/plugin_base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/pytestplugin.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/pytestplugin.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6eb39ab --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/pytestplugin.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/bootstrap.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/bootstrap.py new file mode 100644 index 0000000..d0d3754 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/bootstrap.py @@ -0,0 +1,51 @@ +# testing/plugin/bootstrap.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +""" +Bootstrapper for test framework plugins. + +The entire rationale for this system is to get the modules in plugin/ +imported without importing all of the supporting library, so that we can +set up things for testing before coverage starts. + +The rationale for all of plugin/ being *in* the supporting library in the +first place is so that the testing and plugin suite is available to other +libraries, mainly external SQLAlchemy and Alembic dialects, to make use +of the same test environment and standard suites available to +SQLAlchemy/Alembic themselves without the need to ship/install a separate +package outside of SQLAlchemy. + + +""" + +import importlib.util +import os +import sys + + +bootstrap_file = locals()["bootstrap_file"] +to_bootstrap = locals()["to_bootstrap"] + + +def load_file_as_module(name): + path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name) + + spec = importlib.util.spec_from_file_location(name, path) + assert spec is not None + assert spec.loader is not None + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +if to_bootstrap == "pytest": + sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base") + sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True + sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin") +else: + raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/plugin_base.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/plugin_base.py new file mode 100644 index 0000000..a642668 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/plugin_base.py @@ -0,0 +1,779 @@ +# testing/plugin/plugin_base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import abc +from argparse import Namespace +import configparser +import logging +import os +from pathlib import Path +import re +import sys +from typing import Any + +from sqlalchemy.testing import asyncio + +"""Testing extensions. + +this module is designed to work as a testing-framework-agnostic library, +created so that multiple test frameworks can be supported at once +(mostly so that we can migrate to new ones). The current target +is pytest. + +""" + +# flag which indicates we are in the SQLAlchemy testing suite, +# and not that of Alembic or a third party dialect. +bootstrapped_as_sqlalchemy = False + +log = logging.getLogger("sqlalchemy.testing.plugin_base") + +# late imports +fixtures = None +engines = None +exclusions = None +warnings = None +profiling = None +provision = None +assertions = None +requirements = None +config = None +testing = None +util = None +file_config = None + +logging = None +include_tags = set() +exclude_tags = set() +options: Namespace = None # type: ignore + + +def setup_options(make_option): + make_option( + "--log-info", + action="callback", + type=str, + callback=_log, + help="turn on info logging for <LOG> (multiple OK)", + ) + make_option( + "--log-debug", + action="callback", + type=str, + callback=_log, + help="turn on debug logging for <LOG> (multiple OK)", + ) + make_option( + "--db", + action="append", + type=str, + dest="db", + help="Use prefab database uri. Multiple OK, " + "first one is run by default.", + ) + make_option( + "--dbs", + action="callback", + zeroarg_callback=_list_dbs, + help="List available prefab dbs", + ) + make_option( + "--dburi", + action="append", + type=str, + dest="dburi", + help="Database uri. Multiple OK, first one is run by default.", + ) + make_option( + "--dbdriver", + action="append", + type=str, + dest="dbdriver", + help="Additional database drivers to include in tests. " + "These are linked to the existing database URLs by the " + "provisioning system.", + ) + make_option( + "--dropfirst", + action="store_true", + dest="dropfirst", + help="Drop all tables in the target database first", + ) + make_option( + "--disable-asyncio", + action="store_true", + help="disable test / fixtures / provisoning running in asyncio", + ) + make_option( + "--backend-only", + action="callback", + zeroarg_callback=_set_tag_include("backend"), + help=( + "Run only tests marked with __backend__ or __sparse_backend__; " + "this is now equivalent to the pytest -m backend mark expression" + ), + ) + make_option( + "--nomemory", + action="callback", + zeroarg_callback=_set_tag_exclude("memory_intensive"), + help="Don't run memory profiling tests; " + "this is now equivalent to the pytest -m 'not memory_intensive' " + "mark expression", + ) + make_option( + "--notimingintensive", + action="callback", + zeroarg_callback=_set_tag_exclude("timing_intensive"), + help="Don't run timing intensive tests; " + "this is now equivalent to the pytest -m 'not timing_intensive' " + "mark expression", + ) + make_option( + "--nomypy", + action="callback", + zeroarg_callback=_set_tag_exclude("mypy"), + help="Don't run mypy typing tests; " + "this is now equivalent to the pytest -m 'not mypy' mark expression", + ) + make_option( + "--profile-sort", + type=str, + default="cumulative", + dest="profilesort", + help="Type of sort for profiling standard output", + ) + make_option( + "--profile-dump", + type=str, + dest="profiledump", + help="Filename where a single profile run will be dumped", + ) + make_option( + "--low-connections", + action="store_true", + dest="low_connections", + help="Use a low number of distinct connections - " + "i.e. for Oracle TNS", + ) + make_option( + "--write-idents", + type=str, + dest="write_idents", + help="write out generated follower idents to <file>, " + "when -n<num> is used", + ) + make_option( + "--requirements", + action="callback", + type=str, + callback=_requirements_opt, + help="requirements class for testing, overrides setup.cfg", + ) + make_option( + "--include-tag", + action="callback", + callback=_include_tag, + type=str, + help="Include tests with tag <tag>; " + "legacy, use pytest -m 'tag' instead", + ) + make_option( + "--exclude-tag", + action="callback", + callback=_exclude_tag, + type=str, + help="Exclude tests with tag <tag>; " + "legacy, use pytest -m 'not tag' instead", + ) + make_option( + "--write-profiles", + action="store_true", + dest="write_profiles", + default=False, + help="Write/update failing profiling data.", + ) + make_option( + "--force-write-profiles", + action="store_true", + dest="force_write_profiles", + default=False, + help="Unconditionally write/update profiling data.", + ) + make_option( + "--dump-pyannotate", + type=str, + dest="dump_pyannotate", + help="Run pyannotate and dump json info to given file", + ) + make_option( + "--mypy-extra-test-path", + type=str, + action="append", + default=[], + dest="mypy_extra_test_paths", + help="Additional test directories to add to the mypy tests. " + "This is used only when running mypy tests. Multiple OK", + ) + # db specific options + make_option( + "--postgresql-templatedb", + type=str, + help="name of template database to use for PostgreSQL " + "CREATE DATABASE (defaults to current database)", + ) + make_option( + "--oracledb-thick-mode", + action="store_true", + help="enables the 'thick mode' when testing with oracle+oracledb", + ) + + +def configure_follower(follower_ident): + """Configure required state for a follower. + + This invokes in the parent process and typically includes + database creation. + + """ + from sqlalchemy.testing import provision + + provision.FOLLOWER_IDENT = follower_ident + + +def memoize_important_follower_config(dict_): + """Store important configuration we will need to send to a follower. + + This invokes in the parent process after normal config is set up. + + Hook is currently not used. + + """ + + +def restore_important_follower_config(dict_): + """Restore important configuration needed by a follower. + + This invokes in the follower process. + + Hook is currently not used. + + """ + + +def read_config(root_path): + global file_config + file_config = configparser.ConfigParser() + file_config.read( + [str(root_path / "setup.cfg"), str(root_path / "test.cfg")] + ) + + +def pre_begin(opt): + """things to set up early, before coverage might be setup.""" + global options + options = opt + for fn in pre_configure: + fn(options, file_config) + + +def set_coverage_flag(value): + options.has_coverage = value + + +def post_begin(): + """things to set up later, once we know coverage is running.""" + # Lazy setup of other options (post coverage) + for fn in post_configure: + fn(options, file_config) + + # late imports, has to happen after config. + global util, fixtures, engines, exclusions, assertions, provision + global warnings, profiling, config, testing + from sqlalchemy import testing # noqa + from sqlalchemy.testing import fixtures, engines, exclusions # noqa + from sqlalchemy.testing import assertions, warnings, profiling # noqa + from sqlalchemy.testing import config, provision # noqa + from sqlalchemy import util # noqa + + warnings.setup_filters() + + +def _log(opt_str, value, parser): + global logging + if not logging: + import logging + + logging.basicConfig() + + if opt_str.endswith("-info"): + logging.getLogger(value).setLevel(logging.INFO) + elif opt_str.endswith("-debug"): + logging.getLogger(value).setLevel(logging.DEBUG) + + +def _list_dbs(*args): + if file_config is None: + # assume the current working directory is the one containing the + # setup file + read_config(Path.cwd()) + print("Available --db options (use --dburi to override)") + for macro in sorted(file_config.options("db")): + print("%20s\t%s" % (macro, file_config.get("db", macro))) + sys.exit(0) + + +def _requirements_opt(opt_str, value, parser): + _setup_requirements(value) + + +def _set_tag_include(tag): + def _do_include_tag(opt_str, value, parser): + _include_tag(opt_str, tag, parser) + + return _do_include_tag + + +def _set_tag_exclude(tag): + def _do_exclude_tag(opt_str, value, parser): + _exclude_tag(opt_str, tag, parser) + + return _do_exclude_tag + + +def _exclude_tag(opt_str, value, parser): + exclude_tags.add(value.replace("-", "_")) + + +def _include_tag(opt_str, value, parser): + include_tags.add(value.replace("-", "_")) + + +pre_configure = [] +post_configure = [] + + +def pre(fn): + pre_configure.append(fn) + return fn + + +def post(fn): + post_configure.append(fn) + return fn + + +@pre +def _setup_options(opt, file_config): + global options + options = opt + + +@pre +def _register_sqlite_numeric_dialect(opt, file_config): + from sqlalchemy.dialects import registry + + registry.register( + "sqlite.pysqlite_numeric", + "sqlalchemy.dialects.sqlite.pysqlite", + "_SQLiteDialect_pysqlite_numeric", + ) + registry.register( + "sqlite.pysqlite_dollar", + "sqlalchemy.dialects.sqlite.pysqlite", + "_SQLiteDialect_pysqlite_dollar", + ) + + +@post +def __ensure_cext(opt, file_config): + if os.environ.get("REQUIRE_SQLALCHEMY_CEXT", "0") == "1": + from sqlalchemy.util import has_compiled_ext + + try: + has_compiled_ext(raise_=True) + except ImportError as err: + raise AssertionError( + "REQUIRE_SQLALCHEMY_CEXT is set but can't import the " + "cython extensions" + ) from err + + +@post +def _init_symbols(options, file_config): + from sqlalchemy.testing import config + + config._fixture_functions = _fixture_fn_class() + + +@pre +def _set_disable_asyncio(opt, file_config): + if opt.disable_asyncio: + asyncio.ENABLE_ASYNCIO = False + + +@post +def _engine_uri(options, file_config): + from sqlalchemy import testing + from sqlalchemy.testing import config + from sqlalchemy.testing import provision + from sqlalchemy.engine import url as sa_url + + if options.dburi: + db_urls = list(options.dburi) + else: + db_urls = [] + + extra_drivers = options.dbdriver or [] + + if options.db: + for db_token in options.db: + for db in re.split(r"[,\s]+", db_token): + if db not in file_config.options("db"): + raise RuntimeError( + "Unknown URI specifier '%s'. " + "Specify --dbs for known uris." % db + ) + else: + db_urls.append(file_config.get("db", db)) + + if not db_urls: + db_urls.append(file_config.get("db", "default")) + + config._current = None + + if options.write_idents and provision.FOLLOWER_IDENT: + for db_url in [sa_url.make_url(db_url) for db_url in db_urls]: + with open(options.write_idents, "a") as file_: + file_.write( + f"{provision.FOLLOWER_IDENT} " + f"{db_url.render_as_string(hide_password=False)}\n" + ) + + expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers)) + + for db_url in expanded_urls: + log.info("Adding database URL: %s", db_url) + + cfg = provision.setup_config( + db_url, options, file_config, provision.FOLLOWER_IDENT + ) + if not config._current: + cfg.set_as_current(cfg, testing) + + +@post +def _requirements(options, file_config): + requirement_cls = file_config.get("sqla_testing", "requirement_cls") + _setup_requirements(requirement_cls) + + +def _setup_requirements(argument): + from sqlalchemy.testing import config + from sqlalchemy import testing + + modname, clsname = argument.split(":") + + # importlib.import_module() only introduced in 2.7, a little + # late + mod = __import__(modname) + for component in modname.split(".")[1:]: + mod = getattr(mod, component) + req_cls = getattr(mod, clsname) + + config.requirements = testing.requires = req_cls() + + config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy + + +@post +def _prep_testing_database(options, file_config): + from sqlalchemy.testing import config + + if options.dropfirst: + from sqlalchemy.testing import provision + + for cfg in config.Config.all_configs(): + provision.drop_all_schema_objects(cfg, cfg.db) + + +@post +def _post_setup_options(opt, file_config): + from sqlalchemy.testing import config + + config.options = options + config.file_config = file_config + + +@post +def _setup_profiling(options, file_config): + from sqlalchemy.testing import profiling + + profiling._profile_stats = profiling.ProfileStatsFile( + file_config.get("sqla_testing", "profile_file"), + sort=options.profilesort, + dump=options.profiledump, + ) + + +def want_class(name, cls): + if not issubclass(cls, fixtures.TestBase): + return False + elif name.startswith("_"): + return False + else: + return True + + +def want_method(cls, fn): + if not fn.__name__.startswith("test_"): + return False + elif fn.__module__ is None: + return False + else: + return True + + +def generate_sub_tests(cls, module, markers): + if "backend" in markers or "sparse_backend" in markers: + sparse = "sparse_backend" in markers + for cfg in _possible_configs_for_cls(cls, sparse=sparse): + orig_name = cls.__name__ + + # we can have special chars in these names except for the + # pytest junit plugin, which is tripped up by the brackets + # and periods, so sanitize + + alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name) + alpha_name = re.sub(r"_+$", "", alpha_name) + name = "%s_%s" % (cls.__name__, alpha_name) + subcls = type( + name, + (cls,), + {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg}, + ) + setattr(module, name, subcls) + yield subcls + else: + yield cls + + +def start_test_class_outside_fixtures(cls): + _do_skips(cls) + _setup_engine(cls) + + +def stop_test_class(cls): + # close sessions, immediate connections, etc. + fixtures.stop_test_class_inside_fixtures(cls) + + # close outstanding connection pool connections, dispose of + # additional engines + engines.testing_reaper.stop_test_class_inside_fixtures() + + +def stop_test_class_outside_fixtures(cls): + engines.testing_reaper.stop_test_class_outside_fixtures() + provision.stop_test_class_outside_fixtures(config, config.db, cls) + try: + if not options.low_connections: + assertions.global_cleanup_assertions() + finally: + _restore_engine() + + +def _restore_engine(): + if config._current: + config._current.reset(testing) + + +def final_process_cleanup(): + engines.testing_reaper.final_cleanup() + assertions.global_cleanup_assertions() + _restore_engine() + + +def _setup_engine(cls): + if getattr(cls, "__engine_options__", None): + opts = dict(cls.__engine_options__) + opts["scope"] = "class" + eng = engines.testing_engine(options=opts) + config._current.push_engine(eng, testing) + + +def before_test(test, test_module_name, test_class, test_name): + # format looks like: + # "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause" + + name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__) + + id_ = "%s.%s.%s" % (test_module_name, name, test_name) + + profiling._start_current_test(id_) + + +def after_test(test): + fixtures.after_test() + engines.testing_reaper.after_test() + + +def after_test_fixtures(test): + engines.testing_reaper.after_test_outside_fixtures(test) + + +def _possible_configs_for_cls(cls, reasons=None, sparse=False): + all_configs = set(config.Config.all_configs()) + + if cls.__unsupported_on__: + spec = exclusions.db_spec(*cls.__unsupported_on__) + for config_obj in list(all_configs): + if spec(config_obj): + all_configs.remove(config_obj) + + if getattr(cls, "__only_on__", None): + spec = exclusions.db_spec(*util.to_list(cls.__only_on__)) + for config_obj in list(all_configs): + if not spec(config_obj): + all_configs.remove(config_obj) + + if getattr(cls, "__only_on_config__", None): + all_configs.intersection_update([cls.__only_on_config__]) + + if hasattr(cls, "__requires__"): + requirements = config.requirements + for config_obj in list(all_configs): + for requirement in cls.__requires__: + check = getattr(requirements, requirement) + + skip_reasons = check.matching_config_reasons(config_obj) + if skip_reasons: + all_configs.remove(config_obj) + if reasons is not None: + reasons.extend(skip_reasons) + break + + if hasattr(cls, "__prefer_requires__"): + non_preferred = set() + requirements = config.requirements + for config_obj in list(all_configs): + for requirement in cls.__prefer_requires__: + check = getattr(requirements, requirement) + + if not check.enabled_for_config(config_obj): + non_preferred.add(config_obj) + if all_configs.difference(non_preferred): + all_configs.difference_update(non_preferred) + + if sparse: + # pick only one config from each base dialect + # sorted so we get the same backend each time selecting the highest + # server version info. + per_dialect = {} + for cfg in reversed( + sorted( + all_configs, + key=lambda cfg: ( + cfg.db.name, + cfg.db.driver, + cfg.db.dialect.server_version_info, + ), + ) + ): + db = cfg.db.name + if db not in per_dialect: + per_dialect[db] = cfg + return per_dialect.values() + + return all_configs + + +def _do_skips(cls): + reasons = [] + all_configs = _possible_configs_for_cls(cls, reasons) + + if getattr(cls, "__skip_if__", False): + for c in getattr(cls, "__skip_if__"): + if c(): + config.skip_test( + "'%s' skipped by %s" % (cls.__name__, c.__name__) + ) + + if not all_configs: + msg = "'%s.%s' unsupported on any DB implementation %s%s" % ( + cls.__module__, + cls.__name__, + ", ".join( + "'%s(%s)+%s'" + % ( + config_obj.db.name, + ".".join( + str(dig) + for dig in exclusions._server_version(config_obj.db) + ), + config_obj.db.driver, + ) + for config_obj in config.Config.all_configs() + ), + ", ".join(reasons), + ) + config.skip_test(msg) + elif hasattr(cls, "__prefer_backends__"): + non_preferred = set() + spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__)) + for config_obj in all_configs: + if not spec(config_obj): + non_preferred.add(config_obj) + if all_configs.difference(non_preferred): + all_configs.difference_update(non_preferred) + + if config._current not in all_configs: + _setup_config(all_configs.pop(), cls) + + +def _setup_config(config_obj, ctx): + config._current.push(config_obj, testing) + + +class FixtureFunctions(abc.ABC): + @abc.abstractmethod + def skip_test_exception(self, *arg, **kw): + raise NotImplementedError() + + @abc.abstractmethod + def combinations(self, *args, **kw): + raise NotImplementedError() + + @abc.abstractmethod + def param_ident(self, *args, **kw): + raise NotImplementedError() + + @abc.abstractmethod + def fixture(self, *arg, **kw): + raise NotImplementedError() + + def get_current_test_name(self): + raise NotImplementedError() + + @abc.abstractmethod + def mark_base_test_class(self) -> Any: + raise NotImplementedError() + + @abc.abstractproperty + def add_to_marker(self): + raise NotImplementedError() + + +_fixture_fn_class = None + + +def set_fixture_functions(fixture_fn_class): + global _fixture_fn_class + _fixture_fn_class = fixture_fn_class diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/pytestplugin.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/pytestplugin.py new file mode 100644 index 0000000..1a4d4bb --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/pytestplugin.py @@ -0,0 +1,868 @@ +# testing/plugin/pytestplugin.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from __future__ import annotations + +import argparse +import collections +from functools import update_wrapper +import inspect +import itertools +import operator +import os +import re +import sys +from typing import TYPE_CHECKING +import uuid + +import pytest + +try: + # installed by bootstrap.py + if not TYPE_CHECKING: + import sqla_plugin_base as plugin_base +except ImportError: + # assume we're a package, use traditional import + from . import plugin_base + + +def pytest_addoption(parser): + group = parser.getgroup("sqlalchemy") + + def make_option(name, **kw): + callback_ = kw.pop("callback", None) + if callback_: + + class CallableAction(argparse.Action): + def __call__( + self, parser, namespace, values, option_string=None + ): + callback_(option_string, values, parser) + + kw["action"] = CallableAction + + zeroarg_callback = kw.pop("zeroarg_callback", None) + if zeroarg_callback: + + class CallableAction(argparse.Action): + def __init__( + self, + option_strings, + dest, + default=False, + required=False, + help=None, # noqa + ): + super().__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + const=True, + default=default, + required=required, + help=help, + ) + + def __call__( + self, parser, namespace, values, option_string=None + ): + zeroarg_callback(option_string, values, parser) + + kw["action"] = CallableAction + + group.addoption(name, **kw) + + plugin_base.setup_options(make_option) + + +def pytest_configure(config: pytest.Config): + plugin_base.read_config(config.rootpath) + if plugin_base.exclude_tags or plugin_base.include_tags: + new_expr = " and ".join( + list(plugin_base.include_tags) + + [f"not {tag}" for tag in plugin_base.exclude_tags] + ) + + if config.option.markexpr: + config.option.markexpr += f" and {new_expr}" + else: + config.option.markexpr = new_expr + + if config.pluginmanager.hasplugin("xdist"): + config.pluginmanager.register(XDistHooks()) + + if hasattr(config, "workerinput"): + plugin_base.restore_important_follower_config(config.workerinput) + plugin_base.configure_follower(config.workerinput["follower_ident"]) + else: + if config.option.write_idents and os.path.exists( + config.option.write_idents + ): + os.remove(config.option.write_idents) + + plugin_base.pre_begin(config.option) + + plugin_base.set_coverage_flag( + bool(getattr(config.option, "cov_source", False)) + ) + + plugin_base.set_fixture_functions(PytestFixtureFunctions) + + if config.option.dump_pyannotate: + global DUMP_PYANNOTATE + DUMP_PYANNOTATE = True + + +DUMP_PYANNOTATE = False + + +@pytest.fixture(autouse=True) +def collect_types_fixture(): + if DUMP_PYANNOTATE: + from pyannotate_runtime import collect_types + + collect_types.start() + yield + if DUMP_PYANNOTATE: + collect_types.stop() + + +def _log_sqlalchemy_info(session): + import sqlalchemy + from sqlalchemy import __version__ + from sqlalchemy.util import has_compiled_ext + from sqlalchemy.util._has_cy import _CYEXTENSION_MSG + + greet = "sqlalchemy installation" + site = "no user site" if sys.flags.no_user_site else "user site loaded" + msgs = [ + f"SQLAlchemy {__version__} ({site})", + f"Path: {sqlalchemy.__file__}", + ] + + if has_compiled_ext(): + from sqlalchemy.cyextension import util + + msgs.append(f"compiled extension enabled, e.g. {util.__file__} ") + else: + msgs.append(f"compiled extension not enabled; {_CYEXTENSION_MSG}") + + pm = session.config.pluginmanager.get_plugin("terminalreporter") + if pm: + pm.write_sep("=", greet) + for m in msgs: + pm.write_line(m) + else: + # fancy pants reporter not found, fallback to plain print + print("=" * 25, greet, "=" * 25) + for m in msgs: + print(m) + + +def pytest_sessionstart(session): + from sqlalchemy.testing import asyncio + + _log_sqlalchemy_info(session) + asyncio._assume_async(plugin_base.post_begin) + + +def pytest_sessionfinish(session): + from sqlalchemy.testing import asyncio + + asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup) + + if session.config.option.dump_pyannotate: + from pyannotate_runtime import collect_types + + collect_types.dump_stats(session.config.option.dump_pyannotate) + + +def pytest_unconfigure(config): + from sqlalchemy.testing import asyncio + + asyncio._shutdown() + + +def pytest_collection_finish(session): + if session.config.option.dump_pyannotate: + from pyannotate_runtime import collect_types + + lib_sqlalchemy = os.path.abspath("lib/sqlalchemy") + + def _filter(filename): + filename = os.path.normpath(os.path.abspath(filename)) + if "lib/sqlalchemy" not in os.path.commonpath( + [filename, lib_sqlalchemy] + ): + return None + if "testing" in filename: + return None + + return filename + + collect_types.init_types_collection(filter_filename=_filter) + + +class XDistHooks: + def pytest_configure_node(self, node): + from sqlalchemy.testing import provision + from sqlalchemy.testing import asyncio + + # the master for each node fills workerinput dictionary + # which pytest-xdist will transfer to the subprocess + + plugin_base.memoize_important_follower_config(node.workerinput) + + node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12] + + asyncio._maybe_async_provisioning( + provision.create_follower_db, node.workerinput["follower_ident"] + ) + + def pytest_testnodedown(self, node, error): + from sqlalchemy.testing import provision + from sqlalchemy.testing import asyncio + + asyncio._maybe_async_provisioning( + provision.drop_follower_db, node.workerinput["follower_ident"] + ) + + +def pytest_collection_modifyitems(session, config, items): + # look for all those classes that specify __backend__ and + # expand them out into per-database test cases. + + # this is much easier to do within pytest_pycollect_makeitem, however + # pytest is iterating through cls.__dict__ as makeitem is + # called which causes a "dictionary changed size" error on py3k. + # I'd submit a pullreq for them to turn it into a list first, but + # it's to suit the rather odd use case here which is that we are adding + # new classes to a module on the fly. + + from sqlalchemy.testing import asyncio + + rebuilt_items = collections.defaultdict( + lambda: collections.defaultdict(list) + ) + + items[:] = [ + item + for item in items + if item.getparent(pytest.Class) is not None + and not item.getparent(pytest.Class).name.startswith("_") + ] + + test_classes = {item.getparent(pytest.Class) for item in items} + + def collect(element): + for inst_or_fn in element.collect(): + if isinstance(inst_or_fn, pytest.Collector): + yield from collect(inst_or_fn) + else: + yield inst_or_fn + + def setup_test_classes(): + for test_class in test_classes: + # transfer legacy __backend__ and __sparse_backend__ symbols + # to be markers + add_markers = set() + if getattr(test_class.cls, "__backend__", False) or getattr( + test_class.cls, "__only_on__", False + ): + add_markers = {"backend"} + elif getattr(test_class.cls, "__sparse_backend__", False): + add_markers = {"sparse_backend"} + else: + add_markers = frozenset() + + existing_markers = { + mark.name for mark in test_class.iter_markers() + } + add_markers = add_markers - existing_markers + all_markers = existing_markers.union(add_markers) + + for marker in add_markers: + test_class.add_marker(marker) + + for sub_cls in plugin_base.generate_sub_tests( + test_class.cls, test_class.module, all_markers + ): + if sub_cls is not test_class.cls: + per_cls_dict = rebuilt_items[test_class.cls] + + module = test_class.getparent(pytest.Module) + + new_cls = pytest.Class.from_parent( + name=sub_cls.__name__, parent=module + ) + for marker in add_markers: + new_cls.add_marker(marker) + + for fn in collect(new_cls): + per_cls_dict[fn.name].append(fn) + + # class requirements will sometimes need to access the DB to check + # capabilities, so need to do this for async + asyncio._maybe_async_provisioning(setup_test_classes) + + newitems = [] + for item in items: + cls_ = item.cls + if cls_ in rebuilt_items: + newitems.extend(rebuilt_items[cls_][item.name]) + else: + newitems.append(item) + + # seems like the functions attached to a test class aren't sorted already? + # is that true and why's that? (when using unittest, they're sorted) + items[:] = sorted( + newitems, + key=lambda item: ( + item.getparent(pytest.Module).name, + item.getparent(pytest.Class).name, + item.name, + ), + ) + + +def pytest_pycollect_makeitem(collector, name, obj): + if inspect.isclass(obj) and plugin_base.want_class(name, obj): + from sqlalchemy.testing import config + + if config.any_async: + obj = _apply_maybe_async(obj) + + return [ + pytest.Class.from_parent( + name=parametrize_cls.__name__, parent=collector + ) + for parametrize_cls in _parametrize_cls(collector.module, obj) + ] + elif ( + inspect.isfunction(obj) + and collector.cls is not None + and plugin_base.want_method(collector.cls, obj) + ): + # None means, fall back to default logic, which includes + # method-level parametrize + return None + else: + # empty list means skip this item + return [] + + +def _is_wrapped_coroutine_function(fn): + while hasattr(fn, "__wrapped__"): + fn = fn.__wrapped__ + + return inspect.iscoroutinefunction(fn) + + +def _apply_maybe_async(obj, recurse=True): + from sqlalchemy.testing import asyncio + + for name, value in vars(obj).items(): + if ( + (callable(value) or isinstance(value, classmethod)) + and not getattr(value, "_maybe_async_applied", False) + and (name.startswith("test_")) + and not _is_wrapped_coroutine_function(value) + ): + is_classmethod = False + if isinstance(value, classmethod): + value = value.__func__ + is_classmethod = True + + @_pytest_fn_decorator + def make_async(fn, *args, **kwargs): + return asyncio._maybe_async(fn, *args, **kwargs) + + do_async = make_async(value) + if is_classmethod: + do_async = classmethod(do_async) + do_async._maybe_async_applied = True + + setattr(obj, name, do_async) + if recurse: + for cls in obj.mro()[1:]: + if cls != object: + _apply_maybe_async(cls, False) + return obj + + +def _parametrize_cls(module, cls): + """implement a class-based version of pytest parametrize.""" + + if "_sa_parametrize" not in cls.__dict__: + return [cls] + + _sa_parametrize = cls._sa_parametrize + classes = [] + for full_param_set in itertools.product( + *[params for argname, params in _sa_parametrize] + ): + cls_variables = {} + + for argname, param in zip( + [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set + ): + if not argname: + raise TypeError("need argnames for class-based combinations") + argname_split = re.split(r",\s*", argname) + for arg, val in zip(argname_split, param.values): + cls_variables[arg] = val + parametrized_name = "_".join( + re.sub(r"\W", "", token) + for param in full_param_set + for token in param.id.split("-") + ) + name = "%s_%s" % (cls.__name__, parametrized_name) + newcls = type.__new__(type, name, (cls,), cls_variables) + setattr(module, name, newcls) + classes.append(newcls) + return classes + + +_current_class = None + + +def pytest_runtest_setup(item): + from sqlalchemy.testing import asyncio + + # pytest_runtest_setup runs *before* pytest fixtures with scope="class". + # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest + # for the whole class and has to run things that are across all current + # databases, so we run this outside of the pytest fixture system altogether + # and ensure asyncio greenlet if any engines are async + + global _current_class + + if isinstance(item, pytest.Function) and _current_class is None: + asyncio._maybe_async_provisioning( + plugin_base.start_test_class_outside_fixtures, + item.cls, + ) + _current_class = item.getparent(pytest.Class) + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_teardown(item, nextitem): + # runs inside of pytest function fixture scope + # after test function runs + + from sqlalchemy.testing import asyncio + + asyncio._maybe_async(plugin_base.after_test, item) + + yield + # this is now after all the fixture teardown have run, the class can be + # finalized. Since pytest v7 this finalizer can no longer be added in + # pytest_runtest_setup since the class has not yet been setup at that + # time. + # See https://github.com/pytest-dev/pytest/issues/9343 + global _current_class, _current_report + + if _current_class is not None and ( + # last test or a new class + nextitem is None + or nextitem.getparent(pytest.Class) is not _current_class + ): + _current_class = None + + try: + asyncio._maybe_async_provisioning( + plugin_base.stop_test_class_outside_fixtures, item.cls + ) + except Exception as e: + # in case of an exception during teardown attach the original + # error to the exception message, otherwise it will get lost + if _current_report.failed: + if not e.args: + e.args = ( + "__Original test failure__:\n" + + _current_report.longreprtext, + ) + elif e.args[-1] and isinstance(e.args[-1], str): + args = list(e.args) + args[-1] += ( + "\n__Original test failure__:\n" + + _current_report.longreprtext + ) + e.args = tuple(args) + else: + e.args += ( + "__Original test failure__", + _current_report.longreprtext, + ) + raise + finally: + _current_report = None + + +def pytest_runtest_call(item): + # runs inside of pytest function fixture scope + # before test function runs + + from sqlalchemy.testing import asyncio + + asyncio._maybe_async( + plugin_base.before_test, + item, + item.module.__name__, + item.cls, + item.name, + ) + + +_current_report = None + + +def pytest_runtest_logreport(report): + global _current_report + if report.when == "call": + _current_report = report + + +@pytest.fixture(scope="class") +def setup_class_methods(request): + from sqlalchemy.testing import asyncio + + cls = request.cls + + if hasattr(cls, "setup_test_class"): + asyncio._maybe_async(cls.setup_test_class) + + yield + + if hasattr(cls, "teardown_test_class"): + asyncio._maybe_async(cls.teardown_test_class) + + asyncio._maybe_async(plugin_base.stop_test_class, cls) + + +@pytest.fixture(scope="function") +def setup_test_methods(request): + from sqlalchemy.testing import asyncio + + # called for each test + + self = request.instance + + # before this fixture runs: + + # 1. function level "autouse" fixtures under py3k (examples: TablesTest + # define tables / data, MappedTest define tables / mappers / data) + + # 2. was for p2k. no longer applies + + # 3. run outer xdist-style setup + if hasattr(self, "setup_test"): + asyncio._maybe_async(self.setup_test) + + # alembic test suite is using setUp and tearDown + # xdist methods; support these in the test suite + # for the near term + if hasattr(self, "setUp"): + asyncio._maybe_async(self.setUp) + + # inside the yield: + # 4. function level fixtures defined on test functions themselves, + # e.g. "connection", "metadata" run next + + # 5. pytest hook pytest_runtest_call then runs + + # 6. test itself runs + + yield + + # yield finishes: + + # 7. function level fixtures defined on test functions + # themselves, e.g. "connection" rolls back the transaction, "metadata" + # emits drop all + + # 8. pytest hook pytest_runtest_teardown hook runs, this is associated + # with fixtures close all sessions, provisioning.stop_test_class(), + # engines.testing_reaper -> ensure all connection pool connections + # are returned, engines created by testing_engine that aren't the + # config engine are disposed + + asyncio._maybe_async(plugin_base.after_test_fixtures, self) + + # 10. run xdist-style teardown + if hasattr(self, "tearDown"): + asyncio._maybe_async(self.tearDown) + + if hasattr(self, "teardown_test"): + asyncio._maybe_async(self.teardown_test) + + # 11. was for p2k. no longer applies + + # 12. function level "autouse" fixtures under py3k (examples: TablesTest / + # MappedTest delete table data, possibly drop tables and clear mappers + # depending on the flags defined by the test class) + + +def _pytest_fn_decorator(target): + """Port of langhelpers.decorator with pytest-specific tricks.""" + + from sqlalchemy.util.langhelpers import format_argspec_plus + from sqlalchemy.util.compat import inspect_getfullargspec + + def _exec_code_in_env(code, env, fn_name): + # note this is affected by "from __future__ import annotations" at + # the top; exec'ed code will use non-evaluated annotations + # which allows us to be more flexible with code rendering + # in format_argpsec_plus() + exec(code, env) + return env[fn_name] + + def decorate(fn, add_positional_parameters=()): + spec = inspect_getfullargspec(fn) + if add_positional_parameters: + spec.args.extend(add_positional_parameters) + + metadata = dict( + __target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__ + ) + metadata.update(format_argspec_plus(spec, grouped=False)) + code = ( + """\ +def %(name)s%(grouped_args)s: + return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s) +""" + % metadata + ) + decorated = _exec_code_in_env( + code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__ + ) + if not add_positional_parameters: + decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__ + decorated.__wrapped__ = fn + return update_wrapper(decorated, fn) + else: + # this is the pytest hacky part. don't do a full update wrapper + # because pytest is really being sneaky about finding the args + # for the wrapped function + decorated.__module__ = fn.__module__ + decorated.__name__ = fn.__name__ + if hasattr(fn, "pytestmark"): + decorated.pytestmark = fn.pytestmark + return decorated + + return decorate + + +class PytestFixtureFunctions(plugin_base.FixtureFunctions): + def skip_test_exception(self, *arg, **kw): + return pytest.skip.Exception(*arg, **kw) + + @property + def add_to_marker(self): + return pytest.mark + + def mark_base_test_class(self): + return pytest.mark.usefixtures( + "setup_class_methods", "setup_test_methods" + ) + + _combination_id_fns = { + "i": lambda obj: obj, + "r": repr, + "s": str, + "n": lambda obj: ( + obj.__name__ if hasattr(obj, "__name__") else type(obj).__name__ + ), + } + + def combinations(self, *arg_sets, **kw): + """Facade for pytest.mark.parametrize. + + Automatically derives argument names from the callable which in our + case is always a method on a class with positional arguments. + + ids for parameter sets are derived using an optional template. + + """ + from sqlalchemy.testing import exclusions + + if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"): + arg_sets = list(arg_sets[0]) + + argnames = kw.pop("argnames", None) + + def _filter_exclusions(args): + result = [] + gathered_exclusions = [] + for a in args: + if isinstance(a, exclusions.compound): + gathered_exclusions.append(a) + else: + result.append(a) + + return result, gathered_exclusions + + id_ = kw.pop("id_", None) + + tobuild_pytest_params = [] + has_exclusions = False + if id_: + _combination_id_fns = self._combination_id_fns + + # because itemgetter is not consistent for one argument vs. + # multiple, make it multiple in all cases and use a slice + # to omit the first argument + _arg_getter = operator.itemgetter( + 0, + *[ + idx + for idx, char in enumerate(id_) + if char in ("n", "r", "s", "a") + ], + ) + fns = [ + (operator.itemgetter(idx), _combination_id_fns[char]) + for idx, char in enumerate(id_) + if char in _combination_id_fns + ] + + for arg in arg_sets: + if not isinstance(arg, tuple): + arg = (arg,) + + fn_params, param_exclusions = _filter_exclusions(arg) + + parameters = _arg_getter(fn_params)[1:] + + if param_exclusions: + has_exclusions = True + + tobuild_pytest_params.append( + ( + parameters, + param_exclusions, + "-".join( + comb_fn(getter(arg)) for getter, comb_fn in fns + ), + ) + ) + + else: + for arg in arg_sets: + if not isinstance(arg, tuple): + arg = (arg,) + + fn_params, param_exclusions = _filter_exclusions(arg) + + if param_exclusions: + has_exclusions = True + + tobuild_pytest_params.append( + (fn_params, param_exclusions, None) + ) + + pytest_params = [] + for parameters, param_exclusions, id_ in tobuild_pytest_params: + if has_exclusions: + parameters += (param_exclusions,) + + param = pytest.param(*parameters, id=id_) + pytest_params.append(param) + + def decorate(fn): + if inspect.isclass(fn): + if has_exclusions: + raise NotImplementedError( + "exclusions not supported for class level combinations" + ) + if "_sa_parametrize" not in fn.__dict__: + fn._sa_parametrize = [] + fn._sa_parametrize.append((argnames, pytest_params)) + return fn + else: + _fn_argnames = inspect.getfullargspec(fn).args[1:] + if argnames is None: + _argnames = _fn_argnames + else: + _argnames = re.split(r", *", argnames) + + if has_exclusions: + existing_exl = sum( + 1 for n in _fn_argnames if n.startswith("_exclusions") + ) + current_exclusion_name = f"_exclusions_{existing_exl}" + _argnames += [current_exclusion_name] + + @_pytest_fn_decorator + def check_exclusions(fn, *args, **kw): + _exclusions = args[-1] + if _exclusions: + exlu = exclusions.compound().add(*_exclusions) + fn = exlu(fn) + return fn(*args[:-1], **kw) + + fn = check_exclusions( + fn, add_positional_parameters=(current_exclusion_name,) + ) + + return pytest.mark.parametrize(_argnames, pytest_params)(fn) + + return decorate + + def param_ident(self, *parameters): + ident = parameters[0] + return pytest.param(*parameters[1:], id=ident) + + def fixture(self, *arg, **kw): + from sqlalchemy.testing import config + from sqlalchemy.testing import asyncio + + # wrapping pytest.fixture function. determine if + # decorator was called as @fixture or @fixture(). + if len(arg) > 0 and callable(arg[0]): + # was called as @fixture(), we have the function to wrap. + fn = arg[0] + arg = arg[1:] + else: + # was called as @fixture, don't have the function yet. + fn = None + + # create a pytest.fixture marker. because the fn is not being + # passed, this is always a pytest.FixtureFunctionMarker() + # object (or whatever pytest is calling it when you read this) + # that is waiting for a function. + fixture = pytest.fixture(*arg, **kw) + + # now apply wrappers to the function, including fixture itself + + def wrap(fn): + if config.any_async: + fn = asyncio._maybe_async_wrapper(fn) + # other wrappers may be added here + + # now apply FixtureFunctionMarker + fn = fixture(fn) + + return fn + + if fn: + return wrap(fn) + else: + return wrap + + def get_current_test_name(self): + return os.environ.get("PYTEST_CURRENT_TEST") + + def async_test(self, fn): + from sqlalchemy.testing import asyncio + + @_pytest_fn_decorator + def decorate(fn, *args, **kwargs): + asyncio._run_coroutine_function(fn, *args, **kwargs) + + return decorate(fn) |