From 12cf076118570eebbff08c6b3090e0d4798447a1 Mon Sep 17 00:00:00 2001 From: cyfraeviolae Date: Wed, 3 Apr 2024 03:17:55 -0400 Subject: no venv --- .../site-packages/sqlalchemy/testing/assertsql.py | 516 --------------------- 1 file changed, 516 deletions(-) delete mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py') diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py deleted file mode 100644 index ae4d335..0000000 --- a/venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py +++ /dev/null @@ -1,516 +0,0 @@ -# testing/assertsql.py -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors - - -from __future__ import annotations - -import collections -import contextlib -import itertools -import re - -from .. import event -from ..engine import url -from ..engine.default import DefaultDialect -from ..schema import BaseDDLElement - - -class AssertRule: - is_consumed = False - errormessage = None - consume_statement = True - - def process_statement(self, execute_observed): - pass - - def no_more_statements(self): - assert False, ( - "All statements are complete, but pending " - "assertion rules remain" - ) - - -class SQLMatchRule(AssertRule): - pass - - -class CursorSQL(SQLMatchRule): - def __init__(self, statement, params=None, consume_statement=True): - self.statement = statement - self.params = params - self.consume_statement = consume_statement - - def process_statement(self, execute_observed): - stmt = execute_observed.statements[0] - if self.statement != stmt.statement or ( - self.params is not None and self.params != stmt.parameters - ): - self.consume_statement = True - self.errormessage = ( - "Testing for exact SQL %s parameters %s received %s %s" - % ( - self.statement, - self.params, - stmt.statement, - stmt.parameters, - ) - ) - else: - execute_observed.statements.pop(0) - self.is_consumed = True - if not execute_observed.statements: - self.consume_statement = True - - -class CompiledSQL(SQLMatchRule): - def __init__( - self, statement, params=None, dialect="default", enable_returning=True - ): - self.statement = statement - self.params = params - self.dialect = dialect - self.enable_returning = enable_returning - - def _compare_sql(self, execute_observed, received_statement): - stmt = re.sub(r"[\n\t]", "", self.statement) - return received_statement == stmt - - def _compile_dialect(self, execute_observed): - if self.dialect == "default": - dialect = DefaultDialect() - # this is currently what tests are expecting - # dialect.supports_default_values = True - dialect.supports_default_metavalue = True - - if self.enable_returning: - dialect.insert_returning = dialect.update_returning = ( - dialect.delete_returning - ) = True - dialect.use_insertmanyvalues = True - dialect.supports_multivalues_insert = True - dialect.update_returning_multifrom = True - dialect.delete_returning_multifrom = True - # dialect.favor_returning_over_lastrowid = True - # dialect.insert_null_pk_still_autoincrements = True - - # this is calculated but we need it to be True for this - # to look like all the current RETURNING dialects - assert dialect.insert_executemany_returning - - return dialect - else: - return url.URL.create(self.dialect).get_dialect()() - - def _received_statement(self, execute_observed): - """reconstruct the statement and params in terms - of a target dialect, which for CompiledSQL is just DefaultDialect.""" - - context = execute_observed.context - compare_dialect = self._compile_dialect(execute_observed) - - # received_statement runs a full compile(). we should not need to - # consider extracted_parameters; if we do this indicates some state - # is being sent from a previous cached query, which some misbehaviors - # in the ORM can cause, see #6881 - cache_key = None # execute_observed.context.compiled.cache_key - extracted_parameters = ( - None # execute_observed.context.extracted_parameters - ) - - if "schema_translate_map" in context.execution_options: - map_ = context.execution_options["schema_translate_map"] - else: - map_ = None - - if isinstance(execute_observed.clauseelement, BaseDDLElement): - compiled = execute_observed.clauseelement.compile( - dialect=compare_dialect, - schema_translate_map=map_, - ) - else: - compiled = execute_observed.clauseelement.compile( - cache_key=cache_key, - dialect=compare_dialect, - column_keys=context.compiled.column_keys, - for_executemany=context.compiled.for_executemany, - schema_translate_map=map_, - ) - _received_statement = re.sub(r"[\n\t]", "", str(compiled)) - parameters = execute_observed.parameters - - if not parameters: - _received_parameters = [ - compiled.construct_params( - extracted_parameters=extracted_parameters - ) - ] - else: - _received_parameters = [ - compiled.construct_params( - m, extracted_parameters=extracted_parameters - ) - for m in parameters - ] - - return _received_statement, _received_parameters - - def process_statement(self, execute_observed): - context = execute_observed.context - - _received_statement, _received_parameters = self._received_statement( - execute_observed - ) - params = self._all_params(context) - - equivalent = self._compare_sql(execute_observed, _received_statement) - - if equivalent: - if params is not None: - all_params = list(params) - all_received = list(_received_parameters) - while all_params and all_received: - param = dict(all_params.pop(0)) - - for idx, received in enumerate(list(all_received)): - # do a positive compare only - for param_key in param: - # a key in param did not match current - # 'received' - if ( - param_key not in received - or received[param_key] != param[param_key] - ): - break - else: - # all keys in param matched 'received'; - # onto next param - del all_received[idx] - break - else: - # param did not match any entry - # in all_received - equivalent = False - break - if all_params or all_received: - equivalent = False - - if equivalent: - self.is_consumed = True - self.errormessage = None - else: - self.errormessage = self._failure_message( - execute_observed, params - ) % { - "received_statement": _received_statement, - "received_parameters": _received_parameters, - } - - def _all_params(self, context): - if self.params: - if callable(self.params): - params = self.params(context) - else: - params = self.params - if not isinstance(params, list): - params = [params] - return params - else: - return None - - def _failure_message(self, execute_observed, expected_params): - return ( - "Testing for compiled statement\n%r partial params %s, " - "received\n%%(received_statement)r with params " - "%%(received_parameters)r" - % ( - self.statement.replace("%", "%%"), - repr(expected_params).replace("%", "%%"), - ) - ) - - -class RegexSQL(CompiledSQL): - def __init__( - self, regex, params=None, dialect="default", enable_returning=False - ): - SQLMatchRule.__init__(self) - self.regex = re.compile(regex) - self.orig_regex = regex - self.params = params - self.dialect = dialect - self.enable_returning = enable_returning - - def _failure_message(self, execute_observed, expected_params): - return ( - "Testing for compiled statement ~%r partial params %s, " - "received %%(received_statement)r with params " - "%%(received_parameters)r" - % ( - self.orig_regex.replace("%", "%%"), - repr(expected_params).replace("%", "%%"), - ) - ) - - def _compare_sql(self, execute_observed, received_statement): - return bool(self.regex.match(received_statement)) - - -class DialectSQL(CompiledSQL): - def _compile_dialect(self, execute_observed): - return execute_observed.context.dialect - - def _compare_no_space(self, real_stmt, received_stmt): - stmt = re.sub(r"[\n\t]", "", real_stmt) - return received_stmt == stmt - - def _received_statement(self, execute_observed): - received_stmt, received_params = super()._received_statement( - execute_observed - ) - - # TODO: why do we need this part? - for real_stmt in execute_observed.statements: - if self._compare_no_space(real_stmt.statement, received_stmt): - break - else: - raise AssertionError( - "Can't locate compiled statement %r in list of " - "statements actually invoked" % received_stmt - ) - - return received_stmt, execute_observed.context.compiled_parameters - - def _dialect_adjusted_statement(self, dialect): - paramstyle = dialect.paramstyle - stmt = re.sub(r"[\n\t]", "", self.statement) - - # temporarily escape out PG double colons - stmt = stmt.replace("::", "!!") - - if paramstyle == "pyformat": - stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt) - else: - # positional params - repl = None - if paramstyle == "qmark": - repl = "?" - elif paramstyle == "format": - repl = r"%s" - elif paramstyle.startswith("numeric"): - counter = itertools.count(1) - - num_identifier = "$" if paramstyle == "numeric_dollar" else ":" - - def repl(m): - return f"{num_identifier}{next(counter)}" - - stmt = re.sub(r":([\w_]+)", repl, stmt) - - # put them back - stmt = stmt.replace("!!", "::") - - return stmt - - def _compare_sql(self, execute_observed, received_statement): - stmt = self._dialect_adjusted_statement( - execute_observed.context.dialect - ) - return received_statement == stmt - - def _failure_message(self, execute_observed, expected_params): - return ( - "Testing for compiled statement\n%r partial params %s, " - "received\n%%(received_statement)r with params " - "%%(received_parameters)r" - % ( - self._dialect_adjusted_statement( - execute_observed.context.dialect - ).replace("%", "%%"), - repr(expected_params).replace("%", "%%"), - ) - ) - - -class CountStatements(AssertRule): - def __init__(self, count): - self.count = count - self._statement_count = 0 - - def process_statement(self, execute_observed): - self._statement_count += 1 - - def no_more_statements(self): - if self.count != self._statement_count: - assert False, "desired statement count %d does not match %d" % ( - self.count, - self._statement_count, - ) - - -class AllOf(AssertRule): - def __init__(self, *rules): - self.rules = set(rules) - - def process_statement(self, execute_observed): - for rule in list(self.rules): - rule.errormessage = None - rule.process_statement(execute_observed) - if rule.is_consumed: - self.rules.discard(rule) - if not self.rules: - self.is_consumed = True - break - elif not rule.errormessage: - # rule is not done yet - self.errormessage = None - break - else: - self.errormessage = list(self.rules)[0].errormessage - - -class EachOf(AssertRule): - def __init__(self, *rules): - self.rules = list(rules) - - def process_statement(self, execute_observed): - if not self.rules: - self.is_consumed = True - self.consume_statement = False - - while self.rules: - rule = self.rules[0] - rule.process_statement(execute_observed) - if rule.is_consumed: - self.rules.pop(0) - elif rule.errormessage: - self.errormessage = rule.errormessage - if rule.consume_statement: - break - - if not self.rules: - self.is_consumed = True - - def no_more_statements(self): - if self.rules and not self.rules[0].is_consumed: - self.rules[0].no_more_statements() - elif self.rules: - super().no_more_statements() - - -class Conditional(EachOf): - def __init__(self, condition, rules, else_rules): - if condition: - super().__init__(*rules) - else: - super().__init__(*else_rules) - - -class Or(AllOf): - def process_statement(self, execute_observed): - for rule in self.rules: - rule.process_statement(execute_observed) - if rule.is_consumed: - self.is_consumed = True - break - else: - self.errormessage = list(self.rules)[0].errormessage - - -class SQLExecuteObserved: - def __init__(self, context, clauseelement, multiparams, params): - self.context = context - self.clauseelement = clauseelement - - if multiparams: - self.parameters = multiparams - elif params: - self.parameters = [params] - else: - self.parameters = [] - self.statements = [] - - def __repr__(self): - return str(self.statements) - - -class SQLCursorExecuteObserved( - collections.namedtuple( - "SQLCursorExecuteObserved", - ["statement", "parameters", "context", "executemany"], - ) -): - pass - - -class SQLAsserter: - def __init__(self): - self.accumulated = [] - - def _close(self): - self._final = self.accumulated - del self.accumulated - - def assert_(self, *rules): - rule = EachOf(*rules) - - observed = list(self._final) - while observed: - statement = observed.pop(0) - rule.process_statement(statement) - if rule.is_consumed: - break - elif rule.errormessage: - assert False, rule.errormessage - if observed: - assert False, "Additional SQL statements remain:\n%s" % observed - elif not rule.is_consumed: - rule.no_more_statements() - - -@contextlib.contextmanager -def assert_engine(engine): - asserter = SQLAsserter() - - orig = [] - - @event.listens_for(engine, "before_execute") - def connection_execute( - conn, clauseelement, multiparams, params, execution_options - ): - # grab the original statement + params before any cursor - # execution - orig[:] = clauseelement, multiparams, params - - @event.listens_for(engine, "after_cursor_execute") - def cursor_execute( - conn, cursor, statement, parameters, context, executemany - ): - if not context: - return - # then grab real cursor statements and associate them all - # around a single context - if ( - asserter.accumulated - and asserter.accumulated[-1].context is context - ): - obs = asserter.accumulated[-1] - else: - obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2]) - asserter.accumulated.append(obs) - obs.statements.append( - SQLCursorExecuteObserved( - statement, parameters, context, executemany - ) - ) - - try: - yield asserter - finally: - event.remove(engine, "after_cursor_execute", cursor_execute) - event.remove(engine, "before_execute", connection_execute) - asserter._close() -- cgit v1.2.3