summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py')
-rw-r--r--venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py516
1 files changed, 0 insertions, 516 deletions
diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py
deleted file mode 100644
index ae4d335..0000000
--- a/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py
+++ /dev/null
@@ -1,516 +0,0 @@
-# testing/assertsql.py
-# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
-# <see AUTHORS file>
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
-
-
-from __future__ import annotations
-
-import collections
-import contextlib
-import itertools
-import re
-
-from .. import event
-from ..engine import url
-from ..engine.default import DefaultDialect
-from ..schema import BaseDDLElement
-
-
-class AssertRule:
- is_consumed = False
- errormessage = None
- consume_statement = True
-
- def process_statement(self, execute_observed):
- pass
-
- def no_more_statements(self):
- assert False, (
- "All statements are complete, but pending "
- "assertion rules remain"
- )
-
-
-class SQLMatchRule(AssertRule):
- pass
-
-
-class CursorSQL(SQLMatchRule):
- def __init__(self, statement, params=None, consume_statement=True):
- self.statement = statement
- self.params = params
- self.consume_statement = consume_statement
-
- def process_statement(self, execute_observed):
- stmt = execute_observed.statements[0]
- if self.statement != stmt.statement or (
- self.params is not None and self.params != stmt.parameters
- ):
- self.consume_statement = True
- self.errormessage = (
- "Testing for exact SQL %s parameters %s received %s %s"
- % (
- self.statement,
- self.params,
- stmt.statement,
- stmt.parameters,
- )
- )
- else:
- execute_observed.statements.pop(0)
- self.is_consumed = True
- if not execute_observed.statements:
- self.consume_statement = True
-
-
-class CompiledSQL(SQLMatchRule):
- def __init__(
- self, statement, params=None, dialect="default", enable_returning=True
- ):
- self.statement = statement
- self.params = params
- self.dialect = dialect
- self.enable_returning = enable_returning
-
- def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r"[\n\t]", "", self.statement)
- return received_statement == stmt
-
- def _compile_dialect(self, execute_observed):
- if self.dialect == "default":
- dialect = DefaultDialect()
- # this is currently what tests are expecting
- # dialect.supports_default_values = True
- dialect.supports_default_metavalue = True
-
- if self.enable_returning:
- dialect.insert_returning = dialect.update_returning = (
- dialect.delete_returning
- ) = True
- dialect.use_insertmanyvalues = True
- dialect.supports_multivalues_insert = True
- dialect.update_returning_multifrom = True
- dialect.delete_returning_multifrom = True
- # dialect.favor_returning_over_lastrowid = True
- # dialect.insert_null_pk_still_autoincrements = True
-
- # this is calculated but we need it to be True for this
- # to look like all the current RETURNING dialects
- assert dialect.insert_executemany_returning
-
- return dialect
- else:
- return url.URL.create(self.dialect).get_dialect()()
-
- def _received_statement(self, execute_observed):
- """reconstruct the statement and params in terms
- of a target dialect, which for CompiledSQL is just DefaultDialect."""
-
- context = execute_observed.context
- compare_dialect = self._compile_dialect(execute_observed)
-
- # received_statement runs a full compile(). we should not need to
- # consider extracted_parameters; if we do this indicates some state
- # is being sent from a previous cached query, which some misbehaviors
- # in the ORM can cause, see #6881
- cache_key = None # execute_observed.context.compiled.cache_key
- extracted_parameters = (
- None # execute_observed.context.extracted_parameters
- )
-
- if "schema_translate_map" in context.execution_options:
- map_ = context.execution_options["schema_translate_map"]
- else:
- map_ = None
-
- if isinstance(execute_observed.clauseelement, BaseDDLElement):
- compiled = execute_observed.clauseelement.compile(
- dialect=compare_dialect,
- schema_translate_map=map_,
- )
- else:
- compiled = execute_observed.clauseelement.compile(
- cache_key=cache_key,
- dialect=compare_dialect,
- column_keys=context.compiled.column_keys,
- for_executemany=context.compiled.for_executemany,
- schema_translate_map=map_,
- )
- _received_statement = re.sub(r"[\n\t]", "", str(compiled))
- parameters = execute_observed.parameters
-
- if not parameters:
- _received_parameters = [
- compiled.construct_params(
- extracted_parameters=extracted_parameters
- )
- ]
- else:
- _received_parameters = [
- compiled.construct_params(
- m, extracted_parameters=extracted_parameters
- )
- for m in parameters
- ]
-
- return _received_statement, _received_parameters
-
- def process_statement(self, execute_observed):
- context = execute_observed.context
-
- _received_statement, _received_parameters = self._received_statement(
- execute_observed
- )
- params = self._all_params(context)
-
- equivalent = self._compare_sql(execute_observed, _received_statement)
-
- if equivalent:
- if params is not None:
- all_params = list(params)
- all_received = list(_received_parameters)
- while all_params and all_received:
- param = dict(all_params.pop(0))
-
- for idx, received in enumerate(list(all_received)):
- # do a positive compare only
- for param_key in param:
- # a key in param did not match current
- # 'received'
- if (
- param_key not in received
- or received[param_key] != param[param_key]
- ):
- break
- else:
- # all keys in param matched 'received';
- # onto next param
- del all_received[idx]
- break
- else:
- # param did not match any entry
- # in all_received
- equivalent = False
- break
- if all_params or all_received:
- equivalent = False
-
- if equivalent:
- self.is_consumed = True
- self.errormessage = None
- else:
- self.errormessage = self._failure_message(
- execute_observed, params
- ) % {
- "received_statement": _received_statement,
- "received_parameters": _received_parameters,
- }
-
- def _all_params(self, context):
- if self.params:
- if callable(self.params):
- params = self.params(context)
- else:
- params = self.params
- if not isinstance(params, list):
- params = [params]
- return params
- else:
- return None
-
- def _failure_message(self, execute_observed, expected_params):
- return (
- "Testing for compiled statement\n%r partial params %s, "
- "received\n%%(received_statement)r with params "
- "%%(received_parameters)r"
- % (
- self.statement.replace("%", "%%"),
- repr(expected_params).replace("%", "%%"),
- )
- )
-
-
-class RegexSQL(CompiledSQL):
- def __init__(
- self, regex, params=None, dialect="default", enable_returning=False
- ):
- SQLMatchRule.__init__(self)
- self.regex = re.compile(regex)
- self.orig_regex = regex
- self.params = params
- self.dialect = dialect
- self.enable_returning = enable_returning
-
- def _failure_message(self, execute_observed, expected_params):
- return (
- "Testing for compiled statement ~%r partial params %s, "
- "received %%(received_statement)r with params "
- "%%(received_parameters)r"
- % (
- self.orig_regex.replace("%", "%%"),
- repr(expected_params).replace("%", "%%"),
- )
- )
-
- def _compare_sql(self, execute_observed, received_statement):
- return bool(self.regex.match(received_statement))
-
-
-class DialectSQL(CompiledSQL):
- def _compile_dialect(self, execute_observed):
- return execute_observed.context.dialect
-
- def _compare_no_space(self, real_stmt, received_stmt):
- stmt = re.sub(r"[\n\t]", "", real_stmt)
- return received_stmt == stmt
-
- def _received_statement(self, execute_observed):
- received_stmt, received_params = super()._received_statement(
- execute_observed
- )
-
- # TODO: why do we need this part?
- for real_stmt in execute_observed.statements:
- if self._compare_no_space(real_stmt.statement, received_stmt):
- break
- else:
- raise AssertionError(
- "Can't locate compiled statement %r in list of "
- "statements actually invoked" % received_stmt
- )
-
- return received_stmt, execute_observed.context.compiled_parameters
-
- def _dialect_adjusted_statement(self, dialect):
- paramstyle = dialect.paramstyle
- stmt = re.sub(r"[\n\t]", "", self.statement)
-
- # temporarily escape out PG double colons
- stmt = stmt.replace("::", "!!")
-
- if paramstyle == "pyformat":
- stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
- else:
- # positional params
- repl = None
- if paramstyle == "qmark":
- repl = "?"
- elif paramstyle == "format":
- repl = r"%s"
- elif paramstyle.startswith("numeric"):
- counter = itertools.count(1)
-
- num_identifier = "$" if paramstyle == "numeric_dollar" else ":"
-
- def repl(m):
- return f"{num_identifier}{next(counter)}"
-
- stmt = re.sub(r":([\w_]+)", repl, stmt)
-
- # put them back
- stmt = stmt.replace("!!", "::")
-
- return stmt
-
- def _compare_sql(self, execute_observed, received_statement):
- stmt = self._dialect_adjusted_statement(
- execute_observed.context.dialect
- )
- return received_statement == stmt
-
- def _failure_message(self, execute_observed, expected_params):
- return (
- "Testing for compiled statement\n%r partial params %s, "
- "received\n%%(received_statement)r with params "
- "%%(received_parameters)r"
- % (
- self._dialect_adjusted_statement(
- execute_observed.context.dialect
- ).replace("%", "%%"),
- repr(expected_params).replace("%", "%%"),
- )
- )
-
-
-class CountStatements(AssertRule):
- def __init__(self, count):
- self.count = count
- self._statement_count = 0
-
- def process_statement(self, execute_observed):
- self._statement_count += 1
-
- def no_more_statements(self):
- if self.count != self._statement_count:
- assert False, "desired statement count %d does not match %d" % (
- self.count,
- self._statement_count,
- )
-
-
-class AllOf(AssertRule):
- def __init__(self, *rules):
- self.rules = set(rules)
-
- def process_statement(self, execute_observed):
- for rule in list(self.rules):
- rule.errormessage = None
- rule.process_statement(execute_observed)
- if rule.is_consumed:
- self.rules.discard(rule)
- if not self.rules:
- self.is_consumed = True
- break
- elif not rule.errormessage:
- # rule is not done yet
- self.errormessage = None
- break
- else:
- self.errormessage = list(self.rules)[0].errormessage
-
-
-class EachOf(AssertRule):
- def __init__(self, *rules):
- self.rules = list(rules)
-
- def process_statement(self, execute_observed):
- if not self.rules:
- self.is_consumed = True
- self.consume_statement = False
-
- while self.rules:
- rule = self.rules[0]
- rule.process_statement(execute_observed)
- if rule.is_consumed:
- self.rules.pop(0)
- elif rule.errormessage:
- self.errormessage = rule.errormessage
- if rule.consume_statement:
- break
-
- if not self.rules:
- self.is_consumed = True
-
- def no_more_statements(self):
- if self.rules and not self.rules[0].is_consumed:
- self.rules[0].no_more_statements()
- elif self.rules:
- super().no_more_statements()
-
-
-class Conditional(EachOf):
- def __init__(self, condition, rules, else_rules):
- if condition:
- super().__init__(*rules)
- else:
- super().__init__(*else_rules)
-
-
-class Or(AllOf):
- def process_statement(self, execute_observed):
- for rule in self.rules:
- rule.process_statement(execute_observed)
- if rule.is_consumed:
- self.is_consumed = True
- break
- else:
- self.errormessage = list(self.rules)[0].errormessage
-
-
-class SQLExecuteObserved:
- def __init__(self, context, clauseelement, multiparams, params):
- self.context = context
- self.clauseelement = clauseelement
-
- if multiparams:
- self.parameters = multiparams
- elif params:
- self.parameters = [params]
- else:
- self.parameters = []
- self.statements = []
-
- def __repr__(self):
- return str(self.statements)
-
-
-class SQLCursorExecuteObserved(
- collections.namedtuple(
- "SQLCursorExecuteObserved",
- ["statement", "parameters", "context", "executemany"],
- )
-):
- pass
-
-
-class SQLAsserter:
- def __init__(self):
- self.accumulated = []
-
- def _close(self):
- self._final = self.accumulated
- del self.accumulated
-
- def assert_(self, *rules):
- rule = EachOf(*rules)
-
- observed = list(self._final)
- while observed:
- statement = observed.pop(0)
- rule.process_statement(statement)
- if rule.is_consumed:
- break
- elif rule.errormessage:
- assert False, rule.errormessage
- if observed:
- assert False, "Additional SQL statements remain:\n%s" % observed
- elif not rule.is_consumed:
- rule.no_more_statements()
-
-
-@contextlib.contextmanager
-def assert_engine(engine):
- asserter = SQLAsserter()
-
- orig = []
-
- @event.listens_for(engine, "before_execute")
- def connection_execute(
- conn, clauseelement, multiparams, params, execution_options
- ):
- # grab the original statement + params before any cursor
- # execution
- orig[:] = clauseelement, multiparams, params
-
- @event.listens_for(engine, "after_cursor_execute")
- def cursor_execute(
- conn, cursor, statement, parameters, context, executemany
- ):
- if not context:
- return
- # then grab real cursor statements and associate them all
- # around a single context
- if (
- asserter.accumulated
- and asserter.accumulated[-1].context is context
- ):
- obs = asserter.accumulated[-1]
- else:
- obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
- asserter.accumulated.append(obs)
- obs.statements.append(
- SQLCursorExecuteObserved(
- statement, parameters, context, executemany
- )
- )
-
- try:
- yield asserter
- finally:
- event.remove(engine, "after_cursor_execute", cursor_execute)
- event.remove(engine, "before_execute", connection_execute)
- asserter._close()