diff options
| author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 | 
|---|---|---|
| committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 | 
| commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
| tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/sqlalchemy/testing/engines.py | |
| parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) | |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/testing/engines.py')
| -rw-r--r-- | venv/lib/python3.11/site-packages/sqlalchemy/testing/engines.py | 472 | 
1 files changed, 472 insertions, 0 deletions
| diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/engines.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/engines.py new file mode 100644 index 0000000..7cae807 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/engines.py @@ -0,0 +1,472 @@ +# testing/engines.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import collections +import re +import typing +from typing import Any +from typing import Dict +from typing import Optional +import warnings +import weakref + +from . import config +from .util import decorator +from .util import gc_collect +from .. import event +from .. import pool +from ..util import await_only +from ..util.typing import Literal + + +if typing.TYPE_CHECKING: +    from ..engine import Engine +    from ..engine.url import URL +    from ..ext.asyncio import AsyncEngine + + +class ConnectionKiller: +    def __init__(self): +        self.proxy_refs = weakref.WeakKeyDictionary() +        self.testing_engines = collections.defaultdict(set) +        self.dbapi_connections = set() + +    def add_pool(self, pool): +        event.listen(pool, "checkout", self._add_conn) +        event.listen(pool, "checkin", self._remove_conn) +        event.listen(pool, "close", self._remove_conn) +        event.listen(pool, "close_detached", self._remove_conn) +        # note we are keeping "invalidated" here, as those are still +        # opened connections we would like to roll back + +    def _add_conn(self, dbapi_con, con_record, con_proxy): +        self.dbapi_connections.add(dbapi_con) +        self.proxy_refs[con_proxy] = True + +    def _remove_conn(self, dbapi_conn, *arg): +        self.dbapi_connections.discard(dbapi_conn) + +    def add_engine(self, engine, scope): +        self.add_pool(engine.pool) + +        assert scope in ("class", "global", "function", "fixture") +        self.testing_engines[scope].add(engine) + +    def _safe(self, fn): +        try: +            fn() +        except Exception as e: +            warnings.warn( +                "testing_reaper couldn't rollback/close connection: %s" % e +            ) + +    def rollback_all(self): +        for rec in list(self.proxy_refs): +            if rec is not None and rec.is_valid: +                self._safe(rec.rollback) + +    def checkin_all(self): +        # run pool.checkin() for all ConnectionFairy instances we have +        # tracked. + +        for rec in list(self.proxy_refs): +            if rec is not None and rec.is_valid: +                self.dbapi_connections.discard(rec.dbapi_connection) +                self._safe(rec._checkin) + +        # for fairy refs that were GCed and could not close the connection, +        # such as asyncio, roll back those remaining connections +        for con in self.dbapi_connections: +            self._safe(con.rollback) +        self.dbapi_connections.clear() + +    def close_all(self): +        self.checkin_all() + +    def prepare_for_drop_tables(self, connection): +        # don't do aggressive checks for third party test suites +        if not config.bootstrapped_as_sqlalchemy: +            return + +        from . import provision + +        provision.prepare_for_drop_tables(connection.engine.url, connection) + +    def _drop_testing_engines(self, scope): +        eng = self.testing_engines[scope] +        for rec in list(eng): +            for proxy_ref in list(self.proxy_refs): +                if proxy_ref is not None and proxy_ref.is_valid: +                    if ( +                        proxy_ref._pool is not None +                        and proxy_ref._pool is rec.pool +                    ): +                        self._safe(proxy_ref._checkin) + +            if hasattr(rec, "sync_engine"): +                await_only(rec.dispose()) +            else: +                rec.dispose() +        eng.clear() + +    def after_test(self): +        self._drop_testing_engines("function") + +    def after_test_outside_fixtures(self, test): +        # don't do aggressive checks for third party test suites +        if not config.bootstrapped_as_sqlalchemy: +            return + +        if test.__class__.__leave_connections_for_teardown__: +            return + +        self.checkin_all() + +        # on PostgreSQL, this will test for any "idle in transaction" +        # connections.   useful to identify tests with unusual patterns +        # that can't be cleaned up correctly. +        from . import provision + +        with config.db.connect() as conn: +            provision.prepare_for_drop_tables(conn.engine.url, conn) + +    def stop_test_class_inside_fixtures(self): +        self.checkin_all() +        self._drop_testing_engines("function") +        self._drop_testing_engines("class") + +    def stop_test_class_outside_fixtures(self): +        # ensure no refs to checked out connections at all. + +        if pool.base._strong_ref_connection_records: +            gc_collect() + +            if pool.base._strong_ref_connection_records: +                ln = len(pool.base._strong_ref_connection_records) +                pool.base._strong_ref_connection_records.clear() +                assert ( +                    False +                ), "%d connection recs not cleared after test suite" % (ln) + +    def final_cleanup(self): +        self.checkin_all() +        for scope in self.testing_engines: +            self._drop_testing_engines(scope) + +    def assert_all_closed(self): +        for rec in self.proxy_refs: +            if rec.is_valid: +                assert False + + +testing_reaper = ConnectionKiller() + + +@decorator +def assert_conns_closed(fn, *args, **kw): +    try: +        fn(*args, **kw) +    finally: +        testing_reaper.assert_all_closed() + + +@decorator +def rollback_open_connections(fn, *args, **kw): +    """Decorator that rolls back all open connections after fn execution.""" + +    try: +        fn(*args, **kw) +    finally: +        testing_reaper.rollback_all() + + +@decorator +def close_first(fn, *args, **kw): +    """Decorator that closes all connections before fn execution.""" + +    testing_reaper.checkin_all() +    fn(*args, **kw) + + +@decorator +def close_open_connections(fn, *args, **kw): +    """Decorator that closes all connections after fn execution.""" +    try: +        fn(*args, **kw) +    finally: +        testing_reaper.checkin_all() + + +def all_dialects(exclude=None): +    import sqlalchemy.dialects as d + +    for name in d.__all__: +        # TEMPORARY +        if exclude and name in exclude: +            continue +        mod = getattr(d, name, None) +        if not mod: +            mod = getattr( +                __import__("sqlalchemy.dialects.%s" % name).dialects, name +            ) +        yield mod.dialect() + + +class ReconnectFixture: +    def __init__(self, dbapi): +        self.dbapi = dbapi +        self.connections = [] +        self.is_stopped = False + +    def __getattr__(self, key): +        return getattr(self.dbapi, key) + +    def connect(self, *args, **kwargs): +        conn = self.dbapi.connect(*args, **kwargs) +        if self.is_stopped: +            self._safe(conn.close) +            curs = conn.cursor()  # should fail on Oracle etc. +            # should fail for everything that didn't fail +            # above, connection is closed +            curs.execute("select 1") +            assert False, "simulated connect failure didn't work" +        else: +            self.connections.append(conn) +            return conn + +    def _safe(self, fn): +        try: +            fn() +        except Exception as e: +            warnings.warn("ReconnectFixture couldn't close connection: %s" % e) + +    def shutdown(self, stop=False): +        # TODO: this doesn't cover all cases +        # as nicely as we'd like, namely MySQLdb. +        # would need to implement R. Brewer's +        # proxy server idea to get better +        # coverage. +        self.is_stopped = stop +        for c in list(self.connections): +            self._safe(c.close) +        self.connections = [] + +    def restart(self): +        self.is_stopped = False + + +def reconnecting_engine(url=None, options=None): +    url = url or config.db.url +    dbapi = config.db.dialect.dbapi +    if not options: +        options = {} +    options["module"] = ReconnectFixture(dbapi) +    engine = testing_engine(url, options) +    _dispose = engine.dispose + +    def dispose(): +        engine.dialect.dbapi.shutdown() +        engine.dialect.dbapi.is_stopped = False +        _dispose() + +    engine.test_shutdown = engine.dialect.dbapi.shutdown +    engine.test_restart = engine.dialect.dbapi.restart +    engine.dispose = dispose +    return engine + + +@typing.overload +def testing_engine( +    url: Optional[URL] = None, +    options: Optional[Dict[str, Any]] = None, +    asyncio: Literal[False] = False, +    transfer_staticpool: bool = False, +) -> Engine: ... + + +@typing.overload +def testing_engine( +    url: Optional[URL] = None, +    options: Optional[Dict[str, Any]] = None, +    asyncio: Literal[True] = True, +    transfer_staticpool: bool = False, +) -> AsyncEngine: ... + + +def testing_engine( +    url=None, +    options=None, +    asyncio=False, +    transfer_staticpool=False, +    share_pool=False, +    _sqlite_savepoint=False, +): +    if asyncio: +        assert not _sqlite_savepoint +        from sqlalchemy.ext.asyncio import ( +            create_async_engine as create_engine, +        ) +    else: +        from sqlalchemy import create_engine +    from sqlalchemy.engine.url import make_url + +    if not options: +        use_reaper = True +        scope = "function" +        sqlite_savepoint = False +    else: +        use_reaper = options.pop("use_reaper", True) +        scope = options.pop("scope", "function") +        sqlite_savepoint = options.pop("sqlite_savepoint", False) + +    url = url or config.db.url + +    url = make_url(url) +    if options is None: +        if config.db is None or url.drivername == config.db.url.drivername: +            options = config.db_opts +        else: +            options = {} +    elif config.db is not None and url.drivername == config.db.url.drivername: +        default_opt = config.db_opts.copy() +        default_opt.update(options) + +    engine = create_engine(url, **options) + +    if sqlite_savepoint and engine.name == "sqlite": +        # apply SQLite savepoint workaround +        @event.listens_for(engine, "connect") +        def do_connect(dbapi_connection, connection_record): +            dbapi_connection.isolation_level = None + +        @event.listens_for(engine, "begin") +        def do_begin(conn): +            conn.exec_driver_sql("BEGIN") + +    if transfer_staticpool: +        from sqlalchemy.pool import StaticPool + +        if config.db is not None and isinstance(config.db.pool, StaticPool): +            use_reaper = False +            engine.pool._transfer_from(config.db.pool) +    elif share_pool: +        engine.pool = config.db.pool + +    if scope == "global": +        if asyncio: +            engine.sync_engine._has_events = True +        else: +            engine._has_events = ( +                True  # enable event blocks, helps with profiling +            ) + +    if ( +        isinstance(engine.pool, pool.QueuePool) +        and "pool" not in options +        and "pool_timeout" not in options +        and "max_overflow" not in options +    ): +        engine.pool._timeout = 0 +        engine.pool._max_overflow = 0 +    if use_reaper: +        testing_reaper.add_engine(engine, scope) + +    return engine + + +def mock_engine(dialect_name=None): +    """Provides a mocking engine based on the current testing.db. + +    This is normally used to test DDL generation flow as emitted +    by an Engine. + +    It should not be used in other cases, as assert_compile() and +    assert_sql_execution() are much better choices with fewer +    moving parts. + +    """ + +    from sqlalchemy import create_mock_engine + +    if not dialect_name: +        dialect_name = config.db.name + +    buffer = [] + +    def executor(sql, *a, **kw): +        buffer.append(sql) + +    def assert_sql(stmts): +        recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer] +        assert recv == stmts, recv + +    def print_sql(): +        d = engine.dialect +        return "\n".join(str(s.compile(dialect=d)) for s in engine.mock) + +    engine = create_mock_engine(dialect_name + "://", executor) +    assert not hasattr(engine, "mock") +    engine.mock = buffer +    engine.assert_sql = assert_sql +    engine.print_sql = print_sql +    return engine + + +class DBAPIProxyCursor: +    """Proxy a DBAPI cursor. + +    Tests can provide subclasses of this to intercept +    DBAPI-level cursor operations. + +    """ + +    def __init__(self, engine, conn, *args, **kwargs): +        self.engine = engine +        self.connection = conn +        self.cursor = conn.cursor(*args, **kwargs) + +    def execute(self, stmt, parameters=None, **kw): +        if parameters: +            return self.cursor.execute(stmt, parameters, **kw) +        else: +            return self.cursor.execute(stmt, **kw) + +    def executemany(self, stmt, params, **kw): +        return self.cursor.executemany(stmt, params, **kw) + +    def __iter__(self): +        return iter(self.cursor) + +    def __getattr__(self, key): +        return getattr(self.cursor, key) + + +class DBAPIProxyConnection: +    """Proxy a DBAPI connection. + +    Tests can provide subclasses of this to intercept +    DBAPI-level connection operations. + +    """ + +    def __init__(self, engine, conn, cursor_cls): +        self.conn = conn +        self.engine = engine +        self.cursor_cls = cursor_cls + +    def cursor(self, *args, **kwargs): +        return self.cursor_cls(self.engine, self.conn, *args, **kwargs) + +    def close(self): +        self.conn.close() + +    def __getattr__(self, key): +        return getattr(self.conn, key) | 
