From 12cf076118570eebbff08c6b3090e0d4798447a1 Mon Sep 17 00:00:00 2001 From: cyfraeviolae Date: Wed, 3 Apr 2024 03:17:55 -0400 Subject: no venv --- .../site-packages/sqlalchemy/testing/assertions.py | 989 --------------------- 1 file changed, 989 deletions(-) delete mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/assertions.py (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/testing/assertions.py') diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertions.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertions.py deleted file mode 100644 index baef79d..0000000 --- a/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertions.py +++ /dev/null @@ -1,989 +0,0 @@ -# testing/assertions.py -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors - - -from __future__ import annotations - -from collections import defaultdict -import contextlib -from copy import copy -from itertools import filterfalse -import re -import sys -import warnings - -from . import assertsql -from . import config -from . import engines -from . import mock -from .exclusions import db_spec -from .util import fail -from .. import exc as sa_exc -from .. import schema -from .. import sql -from .. import types as sqltypes -from .. import util -from ..engine import default -from ..engine import url -from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL -from ..util import decorator - - -def expect_warnings(*messages, **kw): - """Context manager which expects one or more warnings. - - With no arguments, squelches all SAWarning emitted via - sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise - pass string expressions that will match selected warnings via regex; - all non-matching warnings are sent through. - - The expect version **asserts** that the warnings were in fact seen. - - Note that the test suite sets SAWarning warnings to raise exceptions. - - """ # noqa - return _expect_warnings_sqla_only(sa_exc.SAWarning, messages, **kw) - - -@contextlib.contextmanager -def expect_warnings_on(db, *messages, **kw): - """Context manager which expects one or more warnings on specific - dialects. - - The expect version **asserts** that the warnings were in fact seen. - - """ - spec = db_spec(db) - - if isinstance(db, str) and not spec(config._current): - yield - else: - with expect_warnings(*messages, **kw): - yield - - -def emits_warning(*messages): - """Decorator form of expect_warnings(). - - Note that emits_warning does **not** assert that the warnings - were in fact seen. - - """ - - @decorator - def decorate(fn, *args, **kw): - with expect_warnings(assert_=False, *messages): - return fn(*args, **kw) - - return decorate - - -def expect_deprecated(*messages, **kw): - return _expect_warnings_sqla_only( - sa_exc.SADeprecationWarning, messages, **kw - ) - - -def expect_deprecated_20(*messages, **kw): - return _expect_warnings_sqla_only( - sa_exc.Base20DeprecationWarning, messages, **kw - ) - - -def emits_warning_on(db, *messages): - """Mark a test as emitting a warning on a specific dialect. - - With no arguments, squelches all SAWarning failures. Or pass one or more - strings; these will be matched to the root of the warning description by - warnings.filterwarnings(). - - Note that emits_warning_on does **not** assert that the warnings - were in fact seen. - - """ - - @decorator - def decorate(fn, *args, **kw): - with expect_warnings_on(db, assert_=False, *messages): - return fn(*args, **kw) - - return decorate - - -def uses_deprecated(*messages): - """Mark a test as immune from fatal deprecation warnings. - - With no arguments, squelches all SADeprecationWarning failures. - Or pass one or more strings; these will be matched to the root - of the warning description by warnings.filterwarnings(). - - As a special case, you may pass a function name prefixed with // - and it will be re-written as needed to match the standard warning - verbiage emitted by the sqlalchemy.util.deprecated decorator. - - Note that uses_deprecated does **not** assert that the warnings - were in fact seen. - - """ - - @decorator - def decorate(fn, *args, **kw): - with expect_deprecated(*messages, assert_=False): - return fn(*args, **kw) - - return decorate - - -_FILTERS = None -_SEEN = None -_EXC_CLS = None - - -def _expect_warnings_sqla_only( - exc_cls, - messages, - regex=True, - search_msg=False, - assert_=True, -): - """SQLAlchemy internal use only _expect_warnings(). - - Alembic is using _expect_warnings() directly, and should be updated - to use this new interface. - - """ - return _expect_warnings( - exc_cls, - messages, - regex=regex, - search_msg=search_msg, - assert_=assert_, - raise_on_any_unexpected=True, - ) - - -@contextlib.contextmanager -def _expect_warnings( - exc_cls, - messages, - regex=True, - search_msg=False, - assert_=True, - raise_on_any_unexpected=False, - squelch_other_warnings=False, -): - global _FILTERS, _SEEN, _EXC_CLS - - if regex or search_msg: - filters = [re.compile(msg, re.I | re.S) for msg in messages] - else: - filters = list(messages) - - if _FILTERS is not None: - # nested call; update _FILTERS and _SEEN, return. outer - # block will assert our messages - assert _SEEN is not None - assert _EXC_CLS is not None - _FILTERS.extend(filters) - _SEEN.update(filters) - _EXC_CLS += (exc_cls,) - yield - else: - seen = _SEEN = set(filters) - _FILTERS = filters - _EXC_CLS = (exc_cls,) - - if raise_on_any_unexpected: - - def real_warn(msg, *arg, **kw): - raise AssertionError("Got unexpected warning: %r" % msg) - - else: - real_warn = warnings.warn - - def our_warn(msg, *arg, **kw): - if isinstance(msg, _EXC_CLS): - exception = type(msg) - msg = str(msg) - elif arg: - exception = arg[0] - else: - exception = None - - if not exception or not issubclass(exception, _EXC_CLS): - if not squelch_other_warnings: - return real_warn(msg, *arg, **kw) - else: - return - - if not filters and not raise_on_any_unexpected: - return - - for filter_ in filters: - if ( - (search_msg and filter_.search(msg)) - or (regex and filter_.match(msg)) - or (not regex and filter_ == msg) - ): - seen.discard(filter_) - break - else: - if not squelch_other_warnings: - real_warn(msg, *arg, **kw) - - with mock.patch("warnings.warn", our_warn): - try: - yield - finally: - _SEEN = _FILTERS = _EXC_CLS = None - - if assert_: - assert not seen, "Warnings were not seen: %s" % ", ".join( - "%r" % (s.pattern if regex else s) for s in seen - ) - - -def global_cleanup_assertions(): - """Check things that have to be finalized at the end of a test suite. - - Hardcoded at the moment, a modular system can be built here - to support things like PG prepared transactions, tables all - dropped, etc. - - """ - _assert_no_stray_pool_connections() - - -def _assert_no_stray_pool_connections(): - engines.testing_reaper.assert_all_closed() - - -def int_within_variance(expected, received, variance): - deviance = int(expected * variance) - assert ( - abs(received - expected) < deviance - ), "Given int value %s is not within %d%% of expected value %s" % ( - received, - variance * 100, - expected, - ) - - -def eq_regex(a, b, msg=None): - assert re.match(b, a), msg or "%r !~ %r" % (a, b) - - -def eq_(a, b, msg=None): - """Assert a == b, with repr messaging on failure.""" - assert a == b, msg or "%r != %r" % (a, b) - - -def ne_(a, b, msg=None): - """Assert a != b, with repr messaging on failure.""" - assert a != b, msg or "%r == %r" % (a, b) - - -def le_(a, b, msg=None): - """Assert a <= b, with repr messaging on failure.""" - assert a <= b, msg or "%r != %r" % (a, b) - - -def is_instance_of(a, b, msg=None): - assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b) - - -def is_none(a, msg=None): - is_(a, None, msg=msg) - - -def is_not_none(a, msg=None): - is_not(a, None, msg=msg) - - -def is_true(a, msg=None): - is_(bool(a), True, msg=msg) - - -def is_false(a, msg=None): - is_(bool(a), False, msg=msg) - - -def is_(a, b, msg=None): - """Assert a is b, with repr messaging on failure.""" - assert a is b, msg or "%r is not %r" % (a, b) - - -def is_not(a, b, msg=None): - """Assert a is not b, with repr messaging on failure.""" - assert a is not b, msg or "%r is %r" % (a, b) - - -# deprecated. See #5429 -is_not_ = is_not - - -def in_(a, b, msg=None): - """Assert a in b, with repr messaging on failure.""" - assert a in b, msg or "%r not in %r" % (a, b) - - -def not_in(a, b, msg=None): - """Assert a in not b, with repr messaging on failure.""" - assert a not in b, msg or "%r is in %r" % (a, b) - - -# deprecated. See #5429 -not_in_ = not_in - - -def startswith_(a, fragment, msg=None): - """Assert a.startswith(fragment), with repr messaging on failure.""" - assert a.startswith(fragment), msg or "%r does not start with %r" % ( - a, - fragment, - ) - - -def eq_ignore_whitespace(a, b, msg=None): - a = re.sub(r"^\s+?|\n", "", a) - a = re.sub(r" {2,}", " ", a) - a = re.sub(r"\t", "", a) - b = re.sub(r"^\s+?|\n", "", b) - b = re.sub(r" {2,}", " ", b) - b = re.sub(r"\t", "", b) - - assert a == b, msg or "%r != %r" % (a, b) - - -def _assert_proper_exception_context(exception): - """assert that any exception we're catching does not have a __context__ - without a __cause__, and that __suppress_context__ is never set. - - Python 3 will report nested as exceptions as "during the handling of - error X, error Y occurred". That's not what we want to do. we want - these exceptions in a cause chain. - - """ - - if ( - exception.__context__ is not exception.__cause__ - and not exception.__suppress_context__ - ): - assert False, ( - "Exception %r was correctly raised but did not set a cause, " - "within context %r as its cause." - % (exception, exception.__context__) - ) - - -def assert_raises(except_cls, callable_, *args, **kw): - return _assert_raises(except_cls, callable_, args, kw, check_context=True) - - -def assert_raises_context_ok(except_cls, callable_, *args, **kw): - return _assert_raises(except_cls, callable_, args, kw) - - -def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): - return _assert_raises( - except_cls, callable_, args, kwargs, msg=msg, check_context=True - ) - - -def assert_warns(except_cls, callable_, *args, **kwargs): - """legacy adapter function for functions that were previously using - assert_raises with SAWarning or similar. - - has some workarounds to accommodate the fact that the callable completes - with this approach rather than stopping at the exception raise. - - - """ - with _expect_warnings_sqla_only(except_cls, [".*"]): - return callable_(*args, **kwargs) - - -def assert_warns_message(except_cls, msg, callable_, *args, **kwargs): - """legacy adapter function for functions that were previously using - assert_raises with SAWarning or similar. - - has some workarounds to accommodate the fact that the callable completes - with this approach rather than stopping at the exception raise. - - Also uses regex.search() to match the given message to the error string - rather than regex.match(). - - """ - with _expect_warnings_sqla_only( - except_cls, - [msg], - search_msg=True, - regex=False, - ): - return callable_(*args, **kwargs) - - -def assert_raises_message_context_ok( - except_cls, msg, callable_, *args, **kwargs -): - return _assert_raises(except_cls, callable_, args, kwargs, msg=msg) - - -def _assert_raises( - except_cls, callable_, args, kwargs, msg=None, check_context=False -): - with _expect_raises(except_cls, msg, check_context) as ec: - callable_(*args, **kwargs) - return ec.error - - -class _ErrorContainer: - error = None - - -@contextlib.contextmanager -def _expect_raises(except_cls, msg=None, check_context=False): - if ( - isinstance(except_cls, type) - and issubclass(except_cls, Warning) - or isinstance(except_cls, Warning) - ): - raise TypeError( - "Use expect_warnings for warnings, not " - "expect_raises / assert_raises" - ) - ec = _ErrorContainer() - if check_context: - are_we_already_in_a_traceback = sys.exc_info()[0] - try: - yield ec - success = False - except except_cls as err: - ec.error = err - success = True - if msg is not None: - # I'm often pdbing here, and "err" above isn't - # in scope, so assign the string explicitly - error_as_string = str(err) - assert re.search(msg, error_as_string, re.UNICODE), "%r !~ %s" % ( - msg, - error_as_string, - ) - if check_context and not are_we_already_in_a_traceback: - _assert_proper_exception_context(err) - print(str(err).encode("utf-8")) - - # it's generally a good idea to not carry traceback objects outside - # of the except: block, but in this case especially we seem to have - # hit some bug in either python 3.10.0b2 or greenlet or both which - # this seems to fix: - # https://github.com/python-greenlet/greenlet/issues/242 - del ec - - # assert outside the block so it works for AssertionError too ! - assert success, "Callable did not raise an exception" - - -def expect_raises(except_cls, check_context=True): - return _expect_raises(except_cls, check_context=check_context) - - -def expect_raises_message(except_cls, msg, check_context=True): - return _expect_raises(except_cls, msg=msg, check_context=check_context) - - -class AssertsCompiledSQL: - def assert_compile( - self, - clause, - result, - params=None, - checkparams=None, - for_executemany=False, - check_literal_execute=None, - check_post_param=None, - dialect=None, - checkpositional=None, - check_prefetch=None, - use_default_dialect=False, - allow_dialect_select=False, - supports_default_values=True, - supports_default_metavalue=True, - literal_binds=False, - render_postcompile=False, - schema_translate_map=None, - render_schema_translate=False, - default_schema_name=None, - from_linting=False, - check_param_order=True, - use_literal_execute_for_simple_int=False, - ): - if use_default_dialect: - dialect = default.DefaultDialect() - dialect.supports_default_values = supports_default_values - dialect.supports_default_metavalue = supports_default_metavalue - elif allow_dialect_select: - dialect = None - else: - if dialect is None: - dialect = getattr(self, "__dialect__", None) - - if dialect is None: - dialect = config.db.dialect - elif dialect == "default" or dialect == "default_qmark": - if dialect == "default": - dialect = default.DefaultDialect() - else: - dialect = default.DefaultDialect("qmark") - dialect.supports_default_values = supports_default_values - dialect.supports_default_metavalue = supports_default_metavalue - elif dialect == "default_enhanced": - dialect = default.StrCompileDialect() - elif isinstance(dialect, str): - dialect = url.URL.create(dialect).get_dialect()() - - if default_schema_name: - dialect.default_schema_name = default_schema_name - - kw = {} - compile_kwargs = {} - - if schema_translate_map: - kw["schema_translate_map"] = schema_translate_map - - if params is not None: - kw["column_keys"] = list(params) - - if literal_binds: - compile_kwargs["literal_binds"] = True - - if render_postcompile: - compile_kwargs["render_postcompile"] = True - - if use_literal_execute_for_simple_int: - compile_kwargs["use_literal_execute_for_simple_int"] = True - - if for_executemany: - kw["for_executemany"] = True - - if render_schema_translate: - kw["render_schema_translate"] = True - - if from_linting or getattr(self, "assert_from_linting", False): - kw["linting"] = sql.FROM_LINTING - - from sqlalchemy import orm - - if isinstance(clause, orm.Query): - stmt = clause._statement_20() - stmt._label_style = LABEL_STYLE_TABLENAME_PLUS_COL - clause = stmt - - if compile_kwargs: - kw["compile_kwargs"] = compile_kwargs - - class DontAccess: - def __getattribute__(self, key): - raise NotImplementedError( - "compiler accessed .statement; use " - "compiler.current_executable" - ) - - class CheckCompilerAccess: - def __init__(self, test_statement): - self.test_statement = test_statement - self._annotations = {} - self.supports_execution = getattr( - test_statement, "supports_execution", False - ) - - if self.supports_execution: - self._execution_options = test_statement._execution_options - - if hasattr(test_statement, "_returning"): - self._returning = test_statement._returning - if hasattr(test_statement, "_inline"): - self._inline = test_statement._inline - if hasattr(test_statement, "_return_defaults"): - self._return_defaults = test_statement._return_defaults - - @property - def _variant_mapping(self): - return self.test_statement._variant_mapping - - def _default_dialect(self): - return self.test_statement._default_dialect() - - def compile(self, dialect, **kw): - return self.test_statement.compile.__func__( - self, dialect=dialect, **kw - ) - - def _compiler(self, dialect, **kw): - return self.test_statement._compiler.__func__( - self, dialect, **kw - ) - - def _compiler_dispatch(self, compiler, **kwargs): - if hasattr(compiler, "statement"): - with mock.patch.object( - compiler, "statement", DontAccess() - ): - return self.test_statement._compiler_dispatch( - compiler, **kwargs - ) - else: - return self.test_statement._compiler_dispatch( - compiler, **kwargs - ) - - # no construct can assume it's the "top level" construct in all cases - # as anything can be nested. ensure constructs don't assume they - # are the "self.statement" element - c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw) - - if isinstance(clause, sqltypes.TypeEngine): - cache_key_no_warnings = clause._static_cache_key - if cache_key_no_warnings: - hash(cache_key_no_warnings) - else: - cache_key_no_warnings = clause._generate_cache_key() - if cache_key_no_warnings: - hash(cache_key_no_warnings[0]) - - param_str = repr(getattr(c, "params", {})) - param_str = param_str.encode("utf-8").decode("ascii", "ignore") - print(("\nSQL String:\n" + str(c) + param_str).encode("utf-8")) - - cc = re.sub(r"[\n\t]", "", str(c)) - - eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) - - if checkparams is not None: - if render_postcompile: - expanded_state = c.construct_expanded_state( - params, escape_names=False - ) - eq_(expanded_state.parameters, checkparams) - else: - eq_(c.construct_params(params), checkparams) - if checkpositional is not None: - if render_postcompile: - expanded_state = c.construct_expanded_state( - params, escape_names=False - ) - eq_( - tuple( - [ - expanded_state.parameters[x] - for x in expanded_state.positiontup - ] - ), - checkpositional, - ) - else: - p = c.construct_params(params, escape_names=False) - eq_(tuple([p[x] for x in c.positiontup]), checkpositional) - if check_prefetch is not None: - eq_(c.prefetch, check_prefetch) - if check_literal_execute is not None: - eq_( - { - c.bind_names[b]: b.effective_value - for b in c.literal_execute_params - }, - check_literal_execute, - ) - if check_post_param is not None: - eq_( - { - c.bind_names[b]: b.effective_value - for b in c.post_compile_params - }, - check_post_param, - ) - if check_param_order and getattr(c, "params", None): - - def get_dialect(paramstyle, positional): - cp = copy(dialect) - cp.paramstyle = paramstyle - cp.positional = positional - return cp - - pyformat_dialect = get_dialect("pyformat", False) - pyformat_c = clause.compile(dialect=pyformat_dialect, **kw) - stmt = re.sub(r"[\n\t]", "", str(pyformat_c)) - - qmark_dialect = get_dialect("qmark", True) - qmark_c = clause.compile(dialect=qmark_dialect, **kw) - values = list(qmark_c.positiontup) - escaped = qmark_c.escaped_bind_names - - for post_param in ( - qmark_c.post_compile_params | qmark_c.literal_execute_params - ): - name = qmark_c.bind_names[post_param] - if name in values: - values = [v for v in values if v != name] - positions = [] - pos_by_value = defaultdict(list) - for v in values: - try: - if v in pos_by_value: - start = pos_by_value[v][-1] - else: - start = 0 - esc = escaped.get(v, v) - pos = stmt.index("%%(%s)s" % (esc,), start) + 2 - positions.append(pos) - pos_by_value[v].append(pos) - except ValueError: - msg = "Expected to find bindparam %r in %r" % (v, stmt) - assert False, msg - - ordered = all( - positions[i - 1] < positions[i] - for i in range(1, len(positions)) - ) - - expected = [v for _, v in sorted(zip(positions, values))] - - msg = ( - "Order of parameters %s does not match the order " - "in the statement %s. Statement %r" % (values, expected, stmt) - ) - - is_true(ordered, msg) - - -class ComparesTables: - def assert_tables_equal( - self, - table, - reflected_table, - strict_types=False, - strict_constraints=True, - ): - assert len(table.c) == len(reflected_table.c) - for c, reflected_c in zip(table.c, reflected_table.c): - eq_(c.name, reflected_c.name) - assert reflected_c is reflected_table.c[c.name] - - if strict_constraints: - eq_(c.primary_key, reflected_c.primary_key) - eq_(c.nullable, reflected_c.nullable) - - if strict_types: - msg = "Type '%s' doesn't correspond to type '%s'" - assert isinstance(reflected_c.type, type(c.type)), msg % ( - reflected_c.type, - c.type, - ) - else: - self.assert_types_base(reflected_c, c) - - if isinstance(c.type, sqltypes.String): - eq_(c.type.length, reflected_c.type.length) - - if strict_constraints: - eq_( - {f.column.name for f in c.foreign_keys}, - {f.column.name for f in reflected_c.foreign_keys}, - ) - if c.server_default: - assert isinstance( - reflected_c.server_default, schema.FetchedValue - ) - - if strict_constraints: - assert len(table.primary_key) == len(reflected_table.primary_key) - for c in table.primary_key: - assert reflected_table.primary_key.columns[c.name] is not None - - def assert_types_base(self, c1, c2): - assert c1.type._compare_type_affinity( - c2.type - ), "On column %r, type '%s' doesn't correspond to type '%s'" % ( - c1.name, - c1.type, - c2.type, - ) - - -class AssertsExecutionResults: - def assert_result(self, result, class_, *objects): - result = list(result) - print(repr(result)) - self.assert_list(result, class_, objects) - - def assert_list(self, result, class_, list_): - self.assert_( - len(result) == len(list_), - "result list is not the same size as test list, " - + "for class " - + class_.__name__, - ) - for i in range(0, len(list_)): - self.assert_row(class_, result[i], list_[i]) - - def assert_row(self, class_, rowobj, desc): - self.assert_( - rowobj.__class__ is class_, "item class is not " + repr(class_) - ) - for key, value in desc.items(): - if isinstance(value, tuple): - if isinstance(value[1], list): - self.assert_list(getattr(rowobj, key), value[0], value[1]) - else: - self.assert_row(value[0], getattr(rowobj, key), value[1]) - else: - self.assert_( - getattr(rowobj, key) == value, - "attribute %s value %s does not match %s" - % (key, getattr(rowobj, key), value), - ) - - def assert_unordered_result(self, result, cls, *expected): - """As assert_result, but the order of objects is not considered. - - The algorithm is very expensive but not a big deal for the small - numbers of rows that the test suite manipulates. - """ - - class immutabledict(dict): - def __hash__(self): - return id(self) - - found = util.IdentitySet(result) - expected = {immutabledict(e) for e in expected} - - for wrong in filterfalse(lambda o: isinstance(o, cls), found): - fail( - 'Unexpected type "%s", expected "%s"' - % (type(wrong).__name__, cls.__name__) - ) - - if len(found) != len(expected): - fail( - 'Unexpected object count "%s", expected "%s"' - % (len(found), len(expected)) - ) - - NOVALUE = object() - - def _compare_item(obj, spec): - for key, value in spec.items(): - if isinstance(value, tuple): - try: - self.assert_unordered_result( - getattr(obj, key), value[0], *value[1] - ) - except AssertionError: - return False - else: - if getattr(obj, key, NOVALUE) != value: - return False - return True - - for expected_item in expected: - for found_item in found: - if _compare_item(found_item, expected_item): - found.remove(found_item) - break - else: - fail( - "Expected %s instance with attributes %s not found." - % (cls.__name__, repr(expected_item)) - ) - return True - - def sql_execution_asserter(self, db=None): - if db is None: - from . import db as db - - return assertsql.assert_engine(db) - - def assert_sql_execution(self, db, callable_, *rules): - with self.sql_execution_asserter(db) as asserter: - result = callable_() - asserter.assert_(*rules) - return result - - def assert_sql(self, db, callable_, rules): - newrules = [] - for rule in rules: - if isinstance(rule, dict): - newrule = assertsql.AllOf( - *[assertsql.CompiledSQL(k, v) for k, v in rule.items()] - ) - else: - newrule = assertsql.CompiledSQL(*rule) - newrules.append(newrule) - - return self.assert_sql_execution(db, callable_, *newrules) - - def assert_sql_count(self, db, callable_, count): - return self.assert_sql_execution( - db, callable_, assertsql.CountStatements(count) - ) - - @contextlib.contextmanager - def assert_execution(self, db, *rules): - with self.sql_execution_asserter(db) as asserter: - yield - asserter.assert_(*rules) - - def assert_statement_count(self, db, count): - return self.assert_execution(db, assertsql.CountStatements(count)) - - @contextlib.contextmanager - def assert_statement_count_multi_db(self, dbs, counts): - recs = [ - (self.sql_execution_asserter(db), db, count) - for (db, count) in zip(dbs, counts) - ] - asserters = [] - for ctx, db, count in recs: - asserters.append(ctx.__enter__()) - try: - yield - finally: - for asserter, (ctx, db, count) in zip(asserters, recs): - ctx.__exit__(None, None, None) - asserter.assert_(assertsql.CountStatements(count)) - - -class ComparesIndexes: - def compare_table_index_with_expected( - self, table: schema.Table, expected: list, dialect_name: str - ): - eq_(len(table.indexes), len(expected)) - idx_dict = {idx.name: idx for idx in table.indexes} - for exp in expected: - idx = idx_dict[exp["name"]] - eq_(idx.unique, exp["unique"]) - cols = [c for c in exp["column_names"] if c is not None] - eq_(len(idx.columns), len(cols)) - for c in cols: - is_true(c in idx.columns) - exprs = exp.get("expressions") - if exprs: - eq_(len(idx.expressions), len(exprs)) - for idx_exp, expr, col in zip( - idx.expressions, exprs, exp["column_names"] - ): - if col is None: - eq_(idx_exp.text, expr) - if ( - exp.get("dialect_options") - and f"{dialect_name}_include" in exp["dialect_options"] - ): - eq_( - idx.dialect_options[dialect_name]["include"], - exp["dialect_options"][f"{dialect_name}_include"], - ) -- cgit v1.2.3