diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/sqlalchemy/testing/suite | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/testing/suite')
28 files changed, 10600 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__init__.py new file mode 100644 index 0000000..a146cb3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__init__.py @@ -0,0 +1,19 @@ +# testing/suite/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from .test_cte import * # noqa +from .test_ddl import * # noqa +from .test_deprecations import * # noqa +from .test_dialect import * # noqa +from .test_insert import * # noqa +from .test_reflection import * # noqa +from .test_results import * # noqa +from .test_rowcount import * # noqa +from .test_select import * # noqa +from .test_sequence import * # noqa +from .test_types import * # noqa +from .test_unicode_ddl import * # noqa +from .test_update_delete import * # noqa diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b8ae9f7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_cte.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_cte.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6bf56dd --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_cte.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_ddl.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_ddl.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..99ddd6a --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_ddl.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_deprecations.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_deprecations.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a114600 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_deprecations.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_dialect.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_dialect.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..325c9a1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_dialect.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_insert.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_insert.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c212fa4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_insert.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_reflection.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_reflection.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..858574b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_reflection.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_results.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_results.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..624d1f2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_results.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_rowcount.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_rowcount.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4ebccba --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_rowcount.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_select.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_select.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5edde44 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_select.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_sequence.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_sequence.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..576fdcf --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_sequence.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_types.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..79cb268 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_unicode_ddl.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_unicode_ddl.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1c8d3ff --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_unicode_ddl.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_update_delete.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_update_delete.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1be9d55 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_update_delete.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_cte.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_cte.py new file mode 100644 index 0000000..5d37880 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_cte.py @@ -0,0 +1,211 @@ +# testing/suite/test_cte.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .. import fixtures +from ..assertions import eq_ +from ..schema import Column +from ..schema import Table +from ... import ForeignKey +from ... import Integer +from ... import select +from ... import String +from ... import testing + + +class CTETest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("ctes",) + + run_inserts = "each" + run_deletes = "each" + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", ForeignKey("some_table.id")), + ) + + Table( + "some_other_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "d1", "parent_id": None}, + {"id": 2, "data": "d2", "parent_id": 1}, + {"id": 3, "data": "d3", "parent_id": 1}, + {"id": 4, "data": "d4", "parent_id": 3}, + {"id": 5, "data": "d5", "parent_id": 3}, + ], + ) + + def test_select_nonrecursive_round_trip(self, connection): + some_table = self.tables.some_table + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + result = connection.execute( + select(cte.c.data).where(cte.c.data.in_(["d4", "d5"])) + ) + eq_(result.fetchall(), [("d4",)]) + + def test_select_recursive_round_trip(self, connection): + some_table = self.tables.some_table + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte", recursive=True) + ) + + cte_alias = cte.alias("c1") + st1 = some_table.alias() + # note that SQL Server requires this to be UNION ALL, + # can't be UNION + cte = cte.union_all( + select(st1).where(st1.c.id == cte_alias.c.parent_id) + ) + result = connection.execute( + select(cte.c.data) + .where(cte.c.data != "d2") + .order_by(cte.c.data.desc()) + ) + eq_( + result.fetchall(), + [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)], + ) + + def test_insert_from_select_round_trip(self, connection): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(cte) + ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)], + ) + + @testing.requires.ctes_with_update_delete + @testing.requires.update_from + def test_update_from_round_trip(self, connection): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) + ) + ) + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.update() + .values(parent_id=5) + .where(some_other_table.c.data == cte.c.data) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [ + (1, "d1", None), + (2, "d2", 5), + (3, "d3", 5), + (4, "d4", 5), + (5, "d5", 3), + ], + ) + + @testing.requires.ctes_with_update_delete + @testing.requires.delete_from + def test_delete_from_round_trip(self, connection): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) + ) + ) + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.delete().where( + some_other_table.c.data == cte.c.data + ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(1, "d1", None), (5, "d5", 3)], + ) + + @testing.requires.ctes_with_update_delete + def test_delete_scalar_subq_round_trip(self, connection): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) + ) + ) + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.delete().where( + some_other_table.c.data + == select(cte.c.data) + .where(cte.c.id == some_other_table.c.id) + .scalar_subquery() + ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(1, "d1", None), (5, "d5", 3)], + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_ddl.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_ddl.py new file mode 100644 index 0000000..3d9b8ec --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_ddl.py @@ -0,0 +1,389 @@ +# testing/suite/test_ddl.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import random + +from . import testing +from .. import config +from .. import fixtures +from .. import util +from ..assertions import eq_ +from ..assertions import is_false +from ..assertions import is_true +from ..config import requirements +from ..schema import Table +from ... import CheckConstraint +from ... import Column +from ... import ForeignKeyConstraint +from ... import Index +from ... import inspect +from ... import Integer +from ... import schema +from ... import String +from ... import UniqueConstraint + + +class TableDDLTest(fixtures.TestBase): + __backend__ = True + + def _simple_fixture(self, schema=None): + return Table( + "test_table", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + schema=schema, + ) + + def _underscore_fixture(self): + return Table( + "_test_table", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("_data", String(50)), + ) + + def _table_index_fixture(self, schema=None): + table = self._simple_fixture(schema=schema) + idx = Index("test_index", table.c.data) + return table, idx + + def _simple_roundtrip(self, table): + with config.db.begin() as conn: + conn.execute(table.insert().values((1, "some data"))) + result = conn.execute(table.select()) + eq_(result.first(), (1, "some data")) + + @requirements.create_table + @util.provide_metadata + def test_create_table(self): + table = self._simple_fixture() + table.create(config.db, checkfirst=False) + self._simple_roundtrip(table) + + @requirements.create_table + @requirements.schemas + @util.provide_metadata + def test_create_table_schema(self): + table = self._simple_fixture(schema=config.test_schema) + table.create(config.db, checkfirst=False) + self._simple_roundtrip(table) + + @requirements.drop_table + @util.provide_metadata + def test_drop_table(self): + table = self._simple_fixture() + table.create(config.db, checkfirst=False) + table.drop(config.db, checkfirst=False) + + @requirements.create_table + @util.provide_metadata + def test_underscore_names(self): + table = self._underscore_fixture() + table.create(config.db, checkfirst=False) + self._simple_roundtrip(table) + + @requirements.comment_reflection + @util.provide_metadata + def test_add_table_comment(self, connection): + table = self._simple_fixture() + table.create(connection, checkfirst=False) + table.comment = "a comment" + connection.execute(schema.SetTableComment(table)) + eq_( + inspect(connection).get_table_comment("test_table"), + {"text": "a comment"}, + ) + + @requirements.comment_reflection + @util.provide_metadata + def test_drop_table_comment(self, connection): + table = self._simple_fixture() + table.create(connection, checkfirst=False) + table.comment = "a comment" + connection.execute(schema.SetTableComment(table)) + connection.execute(schema.DropTableComment(table)) + eq_( + inspect(connection).get_table_comment("test_table"), {"text": None} + ) + + @requirements.table_ddl_if_exists + @util.provide_metadata + def test_create_table_if_not_exists(self, connection): + table = self._simple_fixture() + + connection.execute(schema.CreateTable(table, if_not_exists=True)) + + is_true(inspect(connection).has_table("test_table")) + connection.execute(schema.CreateTable(table, if_not_exists=True)) + + @requirements.index_ddl_if_exists + @util.provide_metadata + def test_create_index_if_not_exists(self, connection): + table, idx = self._table_index_fixture() + + connection.execute(schema.CreateTable(table, if_not_exists=True)) + is_true(inspect(connection).has_table("test_table")) + is_false( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.CreateIndex(idx, if_not_exists=True)) + + is_true( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.CreateIndex(idx, if_not_exists=True)) + + @requirements.table_ddl_if_exists + @util.provide_metadata + def test_drop_table_if_exists(self, connection): + table = self._simple_fixture() + + table.create(connection) + + is_true(inspect(connection).has_table("test_table")) + + connection.execute(schema.DropTable(table, if_exists=True)) + + is_false(inspect(connection).has_table("test_table")) + + connection.execute(schema.DropTable(table, if_exists=True)) + + @requirements.index_ddl_if_exists + @util.provide_metadata + def test_drop_index_if_exists(self, connection): + table, idx = self._table_index_fixture() + + table.create(connection) + + is_true( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.DropIndex(idx, if_exists=True)) + + is_false( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.DropIndex(idx, if_exists=True)) + + +class FutureTableDDLTest(fixtures.FutureEngineMixin, TableDDLTest): + pass + + +class LongNameBlowoutTest(fixtures.TestBase): + """test the creation of a variety of DDL structures and ensure + label length limits pass on backends + + """ + + __backend__ = True + + def fk(self, metadata, connection): + convention = { + "fk": "foreign_key_%(table_name)s_" + "%(column_0_N_name)s_" + "%(referred_table_name)s_" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(20)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + test_needs_fk=True, + ) + + cons = ForeignKeyConstraint( + ["aid"], ["a_things_with_stuff.id_long_column_name"] + ) + Table( + "b_related_things_of_value", + metadata, + Column( + "aid", + ), + cons, + test_needs_fk=True, + ) + actual_name = cons.name + + metadata.create_all(connection) + + if testing.requires.foreign_key_constraint_name_reflection.enabled: + insp = inspect(connection) + fks = insp.get_foreign_keys("b_related_things_of_value") + reflected_name = fks[0]["name"] + + return actual_name, reflected_name + else: + return actual_name, None + + def pk(self, metadata, connection): + convention = { + "pk": "primary_key_%(table_name)s_" + "%(column_0_N_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + a = Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("id_another_long_name", Integer, primary_key=True), + ) + cons = a.primary_key + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + pk = insp.get_pk_constraint("a_things_with_stuff") + reflected_name = pk["name"] + return actual_name, reflected_name + + def ix(self, metadata, connection): + convention = { + "ix": "index_%(table_name)s_" + "%(column_0_N_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + a = Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("id_another_long_name", Integer), + ) + cons = Index(None, a.c.id_long_column_name, a.c.id_another_long_name) + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + ix = insp.get_indexes("a_things_with_stuff") + reflected_name = ix[0]["name"] + return actual_name, reflected_name + + def uq(self, metadata, connection): + convention = { + "uq": "unique_constraint_%(table_name)s_" + "%(column_0_N_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + cons = UniqueConstraint("id_long_column_name", "id_another_long_name") + Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("id_another_long_name", Integer), + cons, + ) + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + uq = insp.get_unique_constraints("a_things_with_stuff") + reflected_name = uq[0]["name"] + return actual_name, reflected_name + + def ck(self, metadata, connection): + convention = { + "ck": "check_constraint_%(table_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + cons = CheckConstraint("some_long_column_name > 5") + Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("some_long_column_name", Integer), + cons, + ) + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + ck = insp.get_check_constraints("a_things_with_stuff") + reflected_name = ck[0]["name"] + return actual_name, reflected_name + + @testing.combinations( + ("fk",), + ("pk",), + ("ix",), + ("ck", testing.requires.check_constraint_reflection.as_skips()), + ("uq", testing.requires.unique_constraint_reflection.as_skips()), + argnames="type_", + ) + def test_long_convention_name(self, type_, metadata, connection): + actual_name, reflected_name = getattr(self, type_)( + metadata, connection + ) + + assert len(actual_name) > 255 + + if reflected_name is not None: + overlap = actual_name[0 : len(reflected_name)] + if len(overlap) < len(actual_name): + eq_(overlap[0:-5], reflected_name[0 : len(overlap) - 5]) + else: + eq_(overlap, reflected_name) + + +__all__ = ("TableDDLTest", "FutureTableDDLTest", "LongNameBlowoutTest") diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_deprecations.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_deprecations.py new file mode 100644 index 0000000..07970c0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_deprecations.py @@ -0,0 +1,153 @@ +# testing/suite/test_deprecations.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .. import fixtures +from ..assertions import eq_ +from ..schema import Column +from ..schema import Table +from ... import Integer +from ... import select +from ... import testing +from ... import union + + +class DeprecatedCompoundSelectTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + ], + ) + + def _assert_result(self, conn, select, result, params=()): + eq_(conn.execute(select, params).fetchall(), result) + + def test_plain_union(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + # note we've had to remove one use case entirely, which is this + # one. the Select gets its FROMS from the WHERE clause and the + # columns clause, but not the ORDER BY, which means the old ".c" system + # allowed you to "order_by(s.c.foo)" to get an unnamed column in the + # ORDER BY without adding the SELECT into the FROM and breaking the + # query. Users will have to adjust for this use case if they were doing + # it before. + def _dont_test_select_from_plain_union(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2).alias().select() + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.order_by_col_from_union + @testing.requires.parens_in_union_contained_select_w_limit_offset + def test_limit_offset_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_wo_limit_offset + def test_order_by_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_distinct_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).distinct() + s2 = select(table).where(table.c.id == 3).distinct() + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_limit_offset_aliased_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = ( + select(table) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + s2 = ( + select(table) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_dialect.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_dialect.py new file mode 100644 index 0000000..6964720 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_dialect.py @@ -0,0 +1,740 @@ +# testing/suite/test_dialect.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +import importlib + +from . import testing +from .. import assert_raises +from .. import config +from .. import engines +from .. import eq_ +from .. import fixtures +from .. import is_not_none +from .. import is_true +from .. import ne_ +from .. import provide_metadata +from ..assertions import expect_raises +from ..assertions import expect_raises_message +from ..config import requirements +from ..provision import set_default_schema_on_connection +from ..schema import Column +from ..schema import Table +from ... import bindparam +from ... import dialects +from ... import event +from ... import exc +from ... import Integer +from ... import literal_column +from ... import select +from ... import String +from ...sql.compiler import Compiled +from ...util import inspect_getfullargspec + + +class PingTest(fixtures.TestBase): + __backend__ = True + + def test_do_ping(self): + with testing.db.connect() as conn: + is_true( + testing.db.dialect.do_ping(conn.connection.dbapi_connection) + ) + + +class ArgSignatureTest(fixtures.TestBase): + """test that all visit_XYZ() in :class:`_sql.Compiler` subclasses have + ``**kw``, for #8988. + + This test uses runtime code inspection. Does not need to be a + ``__backend__`` test as it only needs to run once provided all target + dialects have been imported. + + For third party dialects, the suite would be run with that third + party as a "--dburi", which means its compiler classes will have been + imported by the time this test runs. + + """ + + def _all_subclasses(): # type: ignore # noqa + for d in dialects.__all__: + if not d.startswith("_"): + importlib.import_module("sqlalchemy.dialects.%s" % d) + + stack = [Compiled] + + while stack: + cls = stack.pop(0) + stack.extend(cls.__subclasses__()) + yield cls + + @testing.fixture(params=list(_all_subclasses())) + def all_subclasses(self, request): + yield request.param + + def test_all_visit_methods_accept_kw(self, all_subclasses): + cls = all_subclasses + + for k in cls.__dict__: + if k.startswith("visit_"): + meth = getattr(cls, k) + + insp = inspect_getfullargspec(meth) + is_not_none( + insp.varkw, + f"Compiler visit method {cls.__name__}.{k}() does " + "not accommodate for **kw in its argument signature", + ) + + +class ExceptionTest(fixtures.TablesTest): + """Test basic exception wrapping. + + DBAPIs vary a lot in exception behavior so to actually anticipate + specific exceptions from real round trips, we need to be conservative. + + """ + + run_deletes = "each" + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + + @requirements.duplicate_key_raises_integrity_error + def test_integrity_error(self): + with config.db.connect() as conn: + trans = conn.begin() + conn.execute( + self.tables.manual_pk.insert(), {"id": 1, "data": "d1"} + ) + + assert_raises( + exc.IntegrityError, + conn.execute, + self.tables.manual_pk.insert(), + {"id": 1, "data": "d1"}, + ) + + trans.rollback() + + def test_exception_with_non_ascii(self): + with config.db.connect() as conn: + try: + # try to create an error message that likely has non-ascii + # characters in the DBAPI's message string. unfortunately + # there's no way to make this happen with some drivers like + # mysqlclient, pymysql. this at least does produce a non- + # ascii error message for cx_oracle, psycopg2 + conn.execute(select(literal_column("méil"))) + assert False + except exc.DBAPIError as err: + err_str = str(err) + + assert str(err.orig) in str(err) + + assert isinstance(err_str, str) + + +class IsolationLevelTest(fixtures.TestBase): + __backend__ = True + + __requires__ = ("isolation_level",) + + def _get_non_default_isolation_level(self): + levels = requirements.get_isolation_levels(config) + + default = levels["default"] + supported = levels["supported"] + + s = set(supported).difference(["AUTOCOMMIT", default]) + if s: + return s.pop() + else: + config.skip_test("no non-default isolation level available") + + def test_default_isolation_level(self): + eq_( + config.db.dialect.default_isolation_level, + requirements.get_isolation_levels(config)["default"], + ) + + def test_non_default_isolation_level(self): + non_default = self._get_non_default_isolation_level() + + with config.db.connect() as conn: + existing = conn.get_isolation_level() + + ne_(existing, non_default) + + conn.execution_options(isolation_level=non_default) + + eq_(conn.get_isolation_level(), non_default) + + conn.dialect.reset_isolation_level( + conn.connection.dbapi_connection + ) + + eq_(conn.get_isolation_level(), existing) + + def test_all_levels(self): + levels = requirements.get_isolation_levels(config) + + all_levels = levels["supported"] + + for level in set(all_levels).difference(["AUTOCOMMIT"]): + with config.db.connect() as conn: + conn.execution_options(isolation_level=level) + + eq_(conn.get_isolation_level(), level) + + trans = conn.begin() + trans.rollback() + + eq_(conn.get_isolation_level(), level) + + with config.db.connect() as conn: + eq_( + conn.get_isolation_level(), + levels["default"], + ) + + @testing.requires.get_isolation_level_values + def test_invalid_level_execution_option(self, connection_no_trans): + """test for the new get_isolation_level_values() method""" + + connection = connection_no_trans + with expect_raises_message( + exc.ArgumentError, + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for '%s' are %s" + % ( + "FOO", + connection.dialect.name, + ", ".join( + requirements.get_isolation_levels(config)["supported"] + ), + ), + ): + connection.execution_options(isolation_level="FOO") + + @testing.requires.get_isolation_level_values + @testing.requires.dialect_level_isolation_level_param + def test_invalid_level_engine_param(self, testing_engine): + """test for the new get_isolation_level_values() method + and support for the dialect-level 'isolation_level' parameter. + + """ + + eng = testing_engine(options=dict(isolation_level="FOO")) + with expect_raises_message( + exc.ArgumentError, + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for '%s' are %s" + % ( + "FOO", + eng.dialect.name, + ", ".join( + requirements.get_isolation_levels(config)["supported"] + ), + ), + ): + eng.connect() + + @testing.requires.independent_readonly_connections + def test_dialect_user_setting_is_restored(self, testing_engine): + levels = requirements.get_isolation_levels(config) + default = levels["default"] + supported = ( + sorted( + set(levels["supported"]).difference([default, "AUTOCOMMIT"]) + ) + )[0] + + e = testing_engine(options={"isolation_level": supported}) + + with e.connect() as conn: + eq_(conn.get_isolation_level(), supported) + + with e.connect() as conn: + conn.execution_options(isolation_level=default) + eq_(conn.get_isolation_level(), default) + + with e.connect() as conn: + eq_(conn.get_isolation_level(), supported) + + +class AutocommitIsolationTest(fixtures.TablesTest): + run_deletes = "each" + + __requires__ = ("autocommit",) + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + test_needs_acid=True, + ) + + def _test_conn_autocommits(self, conn, autocommit): + trans = conn.begin() + conn.execute( + self.tables.some_table.insert(), {"id": 1, "data": "some data"} + ) + trans.rollback() + + eq_( + conn.scalar(select(self.tables.some_table.c.id)), + 1 if autocommit else None, + ) + conn.rollback() + + with conn.begin(): + conn.execute(self.tables.some_table.delete()) + + def test_autocommit_on(self, connection_no_trans): + conn = connection_no_trans + c2 = conn.execution_options(isolation_level="AUTOCOMMIT") + self._test_conn_autocommits(c2, True) + + c2.dialect.reset_isolation_level(c2.connection.dbapi_connection) + + self._test_conn_autocommits(conn, False) + + def test_autocommit_off(self, connection_no_trans): + conn = connection_no_trans + self._test_conn_autocommits(conn, False) + + def test_turn_autocommit_off_via_default_iso_level( + self, connection_no_trans + ): + conn = connection_no_trans + conn = conn.execution_options(isolation_level="AUTOCOMMIT") + self._test_conn_autocommits(conn, True) + + conn.execution_options( + isolation_level=requirements.get_isolation_levels(config)[ + "default" + ] + ) + self._test_conn_autocommits(conn, False) + + @testing.requires.independent_readonly_connections + @testing.variation("use_dialect_setting", [True, False]) + def test_dialect_autocommit_is_restored( + self, testing_engine, use_dialect_setting + ): + """test #10147""" + + if use_dialect_setting: + e = testing_engine(options={"isolation_level": "AUTOCOMMIT"}) + else: + e = testing_engine().execution_options( + isolation_level="AUTOCOMMIT" + ) + + levels = requirements.get_isolation_levels(config) + + default = levels["default"] + + with e.connect() as conn: + self._test_conn_autocommits(conn, True) + + with e.connect() as conn: + conn.execution_options(isolation_level=default) + self._test_conn_autocommits(conn, False) + + with e.connect() as conn: + self._test_conn_autocommits(conn, True) + + +class EscapingTest(fixtures.TestBase): + @provide_metadata + def test_percent_sign_round_trip(self): + """test that the DBAPI accommodates for escaped / nonescaped + percent signs in a way that matches the compiler + + """ + m = self.metadata + t = Table("t", m, Column("data", String(50))) + t.create(config.db) + with config.db.begin() as conn: + conn.execute(t.insert(), dict(data="some % value")) + conn.execute(t.insert(), dict(data="some %% other value")) + + eq_( + conn.scalar( + select(t.c.data).where( + t.c.data == literal_column("'some % value'") + ) + ), + "some % value", + ) + + eq_( + conn.scalar( + select(t.c.data).where( + t.c.data == literal_column("'some %% other value'") + ) + ), + "some %% other value", + ) + + +class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase): + __backend__ = True + + __requires__ = ("default_schema_name_switch",) + + def test_control_case(self): + default_schema_name = config.db.dialect.default_schema_name + + eng = engines.testing_engine() + with eng.connect(): + pass + + eq_(eng.dialect.default_schema_name, default_schema_name) + + def test_wont_work_wo_insert(self): + default_schema_name = config.db.dialect.default_schema_name + + eng = engines.testing_engine() + + @event.listens_for(eng, "connect") + def on_connect(dbapi_connection, connection_record): + set_default_schema_on_connection( + config, dbapi_connection, config.test_schema + ) + + with eng.connect() as conn: + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + + eq_(eng.dialect.default_schema_name, default_schema_name) + + def test_schema_change_on_connect(self): + eng = engines.testing_engine() + + @event.listens_for(eng, "connect", insert=True) + def on_connect(dbapi_connection, connection_record): + set_default_schema_on_connection( + config, dbapi_connection, config.test_schema + ) + + with eng.connect() as conn: + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + + eq_(eng.dialect.default_schema_name, config.test_schema) + + def test_schema_change_works_w_transactions(self): + eng = engines.testing_engine() + + @event.listens_for(eng, "connect", insert=True) + def on_connect(dbapi_connection, *arg): + set_default_schema_on_connection( + config, dbapi_connection, config.test_schema + ) + + with eng.connect() as conn: + trans = conn.begin() + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + trans.rollback() + + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + + eq_(eng.dialect.default_schema_name, config.test_schema) + + +class FutureWeCanSetDefaultSchemaWEventsTest( + fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest +): + pass + + +class DifficultParametersTest(fixtures.TestBase): + __backend__ = True + + tough_parameters = testing.combinations( + ("boring",), + ("per cent",), + ("per % cent",), + ("%percent",), + ("par(ens)",), + ("percent%(ens)yah",), + ("col:ons",), + ("_starts_with_underscore",), + ("dot.s",), + ("more :: %colons%",), + ("_name",), + ("___name",), + ("[BracketsAndCase]",), + ("42numbers",), + ("percent%signs",), + ("has spaces",), + ("/slashes/",), + ("more/slashes",), + ("q?marks",), + ("1param",), + ("1col:on",), + argnames="paramname", + ) + + @tough_parameters + @config.requirements.unusual_column_name_characters + def test_round_trip_same_named_column( + self, paramname, connection, metadata + ): + name = paramname + + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column(name, String(50), nullable=False), + ) + + # table is created + t.create(connection) + + # automatic param generated by insert + connection.execute(t.insert().values({"id": 1, name: "some name"})) + + # automatic param generated by criteria, plus selecting the column + stmt = select(t.c[name]).where(t.c[name] == "some name") + + eq_(connection.scalar(stmt), "some name") + + # use the name in a param explicitly + stmt = select(t.c[name]).where(t.c[name] == bindparam(name)) + + row = connection.execute(stmt, {name: "some name"}).first() + + # name works as the key from cursor.description + eq_(row._mapping[name], "some name") + + # use expanding IN + stmt = select(t.c[name]).where( + t.c[name].in_(["some name", "some other_name"]) + ) + + row = connection.execute(stmt).first() + + @testing.fixture + def multirow_fixture(self, metadata, connection): + mytable = Table( + "mytable", + metadata, + Column("myid", Integer), + Column("name", String(50)), + Column("desc", String(50)), + ) + + mytable.create(connection) + + connection.execute( + mytable.insert(), + [ + {"myid": 1, "name": "a", "desc": "a_desc"}, + {"myid": 2, "name": "b", "desc": "b_desc"}, + {"myid": 3, "name": "c", "desc": "c_desc"}, + {"myid": 4, "name": "d", "desc": "d_desc"}, + ], + ) + yield mytable + + @tough_parameters + def test_standalone_bindparam_escape( + self, paramname, connection, multirow_fixture + ): + tbl1 = multirow_fixture + stmt = select(tbl1.c.myid).where( + tbl1.c.name == bindparam(paramname, value="x") + ) + res = connection.scalar(stmt, {paramname: "c"}) + eq_(res, 3) + + @tough_parameters + def test_standalone_bindparam_escape_expanding( + self, paramname, connection, multirow_fixture + ): + tbl1 = multirow_fixture + stmt = ( + select(tbl1.c.myid) + .where(tbl1.c.name.in_(bindparam(paramname, value=["a", "b"]))) + .order_by(tbl1.c.myid) + ) + + res = connection.scalars(stmt, {paramname: ["d", "a"]}).all() + eq_(res, [1, 4]) + + +class ReturningGuardsTest(fixtures.TablesTest): + """test that the various 'returning' flags are set appropriately""" + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "t", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + + @testing.fixture + def run_stmt(self, connection): + t = self.tables.t + + def go(stmt, executemany, id_param_name, expect_success): + stmt = stmt.returning(t.c.id) + + if executemany: + if not expect_success: + # for RETURNING executemany(), we raise our own + # error as this is independent of general RETURNING + # support + with expect_raises_message( + exc.StatementError, + rf"Dialect {connection.dialect.name}\+" + f"{connection.dialect.driver} with " + f"current server capabilities does not support " + f".*RETURNING when executemany is used", + ): + result = connection.execute( + stmt, + [ + {id_param_name: 1, "data": "d1"}, + {id_param_name: 2, "data": "d2"}, + {id_param_name: 3, "data": "d3"}, + ], + ) + else: + result = connection.execute( + stmt, + [ + {id_param_name: 1, "data": "d1"}, + {id_param_name: 2, "data": "d2"}, + {id_param_name: 3, "data": "d3"}, + ], + ) + eq_(result.all(), [(1,), (2,), (3,)]) + else: + if not expect_success: + # for RETURNING execute(), we pass all the way to the DB + # and let it fail + with expect_raises(exc.DBAPIError): + connection.execute( + stmt, {id_param_name: 1, "data": "d1"} + ) + else: + result = connection.execute( + stmt, {id_param_name: 1, "data": "d1"} + ) + eq_(result.all(), [(1,)]) + + return go + + def test_insert_single(self, connection, run_stmt): + t = self.tables.t + + stmt = t.insert() + + run_stmt(stmt, False, "id", connection.dialect.insert_returning) + + def test_insert_many(self, connection, run_stmt): + t = self.tables.t + + stmt = t.insert() + + run_stmt( + stmt, True, "id", connection.dialect.insert_executemany_returning + ) + + def test_update_single(self, connection, run_stmt): + t = self.tables.t + + connection.execute( + t.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + stmt = t.update().where(t.c.id == bindparam("b_id")) + + run_stmt(stmt, False, "b_id", connection.dialect.update_returning) + + def test_update_many(self, connection, run_stmt): + t = self.tables.t + + connection.execute( + t.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + stmt = t.update().where(t.c.id == bindparam("b_id")) + + run_stmt( + stmt, True, "b_id", connection.dialect.update_executemany_returning + ) + + def test_delete_single(self, connection, run_stmt): + t = self.tables.t + + connection.execute( + t.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + stmt = t.delete().where(t.c.id == bindparam("b_id")) + + run_stmt(stmt, False, "b_id", connection.dialect.delete_returning) + + def test_delete_many(self, connection, run_stmt): + t = self.tables.t + + connection.execute( + t.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + stmt = t.delete().where(t.c.id == bindparam("b_id")) + + run_stmt( + stmt, True, "b_id", connection.dialect.delete_executemany_returning + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_insert.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_insert.py new file mode 100644 index 0000000..1cff044 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_insert.py @@ -0,0 +1,630 @@ +# testing/suite/test_insert.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from decimal import Decimal +import uuid + +from . import testing +from .. import fixtures +from ..assertions import eq_ +from ..config import requirements +from ..schema import Column +from ..schema import Table +from ... import Double +from ... import Float +from ... import Identity +from ... import Integer +from ... import literal +from ... import literal_column +from ... import Numeric +from ... import select +from ... import String +from ...types import LargeBinary +from ...types import UUID +from ...types import Uuid + + +class LastrowidTest(fixtures.TablesTest): + run_deletes = "each" + + __backend__ = True + + __requires__ = "implements_get_lastrowid", "autoincrement_insert" + + @classmethod + def define_tables(cls, metadata): + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + implicit_returning=False, + ) + + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + implicit_returning=False, + ) + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_( + row, + ( + conn.dialect.default_sequence_base, + "some data", + ), + ) + + def test_autoincrement_on_insert(self, connection): + connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + self._assert_round_trip(self.tables.autoinc_pk, connection) + + def test_last_inserted_id(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) + eq_(r.inserted_primary_key, (pk,)) + + @requirements.dbapi_lastrowid + def test_native_lastrowid_autoinc(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + lastrowid = r.lastrowid + pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) + eq_(lastrowid, pk) + + +class InsertBehaviorTest(fixtures.TablesTest): + run_deletes = "each" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + Table( + "no_implicit_returning", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + implicit_returning=False, + ) + Table( + "includes_defaults", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("x", Integer, default=5), + Column( + "y", + Integer, + default=literal_column("2", type_=Integer) + literal(2), + ), + ) + + @testing.variation("style", ["plain", "return_defaults"]) + @testing.variation("executemany", [True, False]) + def test_no_results_for_non_returning_insert( + self, connection, style, executemany + ): + """test another INSERT issue found during #10453""" + + table = self.tables.no_implicit_returning + + stmt = table.insert() + if style.return_defaults: + stmt = stmt.return_defaults() + + if executemany: + data = [ + {"data": "d1"}, + {"data": "d2"}, + {"data": "d3"}, + {"data": "d4"}, + {"data": "d5"}, + ] + else: + data = {"data": "d1"} + + r = connection.execute(stmt, data) + assert not r.returns_rows + + @requirements.autoincrement_insert + def test_autoclose_on_insert(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + assert r._soft_closed + assert not r.closed + assert r.is_insert + + # new as of I8091919d45421e3f53029b8660427f844fee0228; for the moment + # an insert where the PK was taken from a row that the dialect + # selected, as is the case for mssql/pyodbc, will still report + # returns_rows as true because there's a cursor description. in that + # case, the row had to have been consumed at least. + assert not r.returns_rows or r.fetchone() is None + + @requirements.insert_returning + def test_autoclose_on_insert_implicit_returning(self, connection): + r = connection.execute( + # return_defaults() ensures RETURNING will be used, + # new in 2.0 as sqlite/mariadb offer both RETURNING and + # cursor.lastrowid + self.tables.autoinc_pk.insert().return_defaults(), + dict(data="some data"), + ) + assert r._soft_closed + assert not r.closed + assert r.is_insert + + # note we are experimenting with having this be True + # as of I8091919d45421e3f53029b8660427f844fee0228 . + # implicit returning has fetched the row, but it still is a + # "returns rows" + assert r.returns_rows + + # and we should be able to fetchone() on it, we just get no row + eq_(r.fetchone(), None) + + # and the keys, etc. + eq_(r.keys(), ["id"]) + + # but the dialect took in the row already. not really sure + # what the best behavior is. + + @requirements.empty_inserts + def test_empty_insert(self, connection): + r = connection.execute(self.tables.autoinc_pk.insert()) + assert r._soft_closed + assert not r.closed + + r = connection.execute( + self.tables.autoinc_pk.select().where( + self.tables.autoinc_pk.c.id != None + ) + ) + eq_(len(r.all()), 1) + + @requirements.empty_inserts_executemany + def test_empty_insert_multiple(self, connection): + r = connection.execute(self.tables.autoinc_pk.insert(), [{}, {}, {}]) + assert r._soft_closed + assert not r.closed + + r = connection.execute( + self.tables.autoinc_pk.select().where( + self.tables.autoinc_pk.c.id != None + ) + ) + + eq_(len(r.all()), 3) + + @requirements.insert_from_select + def test_insert_from_select_autoinc(self, connection): + src_table = self.tables.manual_pk + dest_table = self.tables.autoinc_pk + connection.execute( + src_table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ], + ) + + result = connection.execute( + dest_table.insert().from_select( + ("data",), + select(src_table.c.data).where( + src_table.c.data.in_(["data2", "data3"]) + ), + ) + ) + + eq_(result.inserted_primary_key, (None,)) + + result = connection.execute( + select(dest_table.c.data).order_by(dest_table.c.data) + ) + eq_(result.fetchall(), [("data2",), ("data3",)]) + + @requirements.insert_from_select + def test_insert_from_select_autoinc_no_rows(self, connection): + src_table = self.tables.manual_pk + dest_table = self.tables.autoinc_pk + + result = connection.execute( + dest_table.insert().from_select( + ("data",), + select(src_table.c.data).where( + src_table.c.data.in_(["data2", "data3"]) + ), + ) + ) + eq_(result.inserted_primary_key, (None,)) + + result = connection.execute( + select(dest_table.c.data).order_by(dest_table.c.data) + ) + + eq_(result.fetchall(), []) + + @requirements.insert_from_select + def test_insert_from_select(self, connection): + table = self.tables.manual_pk + connection.execute( + table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ], + ) + + connection.execute( + table.insert() + .inline() + .from_select( + ("id", "data"), + select(table.c.id + 5, table.c.data).where( + table.c.data.in_(["data2", "data3"]) + ), + ) + ) + + eq_( + connection.execute( + select(table.c.data).order_by(table.c.data) + ).fetchall(), + [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)], + ) + + @requirements.insert_from_select + def test_insert_from_select_with_defaults(self, connection): + table = self.tables.includes_defaults + connection.execute( + table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ], + ) + + connection.execute( + table.insert() + .inline() + .from_select( + ("id", "data"), + select(table.c.id + 5, table.c.data).where( + table.c.data.in_(["data2", "data3"]) + ), + ) + ) + + eq_( + connection.execute( + select(table).order_by(table.c.data, table.c.id) + ).fetchall(), + [ + (1, "data1", 5, 4), + (2, "data2", 5, 4), + (7, "data2", 5, 4), + (3, "data3", 5, 4), + (8, "data3", 5, 4), + ], + ) + + +class ReturningTest(fixtures.TablesTest): + run_create_tables = "each" + __requires__ = "insert_returning", "autoincrement_insert" + __backend__ = True + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_( + row, + ( + conn.dialect.default_sequence_base, + "some data", + ), + ) + + @classmethod + def define_tables(cls, metadata): + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + + @requirements.fetch_rows_post_commit + def test_explicit_returning_pk_autocommit(self, connection): + table = self.tables.autoinc_pk + r = connection.execute( + table.insert().returning(table.c.id), dict(data="some data") + ) + pk = r.first()[0] + fetched_pk = connection.scalar(select(table.c.id)) + eq_(fetched_pk, pk) + + def test_explicit_returning_pk_no_autocommit(self, connection): + table = self.tables.autoinc_pk + r = connection.execute( + table.insert().returning(table.c.id), dict(data="some data") + ) + + pk = r.first()[0] + fetched_pk = connection.scalar(select(table.c.id)) + eq_(fetched_pk, pk) + + def test_autoincrement_on_insert_implicit_returning(self, connection): + connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + self._assert_round_trip(self.tables.autoinc_pk, connection) + + def test_last_inserted_id_implicit_returning(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) + eq_(r.inserted_primary_key, (pk,)) + + @requirements.insert_executemany_returning + def test_insertmanyvalues_returning(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert().returning( + self.tables.autoinc_pk.c.id + ), + [ + {"data": "d1"}, + {"data": "d2"}, + {"data": "d3"}, + {"data": "d4"}, + {"data": "d5"}, + ], + ) + rall = r.all() + + pks = connection.execute(select(self.tables.autoinc_pk.c.id)) + + eq_(rall, pks.all()) + + @testing.combinations( + (Double(), 8.5514716, True), + ( + Double(53), + 8.5514716, + True, + testing.requires.float_or_double_precision_behaves_generically, + ), + (Float(), 8.5514, True), + ( + Float(8), + 8.5514, + True, + testing.requires.float_or_double_precision_behaves_generically, + ), + ( + Numeric(precision=15, scale=12, asdecimal=False), + 8.5514716, + True, + testing.requires.literal_float_coercion, + ), + ( + Numeric(precision=15, scale=12, asdecimal=True), + Decimal("8.5514716"), + False, + ), + argnames="type_,value,do_rounding", + ) + @testing.variation("sort_by_parameter_order", [True, False]) + @testing.variation("multiple_rows", [True, False]) + def test_insert_w_floats( + self, + connection, + metadata, + sort_by_parameter_order, + type_, + value, + do_rounding, + multiple_rows, + ): + """test #9701. + + this tests insertmanyvalues as well as decimal / floating point + RETURNING types + + """ + + t = Table( + # Oracle backends seems to be getting confused if + # this table is named the same as the one + # in test_imv_returning_datatypes. use a different name + "f_t", + metadata, + Column("id", Integer, Identity(), primary_key=True), + Column("value", type_), + ) + + t.create(connection) + + result = connection.execute( + t.insert().returning( + t.c.id, + t.c.value, + sort_by_parameter_order=bool(sort_by_parameter_order), + ), + ( + [{"value": value} for i in range(10)] + if multiple_rows + else {"value": value} + ), + ) + + if multiple_rows: + i_range = range(1, 11) + else: + i_range = range(1, 2) + + # we want to test only that we are getting floating points back + # with some degree of the original value maintained, that it is not + # being truncated to an integer. there's too much variation in how + # drivers return floats, which should not be relied upon to be + # exact, for us to just compare as is (works for PG drivers but not + # others) so we use rounding here. There's precedent for this + # in suite/test_types.py::NumericTest as well + + if do_rounding: + eq_( + {(id_, round(val_, 5)) for id_, val_ in result}, + {(id_, round(value, 5)) for id_ in i_range}, + ) + + eq_( + { + round(val_, 5) + for val_ in connection.scalars(select(t.c.value)) + }, + {round(value, 5)}, + ) + else: + eq_( + set(result), + {(id_, value) for id_ in i_range}, + ) + + eq_( + set(connection.scalars(select(t.c.value))), + {value}, + ) + + @testing.combinations( + ( + "non_native_uuid", + Uuid(native_uuid=False), + uuid.uuid4(), + ), + ( + "non_native_uuid_str", + Uuid(as_uuid=False, native_uuid=False), + str(uuid.uuid4()), + ), + ( + "generic_native_uuid", + Uuid(native_uuid=True), + uuid.uuid4(), + testing.requires.uuid_data_type, + ), + ( + "generic_native_uuid_str", + Uuid(as_uuid=False, native_uuid=True), + str(uuid.uuid4()), + testing.requires.uuid_data_type, + ), + ("UUID", UUID(), uuid.uuid4(), testing.requires.uuid_data_type), + ( + "LargeBinary1", + LargeBinary(), + b"this is binary", + ), + ("LargeBinary2", LargeBinary(), b"7\xe7\x9f"), + argnames="type_,value", + id_="iaa", + ) + @testing.variation("sort_by_parameter_order", [True, False]) + @testing.variation("multiple_rows", [True, False]) + @testing.requires.insert_returning + def test_imv_returning_datatypes( + self, + connection, + metadata, + sort_by_parameter_order, + type_, + value, + multiple_rows, + ): + """test #9739, #9808 (similar to #9701). + + this tests insertmanyvalues in conjunction with various datatypes. + + These tests are particularly for the asyncpg driver which needs + most types to be explicitly cast for the new IMV format + + """ + t = Table( + "d_t", + metadata, + Column("id", Integer, Identity(), primary_key=True), + Column("value", type_), + ) + + t.create(connection) + + result = connection.execute( + t.insert().returning( + t.c.id, + t.c.value, + sort_by_parameter_order=bool(sort_by_parameter_order), + ), + ( + [{"value": value} for i in range(10)] + if multiple_rows + else {"value": value} + ), + ) + + if multiple_rows: + i_range = range(1, 11) + else: + i_range = range(1, 2) + + eq_( + set(result), + {(id_, value) for id_ in i_range}, + ) + + eq_( + set(connection.scalars(select(t.c.value))), + {value}, + ) + + +__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest") diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_reflection.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_reflection.py new file mode 100644 index 0000000..f257d2f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_reflection.py @@ -0,0 +1,3128 @@ +# testing/suite/test_reflection.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import operator +import re + +import sqlalchemy as sa +from .. import config +from .. import engines +from .. import eq_ +from .. import expect_raises +from .. import expect_raises_message +from .. import expect_warnings +from .. import fixtures +from .. import is_ +from ..provision import get_temp_table_name +from ..provision import temp_table_keyword_args +from ..schema import Column +from ..schema import Table +from ... import event +from ... import ForeignKey +from ... import func +from ... import Identity +from ... import inspect +from ... import Integer +from ... import MetaData +from ... import String +from ... import testing +from ... import types as sql_types +from ...engine import Inspector +from ...engine import ObjectKind +from ...engine import ObjectScope +from ...exc import NoSuchTableError +from ...exc import UnreflectableTableError +from ...schema import DDL +from ...schema import Index +from ...sql.elements import quoted_name +from ...sql.schema import BLANK_SCHEMA +from ...testing import ComparesIndexes +from ...testing import ComparesTables +from ...testing import is_false +from ...testing import is_true +from ...testing import mock + + +metadata, users = None, None + + +class OneConnectionTablesTest(fixtures.TablesTest): + @classmethod + def setup_bind(cls): + # TODO: when temp tables are subject to server reset, + # this will also have to disable that server reset from + # happening + if config.requirements.independent_connections.enabled: + from sqlalchemy import pool + + return engines.testing_engine( + options=dict(poolclass=pool.StaticPool, scope="class"), + ) + else: + return config.db + + +class HasTableTest(OneConnectionTablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + if testing.requires.schemas.enabled: + Table( + "test_table_s", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + schema=config.test_schema, + ) + + if testing.requires.view_reflection: + cls.define_views(metadata) + if testing.requires.has_temp_table.enabled: + cls.define_temp_tables(metadata) + + @classmethod + def define_views(cls, metadata): + query = "CREATE VIEW vv AS SELECT id, data FROM test_table" + + event.listen(metadata, "after_create", DDL(query)) + event.listen(metadata, "before_drop", DDL("DROP VIEW vv")) + + if testing.requires.schemas.enabled: + query = ( + "CREATE VIEW %s.vv AS SELECT id, data FROM %s.test_table_s" + % ( + config.test_schema, + config.test_schema, + ) + ) + event.listen(metadata, "after_create", DDL(query)) + event.listen( + metadata, + "before_drop", + DDL("DROP VIEW %s.vv" % (config.test_schema)), + ) + + @classmethod + def temp_table_name(cls): + return get_temp_table_name( + config, config.db, f"user_tmp_{config.ident}" + ) + + @classmethod + def define_temp_tables(cls, metadata): + kw = temp_table_keyword_args(config, config.db) + table_name = cls.temp_table_name() + user_tmp = Table( + table_name, + metadata, + Column("id", sa.INT, primary_key=True), + Column("name", sa.VARCHAR(50)), + **kw, + ) + if ( + testing.requires.view_reflection.enabled + and testing.requires.temporary_views.enabled + ): + event.listen( + user_tmp, + "after_create", + DDL( + "create temporary view user_tmp_v as " + "select * from user_tmp_%s" % config.ident + ), + ) + event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v")) + + def test_has_table(self): + with config.db.begin() as conn: + is_true(config.db.dialect.has_table(conn, "test_table")) + is_false(config.db.dialect.has_table(conn, "test_table_s")) + is_false(config.db.dialect.has_table(conn, "nonexistent_table")) + + def test_has_table_cache(self, metadata): + insp = inspect(config.db) + is_true(insp.has_table("test_table")) + nt = Table("new_table", metadata, Column("col", Integer)) + is_false(insp.has_table("new_table")) + nt.create(config.db) + try: + is_false(insp.has_table("new_table")) + insp.clear_cache() + is_true(insp.has_table("new_table")) + finally: + nt.drop(config.db) + + @testing.requires.schemas + def test_has_table_schema(self): + with config.db.begin() as conn: + is_false( + config.db.dialect.has_table( + conn, "test_table", schema=config.test_schema + ) + ) + is_true( + config.db.dialect.has_table( + conn, "test_table_s", schema=config.test_schema + ) + ) + is_false( + config.db.dialect.has_table( + conn, "nonexistent_table", schema=config.test_schema + ) + ) + + @testing.requires.schemas + def test_has_table_nonexistent_schema(self): + with config.db.begin() as conn: + is_false( + config.db.dialect.has_table( + conn, "test_table", schema="nonexistent_schema" + ) + ) + + @testing.requires.views + def test_has_table_view(self, connection): + insp = inspect(connection) + is_true(insp.has_table("vv")) + + @testing.requires.has_temp_table + def test_has_table_temp_table(self, connection): + insp = inspect(connection) + temp_table_name = self.temp_table_name() + is_true(insp.has_table(temp_table_name)) + + @testing.requires.has_temp_table + @testing.requires.view_reflection + @testing.requires.temporary_views + def test_has_table_temp_view(self, connection): + insp = inspect(connection) + is_true(insp.has_table("user_tmp_v")) + + @testing.requires.views + @testing.requires.schemas + def test_has_table_view_schema(self, connection): + insp = inspect(connection) + is_true(insp.has_table("vv", config.test_schema)) + + +class HasIndexTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + tt = Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("data2", String(50)), + ) + Index("my_idx", tt.c.data) + + if testing.requires.schemas.enabled: + tt = Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + schema=config.test_schema, + ) + Index("my_idx_s", tt.c.data) + + kind = testing.combinations("dialect", "inspector", argnames="kind") + + def _has_index(self, kind, conn): + if kind == "dialect": + return lambda *a, **k: config.db.dialect.has_index(conn, *a, **k) + else: + return inspect(conn).has_index + + @kind + def test_has_index(self, kind, connection, metadata): + meth = self._has_index(kind, connection) + assert meth("test_table", "my_idx") + assert not meth("test_table", "my_idx_s") + assert not meth("nonexistent_table", "my_idx") + assert not meth("test_table", "nonexistent_idx") + + assert not meth("test_table", "my_idx_2") + assert not meth("test_table_2", "my_idx_3") + idx = Index("my_idx_2", self.tables.test_table.c.data2) + tbl = Table( + "test_table_2", + metadata, + Column("foo", Integer), + Index("my_idx_3", "foo"), + ) + idx.create(connection) + tbl.create(connection) + try: + if kind == "inspector": + assert not meth("test_table", "my_idx_2") + assert not meth("test_table_2", "my_idx_3") + meth.__self__.clear_cache() + assert meth("test_table", "my_idx_2") is True + assert meth("test_table_2", "my_idx_3") is True + finally: + tbl.drop(connection) + idx.drop(connection) + + @testing.requires.schemas + @kind + def test_has_index_schema(self, kind, connection): + meth = self._has_index(kind, connection) + assert meth("test_table", "my_idx_s", schema=config.test_schema) + assert not meth("test_table", "my_idx", schema=config.test_schema) + assert not meth( + "nonexistent_table", "my_idx_s", schema=config.test_schema + ) + assert not meth( + "test_table", "nonexistent_idx_s", schema=config.test_schema + ) + + +class BizarroCharacterFKResolutionTest(fixtures.TestBase): + """tests for #10275""" + + __backend__ = True + + @testing.combinations( + ("id",), ("(3)",), ("col%p",), ("[brack]",), argnames="columnname" + ) + @testing.variation("use_composite", [True, False]) + @testing.combinations( + ("plain",), + ("(2)",), + ("per % cent",), + ("[brackets]",), + argnames="tablename", + ) + def test_fk_ref( + self, connection, metadata, use_composite, tablename, columnname + ): + tt = Table( + tablename, + metadata, + Column(columnname, Integer, key="id", primary_key=True), + test_needs_fk=True, + ) + if use_composite: + tt.append_column(Column("id2", Integer, primary_key=True)) + + if use_composite: + Table( + "other", + metadata, + Column("id", Integer, primary_key=True), + Column("ref", Integer), + Column("ref2", Integer), + sa.ForeignKeyConstraint(["ref", "ref2"], [tt.c.id, tt.c.id2]), + test_needs_fk=True, + ) + else: + Table( + "other", + metadata, + Column("id", Integer, primary_key=True), + Column("ref", ForeignKey(tt.c.id)), + test_needs_fk=True, + ) + + metadata.create_all(connection) + + m2 = MetaData() + + o2 = Table("other", m2, autoload_with=connection) + t1 = m2.tables[tablename] + + assert o2.c.ref.references(t1.c[0]) + if use_composite: + assert o2.c.ref2.references(t1.c[1]) + + +class QuotedNameArgumentTest(fixtures.TablesTest): + run_create_tables = "once" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "quote ' one", + metadata, + Column("id", Integer), + Column("name", String(50)), + Column("data", String(50)), + Column("related_id", Integer), + sa.PrimaryKeyConstraint("id", name="pk quote ' one"), + sa.Index("ix quote ' one", "name"), + sa.UniqueConstraint( + "data", + name="uq quote' one", + ), + sa.ForeignKeyConstraint( + ["id"], ["related.id"], name="fk quote ' one" + ), + sa.CheckConstraint("name != 'foo'", name="ck quote ' one"), + comment=r"""quote ' one comment""", + test_needs_fk=True, + ) + + if testing.requires.symbol_names_w_double_quote.enabled: + Table( + 'quote " two', + metadata, + Column("id", Integer), + Column("name", String(50)), + Column("data", String(50)), + Column("related_id", Integer), + sa.PrimaryKeyConstraint("id", name='pk quote " two'), + sa.Index('ix quote " two', "name"), + sa.UniqueConstraint( + "data", + name='uq quote" two', + ), + sa.ForeignKeyConstraint( + ["id"], ["related.id"], name='fk quote " two' + ), + sa.CheckConstraint("name != 'foo'", name='ck quote " two '), + comment=r"""quote " two comment""", + test_needs_fk=True, + ) + + Table( + "related", + metadata, + Column("id", Integer, primary_key=True), + Column("related", Integer), + test_needs_fk=True, + ) + + if testing.requires.view_column_reflection.enabled: + if testing.requires.symbol_names_w_double_quote.enabled: + names = [ + "quote ' one", + 'quote " two', + ] + else: + names = [ + "quote ' one", + ] + for name in names: + query = "CREATE VIEW %s AS SELECT * FROM %s" % ( + config.db.dialect.identifier_preparer.quote( + "view %s" % name + ), + config.db.dialect.identifier_preparer.quote(name), + ) + + event.listen(metadata, "after_create", DDL(query)) + event.listen( + metadata, + "before_drop", + DDL( + "DROP VIEW %s" + % config.db.dialect.identifier_preparer.quote( + "view %s" % name + ) + ), + ) + + def quote_fixtures(fn): + return testing.combinations( + ("quote ' one",), + ('quote " two', testing.requires.symbol_names_w_double_quote), + )(fn) + + @quote_fixtures + def test_get_table_options(self, name): + insp = inspect(config.db) + + if testing.requires.reflect_table_options.enabled: + res = insp.get_table_options(name) + is_true(isinstance(res, dict)) + else: + with expect_raises(NotImplementedError): + res = insp.get_table_options(name) + + @quote_fixtures + @testing.requires.view_column_reflection + def test_get_view_definition(self, name): + insp = inspect(config.db) + assert insp.get_view_definition("view %s" % name) + + @quote_fixtures + def test_get_columns(self, name): + insp = inspect(config.db) + assert insp.get_columns(name) + + @quote_fixtures + def test_get_pk_constraint(self, name): + insp = inspect(config.db) + assert insp.get_pk_constraint(name) + + @quote_fixtures + def test_get_foreign_keys(self, name): + insp = inspect(config.db) + assert insp.get_foreign_keys(name) + + @quote_fixtures + def test_get_indexes(self, name): + insp = inspect(config.db) + assert insp.get_indexes(name) + + @quote_fixtures + @testing.requires.unique_constraint_reflection + def test_get_unique_constraints(self, name): + insp = inspect(config.db) + assert insp.get_unique_constraints(name) + + @quote_fixtures + @testing.requires.comment_reflection + def test_get_table_comment(self, name): + insp = inspect(config.db) + assert insp.get_table_comment(name) + + @quote_fixtures + @testing.requires.check_constraint_reflection + def test_get_check_constraints(self, name): + insp = inspect(config.db) + assert insp.get_check_constraints(name) + + +def _multi_combination(fn): + schema = testing.combinations( + None, + ( + lambda: config.test_schema, + testing.requires.schemas, + ), + argnames="schema", + ) + scope = testing.combinations( + ObjectScope.DEFAULT, + ObjectScope.TEMPORARY, + ObjectScope.ANY, + argnames="scope", + ) + kind = testing.combinations( + ObjectKind.TABLE, + ObjectKind.VIEW, + ObjectKind.MATERIALIZED_VIEW, + ObjectKind.ANY, + ObjectKind.ANY_VIEW, + ObjectKind.TABLE | ObjectKind.VIEW, + ObjectKind.TABLE | ObjectKind.MATERIALIZED_VIEW, + argnames="kind", + ) + filter_names = testing.combinations(True, False, argnames="use_filter") + + return schema(scope(kind(filter_names(fn)))) + + +class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest): + run_inserts = run_deletes = None + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + cls.define_reflected_tables(metadata, None) + if testing.requires.schemas.enabled: + cls.define_reflected_tables(metadata, testing.config.test_schema) + + @classmethod + def define_reflected_tables(cls, metadata, schema): + if schema: + schema_prefix = schema + "." + else: + schema_prefix = "" + + if testing.requires.self_referential_foreign_keys.enabled: + parent_id_args = ( + ForeignKey( + "%susers.user_id" % schema_prefix, name="user_id_fk" + ), + ) + else: + parent_id_args = () + users = Table( + "users", + metadata, + Column("user_id", sa.INT, primary_key=True), + Column("test1", sa.CHAR(5), nullable=False), + Column("test2", sa.Float(), nullable=False), + Column("parent_user_id", sa.Integer, *parent_id_args), + sa.CheckConstraint( + "test2 > 0", + name="zz_test2_gt_zero", + comment="users check constraint", + ), + sa.CheckConstraint("test2 <= 1000"), + schema=schema, + test_needs_fk=True, + ) + + Table( + "dingalings", + metadata, + Column("dingaling_id", sa.Integer, primary_key=True), + Column( + "address_id", + sa.Integer, + ForeignKey( + "%semail_addresses.address_id" % schema_prefix, + name="zz_email_add_id_fg", + comment="di fk comment", + ), + ), + Column( + "id_user", + sa.Integer, + ForeignKey("%susers.user_id" % schema_prefix), + ), + Column("data", sa.String(30), unique=True), + sa.CheckConstraint( + "address_id > 0 AND address_id < 1000", + name="address_id_gt_zero", + ), + sa.UniqueConstraint( + "address_id", + "dingaling_id", + name="zz_dingalings_multiple", + comment="di unique comment", + ), + schema=schema, + test_needs_fk=True, + ) + Table( + "email_addresses", + metadata, + Column("address_id", sa.Integer), + Column("remote_user_id", sa.Integer, ForeignKey(users.c.user_id)), + Column("email_address", sa.String(20), index=True), + sa.PrimaryKeyConstraint( + "address_id", name="email_ad_pk", comment="ea pk comment" + ), + schema=schema, + test_needs_fk=True, + ) + Table( + "comment_test", + metadata, + Column("id", sa.Integer, primary_key=True, comment="id comment"), + Column("data", sa.String(20), comment="data % comment"), + Column( + "d2", + sa.String(20), + comment=r"""Comment types type speedily ' " \ '' Fun!""", + ), + Column("d3", sa.String(42), comment="Comment\nwith\rescapes"), + schema=schema, + comment=r"""the test % ' " \ table comment""", + ) + Table( + "no_constraints", + metadata, + Column("data", sa.String(20)), + schema=schema, + comment="no\nconstraints\rhas\fescaped\vcomment", + ) + + if testing.requires.cross_schema_fk_reflection.enabled: + if schema is None: + Table( + "local_table", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(20)), + Column( + "remote_id", + ForeignKey( + "%s.remote_table_2.id" % testing.config.test_schema + ), + ), + test_needs_fk=True, + schema=config.db.dialect.default_schema_name, + ) + else: + Table( + "remote_table", + metadata, + Column("id", sa.Integer, primary_key=True), + Column( + "local_id", + ForeignKey( + "%s.local_table.id" + % config.db.dialect.default_schema_name + ), + ), + Column("data", sa.String(20)), + schema=schema, + test_needs_fk=True, + ) + Table( + "remote_table_2", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(20)), + schema=schema, + test_needs_fk=True, + ) + + if testing.requires.index_reflection.enabled: + Index("users_t_idx", users.c.test1, users.c.test2, unique=True) + Index( + "users_all_idx", users.c.user_id, users.c.test2, users.c.test1 + ) + + if not schema: + # test_needs_fk is at the moment to force MySQL InnoDB + noncol_idx_test_nopk = Table( + "noncol_idx_test_nopk", + metadata, + Column("q", sa.String(5)), + test_needs_fk=True, + ) + + noncol_idx_test_pk = Table( + "noncol_idx_test_pk", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("q", sa.String(5)), + test_needs_fk=True, + ) + + if ( + testing.requires.indexes_with_ascdesc.enabled + and testing.requires.reflect_indexes_with_ascdesc.enabled + ): + Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc()) + Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc()) + + if testing.requires.view_column_reflection.enabled: + cls.define_views(metadata, schema) + if not schema and testing.requires.temp_table_reflection.enabled: + cls.define_temp_tables(metadata) + + @classmethod + def temp_table_name(cls): + return get_temp_table_name( + config, config.db, f"user_tmp_{config.ident}" + ) + + @classmethod + def define_temp_tables(cls, metadata): + kw = temp_table_keyword_args(config, config.db) + table_name = cls.temp_table_name() + user_tmp = Table( + table_name, + metadata, + Column("id", sa.INT, primary_key=True), + Column("name", sa.VARCHAR(50)), + Column("foo", sa.INT), + # disambiguate temp table unique constraint names. this is + # pretty arbitrary for a generic dialect however we are doing + # it to suit SQL Server which will produce name conflicts for + # unique constraints created against temp tables in different + # databases. + # https://www.arbinada.com/en/node/1645 + sa.UniqueConstraint("name", name=f"user_tmp_uq_{config.ident}"), + sa.Index("user_tmp_ix", "foo"), + **kw, + ) + if ( + testing.requires.view_reflection.enabled + and testing.requires.temporary_views.enabled + ): + event.listen( + user_tmp, + "after_create", + DDL( + "create temporary view user_tmp_v as " + "select * from user_tmp_%s" % config.ident + ), + ) + event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v")) + + @classmethod + def define_views(cls, metadata, schema): + if testing.requires.materialized_views.enabled: + materialized = {"dingalings"} + else: + materialized = set() + for table_name in ("users", "email_addresses", "dingalings"): + fullname = table_name + if schema: + fullname = f"{schema}.{table_name}" + view_name = fullname + "_v" + prefix = "MATERIALIZED " if table_name in materialized else "" + query = ( + f"CREATE {prefix}VIEW {view_name} AS SELECT * FROM {fullname}" + ) + + event.listen(metadata, "after_create", DDL(query)) + if table_name in materialized: + index_name = "mat_index" + if schema and testing.against("oracle"): + index_name = f"{schema}.{index_name}" + idx = f"CREATE INDEX {index_name} ON {view_name}(data)" + event.listen(metadata, "after_create", DDL(idx)) + event.listen( + metadata, "before_drop", DDL(f"DROP {prefix}VIEW {view_name}") + ) + + def _resolve_kind(self, kind, tables, views, materialized): + res = {} + if ObjectKind.TABLE in kind: + res.update(tables) + if ObjectKind.VIEW in kind: + res.update(views) + if ObjectKind.MATERIALIZED_VIEW in kind: + res.update(materialized) + return res + + def _resolve_views(self, views, materialized): + if not testing.requires.view_column_reflection.enabled: + materialized.clear() + views.clear() + elif not testing.requires.materialized_views.enabled: + views.update(materialized) + materialized.clear() + + def _resolve_names(self, schema, scope, filter_names, values): + scope_filter = lambda _: True # noqa: E731 + if scope is ObjectScope.DEFAULT: + scope_filter = lambda k: "tmp" not in k[1] # noqa: E731 + if scope is ObjectScope.TEMPORARY: + scope_filter = lambda k: "tmp" in k[1] # noqa: E731 + + removed = { + None: {"remote_table", "remote_table_2"}, + testing.config.test_schema: { + "local_table", + "noncol_idx_test_nopk", + "noncol_idx_test_pk", + "user_tmp_v", + self.temp_table_name(), + }, + } + if not testing.requires.cross_schema_fk_reflection.enabled: + removed[None].add("local_table") + removed[testing.config.test_schema].update( + ["remote_table", "remote_table_2"] + ) + if not testing.requires.index_reflection.enabled: + removed[None].update( + ["noncol_idx_test_nopk", "noncol_idx_test_pk"] + ) + if ( + not testing.requires.temp_table_reflection.enabled + or not testing.requires.temp_table_names.enabled + ): + removed[None].update(["user_tmp_v", self.temp_table_name()]) + if not testing.requires.temporary_views.enabled: + removed[None].update(["user_tmp_v"]) + + res = { + k: v + for k, v in values.items() + if scope_filter(k) + and k[1] not in removed[schema] + and (not filter_names or k[1] in filter_names) + } + return res + + def exp_options( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + materialized = {(schema, "dingalings_v"): mock.ANY} + views = { + (schema, "email_addresses_v"): mock.ANY, + (schema, "users_v"): mock.ANY, + (schema, "user_tmp_v"): mock.ANY, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): mock.ANY, + (schema, "dingalings"): mock.ANY, + (schema, "email_addresses"): mock.ANY, + (schema, "comment_test"): mock.ANY, + (schema, "no_constraints"): mock.ANY, + (schema, "local_table"): mock.ANY, + (schema, "remote_table"): mock.ANY, + (schema, "remote_table_2"): mock.ANY, + (schema, "noncol_idx_test_nopk"): mock.ANY, + (schema, "noncol_idx_test_pk"): mock.ANY, + (schema, self.temp_table_name()): mock.ANY, + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + def exp_comments( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + empty = {"text": None} + materialized = {(schema, "dingalings_v"): empty} + views = { + (schema, "email_addresses_v"): empty, + (schema, "users_v"): empty, + (schema, "user_tmp_v"): empty, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): empty, + (schema, "dingalings"): empty, + (schema, "email_addresses"): empty, + (schema, "comment_test"): { + "text": r"""the test % ' " \ table comment""" + }, + (schema, "no_constraints"): { + "text": "no\nconstraints\rhas\fescaped\vcomment" + }, + (schema, "local_table"): empty, + (schema, "remote_table"): empty, + (schema, "remote_table_2"): empty, + (schema, "noncol_idx_test_nopk"): empty, + (schema, "noncol_idx_test_pk"): empty, + (schema, self.temp_table_name()): empty, + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + def exp_columns( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def col( + name, auto=False, default=mock.ANY, comment=None, nullable=True + ): + res = { + "name": name, + "autoincrement": auto, + "type": mock.ANY, + "default": default, + "comment": comment, + "nullable": nullable, + } + if auto == "omit": + res.pop("autoincrement") + return res + + def pk(name, **kw): + kw = {"auto": True, "default": mock.ANY, "nullable": False, **kw} + return col(name, **kw) + + materialized = { + (schema, "dingalings_v"): [ + col("dingaling_id", auto="omit", nullable=mock.ANY), + col("address_id"), + col("id_user"), + col("data"), + ] + } + views = { + (schema, "email_addresses_v"): [ + col("address_id", auto="omit", nullable=mock.ANY), + col("remote_user_id"), + col("email_address"), + ], + (schema, "users_v"): [ + col("user_id", auto="omit", nullable=mock.ANY), + col("test1", nullable=mock.ANY), + col("test2", nullable=mock.ANY), + col("parent_user_id"), + ], + (schema, "user_tmp_v"): [ + col("id", auto="omit", nullable=mock.ANY), + col("name"), + col("foo"), + ], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + pk("user_id"), + col("test1", nullable=False), + col("test2", nullable=False), + col("parent_user_id"), + ], + (schema, "dingalings"): [ + pk("dingaling_id"), + col("address_id"), + col("id_user"), + col("data"), + ], + (schema, "email_addresses"): [ + pk("address_id"), + col("remote_user_id"), + col("email_address"), + ], + (schema, "comment_test"): [ + pk("id", comment="id comment"), + col("data", comment="data % comment"), + col( + "d2", + comment=r"""Comment types type speedily ' " \ '' Fun!""", + ), + col("d3", comment="Comment\nwith\rescapes"), + ], + (schema, "no_constraints"): [col("data")], + (schema, "local_table"): [pk("id"), col("data"), col("remote_id")], + (schema, "remote_table"): [pk("id"), col("local_id"), col("data")], + (schema, "remote_table_2"): [pk("id"), col("data")], + (schema, "noncol_idx_test_nopk"): [col("q")], + (schema, "noncol_idx_test_pk"): [pk("id"), col("q")], + (schema, self.temp_table_name()): [ + pk("id"), + col("name"), + col("foo"), + ], + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_column_keys(self): + return {"name", "type", "nullable", "default"} + + def exp_pks( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def pk(*cols, name=mock.ANY, comment=None): + return { + "constrained_columns": list(cols), + "name": name, + "comment": comment, + } + + empty = pk(name=None) + if testing.requires.materialized_views_reflect_pk.enabled: + materialized = {(schema, "dingalings_v"): pk("dingaling_id")} + else: + materialized = {(schema, "dingalings_v"): empty} + views = { + (schema, "email_addresses_v"): empty, + (schema, "users_v"): empty, + (schema, "user_tmp_v"): empty, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): pk("user_id"), + (schema, "dingalings"): pk("dingaling_id"), + (schema, "email_addresses"): pk( + "address_id", name="email_ad_pk", comment="ea pk comment" + ), + (schema, "comment_test"): pk("id"), + (schema, "no_constraints"): empty, + (schema, "local_table"): pk("id"), + (schema, "remote_table"): pk("id"), + (schema, "remote_table_2"): pk("id"), + (schema, "noncol_idx_test_nopk"): empty, + (schema, "noncol_idx_test_pk"): pk("id"), + (schema, self.temp_table_name()): pk("id"), + } + if not testing.requires.reflects_pk_names.enabled: + for val in tables.values(): + if val["name"] is not None: + val["name"] = mock.ANY + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_pk_keys(self): + return {"name", "constrained_columns"} + + def exp_fks( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + class tt: + def __eq__(self, other): + return ( + other is None + or config.db.dialect.default_schema_name == other + ) + + def fk( + cols, + ref_col, + ref_table, + ref_schema=schema, + name=mock.ANY, + comment=None, + ): + return { + "constrained_columns": cols, + "referred_columns": ref_col, + "name": name, + "options": mock.ANY, + "referred_schema": ( + ref_schema if ref_schema is not None else tt() + ), + "referred_table": ref_table, + "comment": comment, + } + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + fk(["parent_user_id"], ["user_id"], "users", name="user_id_fk") + ], + (schema, "dingalings"): [ + fk(["id_user"], ["user_id"], "users"), + fk( + ["address_id"], + ["address_id"], + "email_addresses", + name="zz_email_add_id_fg", + comment="di fk comment", + ), + ], + (schema, "email_addresses"): [ + fk(["remote_user_id"], ["user_id"], "users") + ], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [ + fk( + ["remote_id"], + ["id"], + "remote_table_2", + ref_schema=config.test_schema, + ) + ], + (schema, "remote_table"): [ + fk(["local_id"], ["id"], "local_table", ref_schema=None) + ], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [], + } + if not testing.requires.self_referential_foreign_keys.enabled: + tables[(schema, "users")].clear() + if not testing.requires.named_constraints.enabled: + for vals in tables.values(): + for val in vals: + if val["name"] is not mock.ANY: + val["name"] = mock.ANY + + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_fk_keys(self): + return { + "name", + "constrained_columns", + "referred_schema", + "referred_table", + "referred_columns", + } + + def exp_indexes( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def idx( + *cols, + name, + unique=False, + column_sorting=None, + duplicates=False, + fk=False, + ): + fk_req = testing.requires.foreign_keys_reflect_as_index + dup_req = testing.requires.unique_constraints_reflect_as_index + sorting_expression = ( + testing.requires.reflect_indexes_with_ascdesc_as_expression + ) + + if (fk and not fk_req.enabled) or ( + duplicates and not dup_req.enabled + ): + return () + res = { + "unique": unique, + "column_names": list(cols), + "name": name, + "dialect_options": mock.ANY, + "include_columns": [], + } + if column_sorting: + res["column_sorting"] = column_sorting + if sorting_expression.enabled: + res["expressions"] = orig = res["column_names"] + res["column_names"] = [ + None if c in column_sorting else c for c in orig + ] + + if duplicates: + res["duplicates_constraint"] = name + return [res] + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + if materialized: + materialized[(schema, "dingalings_v")].extend( + idx("data", name="mat_index") + ) + tables = { + (schema, "users"): [ + *idx("parent_user_id", name="user_id_fk", fk=True), + *idx("user_id", "test2", "test1", name="users_all_idx"), + *idx("test1", "test2", name="users_t_idx", unique=True), + ], + (schema, "dingalings"): [ + *idx("data", name=mock.ANY, unique=True, duplicates=True), + *idx("id_user", name=mock.ANY, fk=True), + *idx( + "address_id", + "dingaling_id", + name="zz_dingalings_multiple", + unique=True, + duplicates=True, + ), + ], + (schema, "email_addresses"): [ + *idx("email_address", name=mock.ANY), + *idx("remote_user_id", name=mock.ANY, fk=True), + ], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [ + *idx("remote_id", name=mock.ANY, fk=True) + ], + (schema, "remote_table"): [ + *idx("local_id", name=mock.ANY, fk=True) + ], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [ + *idx( + "q", + name="noncol_idx_nopk", + column_sorting={"q": ("desc",)}, + ) + ], + (schema, "noncol_idx_test_pk"): [ + *idx( + "q", name="noncol_idx_pk", column_sorting={"q": ("desc",)} + ) + ], + (schema, self.temp_table_name()): [ + *idx("foo", name="user_tmp_ix"), + *idx( + "name", + name=f"user_tmp_uq_{config.ident}", + duplicates=True, + unique=True, + ), + ], + } + if ( + not testing.requires.indexes_with_ascdesc.enabled + or not testing.requires.reflect_indexes_with_ascdesc.enabled + ): + tables[(schema, "noncol_idx_test_nopk")].clear() + tables[(schema, "noncol_idx_test_pk")].clear() + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_index_keys(self): + return {"name", "column_names", "unique"} + + def exp_ucs( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + all_=False, + ): + def uc( + *cols, name, duplicates_index=None, is_index=False, comment=None + ): + req = testing.requires.unique_index_reflect_as_unique_constraints + if is_index and not req.enabled: + return () + res = { + "column_names": list(cols), + "name": name, + "comment": comment, + } + if duplicates_index: + res["duplicates_index"] = duplicates_index + return [res] + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + *uc( + "test1", + "test2", + name="users_t_idx", + duplicates_index="users_t_idx", + is_index=True, + ) + ], + (schema, "dingalings"): [ + *uc("data", name=mock.ANY, duplicates_index=mock.ANY), + *uc( + "address_id", + "dingaling_id", + name="zz_dingalings_multiple", + duplicates_index="zz_dingalings_multiple", + comment="di unique comment", + ), + ], + (schema, "email_addresses"): [], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [], + (schema, "remote_table"): [], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [ + *uc("name", name=f"user_tmp_uq_{config.ident}") + ], + } + if all_: + return {**materialized, **views, **tables} + else: + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_unique_cst_keys(self): + return {"name", "column_names"} + + def exp_ccs( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + class tt(str): + def __eq__(self, other): + res = ( + other.lower() + .replace("(", "") + .replace(")", "") + .replace("`", "") + ) + return self in res + + def cc(text, name, comment=None): + return {"sqltext": tt(text), "name": name, "comment": comment} + + # print({1: "test2 > (0)::double precision"} == {1: tt("test2 > 0")}) + # assert 0 + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + cc("test2 <= 1000", mock.ANY), + cc( + "test2 > 0", + "zz_test2_gt_zero", + comment="users check constraint", + ), + ], + (schema, "dingalings"): [ + cc( + "address_id > 0 and address_id < 1000", + name="address_id_gt_zero", + ), + ], + (schema, "email_addresses"): [], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [], + (schema, "remote_table"): [], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [], + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_cc_keys(self): + return {"name", "sqltext"} + + @testing.requires.schema_reflection + def test_get_schema_names(self, connection): + insp = inspect(connection) + + is_true(testing.config.test_schema in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_has_schema(self, connection): + insp = inspect(connection) + + is_true(insp.has_schema(testing.config.test_schema)) + is_false(insp.has_schema("sa_fake_schema_foo")) + + @testing.requires.schema_reflection + def test_get_schema_names_w_translate_map(self, connection): + """test #7300""" + + connection = connection.execution_options( + schema_translate_map={ + "foo": "bar", + BLANK_SCHEMA: testing.config.test_schema, + } + ) + insp = inspect(connection) + + is_true(testing.config.test_schema in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_has_schema_w_translate_map(self, connection): + connection = connection.execution_options( + schema_translate_map={ + "foo": "bar", + BLANK_SCHEMA: testing.config.test_schema, + } + ) + insp = inspect(connection) + + is_true(insp.has_schema(testing.config.test_schema)) + is_false(insp.has_schema("sa_fake_schema_foo")) + + @testing.requires.schema_reflection + @testing.requires.schema_create_delete + def test_schema_cache(self, connection): + insp = inspect(connection) + + is_false("foo_bar" in insp.get_schema_names()) + is_false(insp.has_schema("foo_bar")) + connection.execute(DDL("CREATE SCHEMA foo_bar")) + try: + is_false("foo_bar" in insp.get_schema_names()) + is_false(insp.has_schema("foo_bar")) + insp.clear_cache() + is_true("foo_bar" in insp.get_schema_names()) + is_true(insp.has_schema("foo_bar")) + finally: + connection.execute(DDL("DROP SCHEMA foo_bar")) + + @testing.requires.schema_reflection + def test_dialect_initialize(self): + engine = engines.testing_engine() + inspect(engine) + assert hasattr(engine.dialect, "default_schema_name") + + @testing.requires.schema_reflection + def test_get_default_schema_name(self, connection): + insp = inspect(connection) + eq_(insp.default_schema_name, connection.dialect.default_schema_name) + + @testing.combinations( + None, + ("foreign_key", testing.requires.foreign_key_constraint_reflection), + argnames="order_by", + ) + @testing.combinations( + (True, testing.requires.schemas), False, argnames="use_schema" + ) + def test_get_table_names(self, connection, order_by, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + _ignore_tables = { + "comment_test", + "noncol_idx_test_pk", + "noncol_idx_test_nopk", + "local_table", + "remote_table", + "remote_table_2", + "no_constraints", + } + + insp = inspect(connection) + + if order_by: + tables = [ + rec[0] + for rec in insp.get_sorted_table_and_fkc_names(schema) + if rec[0] + ] + else: + tables = insp.get_table_names(schema) + table_names = [t for t in tables if t not in _ignore_tables] + + if order_by == "foreign_key": + answer = ["users", "email_addresses", "dingalings"] + eq_(table_names, answer) + else: + answer = ["dingalings", "email_addresses", "users"] + eq_(sorted(table_names), answer) + + @testing.combinations( + (True, testing.requires.schemas), False, argnames="use_schema" + ) + def test_get_view_names(self, connection, use_schema): + insp = inspect(connection) + if use_schema: + schema = config.test_schema + else: + schema = None + table_names = insp.get_view_names(schema) + if testing.requires.materialized_views.enabled: + eq_(sorted(table_names), ["email_addresses_v", "users_v"]) + eq_(insp.get_materialized_view_names(schema), ["dingalings_v"]) + else: + answer = ["dingalings_v", "email_addresses_v", "users_v"] + eq_(sorted(table_names), answer) + + @testing.requires.temp_table_names + def test_get_temp_table_names(self, connection): + insp = inspect(connection) + temp_table_names = insp.get_temp_table_names() + eq_(sorted(temp_table_names), [f"user_tmp_{config.ident}"]) + + @testing.requires.view_reflection + @testing.requires.temporary_views + def test_get_temp_view_names(self, connection): + insp = inspect(connection) + temp_table_names = insp.get_temp_view_names() + eq_(sorted(temp_table_names), ["user_tmp_v"]) + + @testing.requires.comment_reflection + def test_get_comments(self, connection): + self._test_get_comments(connection) + + @testing.requires.comment_reflection + @testing.requires.schemas + def test_get_comments_with_schema(self, connection): + self._test_get_comments(connection, testing.config.test_schema) + + def _test_get_comments(self, connection, schema=None): + insp = inspect(connection) + exp = self.exp_comments(schema=schema) + eq_( + insp.get_table_comment("comment_test", schema=schema), + exp[(schema, "comment_test")], + ) + + eq_( + insp.get_table_comment("users", schema=schema), + exp[(schema, "users")], + ) + + eq_( + insp.get_table_comment("comment_test", schema=schema), + exp[(schema, "comment_test")], + ) + + no_cst = self.tables.no_constraints.name + eq_( + insp.get_table_comment(no_cst, schema=schema), + exp[(schema, no_cst)], + ) + + @testing.combinations( + (False, False), + (False, True, testing.requires.schemas), + (True, False, testing.requires.view_reflection), + ( + True, + True, + testing.requires.schemas + testing.requires.view_reflection, + ), + argnames="use_views,use_schema", + ) + def test_get_columns(self, connection, use_views, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + users, addresses = (self.tables.users, self.tables.email_addresses) + if use_views: + table_names = ["users_v", "email_addresses_v", "dingalings_v"] + else: + table_names = ["users", "email_addresses"] + + insp = inspect(connection) + for table_name, table in zip(table_names, (users, addresses)): + schema_name = schema + cols = insp.get_columns(table_name, schema=schema_name) + is_true(len(cols) > 0, len(cols)) + + # should be in order + + for i, col in enumerate(table.columns): + eq_(col.name, cols[i]["name"]) + ctype = cols[i]["type"].__class__ + ctype_def = col.type + if isinstance(ctype_def, sa.types.TypeEngine): + ctype_def = ctype_def.__class__ + + # Oracle returns Date for DateTime. + + if testing.against("oracle") and ctype_def in ( + sql_types.Date, + sql_types.DateTime, + ): + ctype_def = sql_types.Date + + # assert that the desired type and return type share + # a base within one of the generic types. + + is_true( + len( + set(ctype.__mro__) + .intersection(ctype_def.__mro__) + .intersection( + [ + sql_types.Integer, + sql_types.Numeric, + sql_types.DateTime, + sql_types.Date, + sql_types.Time, + sql_types.String, + sql_types._Binary, + ] + ) + ) + > 0, + "%s(%s), %s(%s)" + % (col.name, col.type, cols[i]["name"], ctype), + ) + + if not col.primary_key: + assert cols[i]["default"] is None + + # The case of a table with no column + # is tested below in TableNoColumnsTest + + @testing.requires.temp_table_reflection + def test_reflect_table_temp_table(self, connection): + table_name = self.temp_table_name() + user_tmp = self.tables[table_name] + + reflected_user_tmp = Table( + table_name, MetaData(), autoload_with=connection + ) + self.assert_tables_equal( + user_tmp, reflected_user_tmp, strict_constraints=False + ) + + @testing.requires.temp_table_reflection + def test_get_temp_table_columns(self, connection): + table_name = self.temp_table_name() + user_tmp = self.tables[table_name] + insp = inspect(connection) + cols = insp.get_columns(table_name) + is_true(len(cols) > 0, len(cols)) + + for i, col in enumerate(user_tmp.columns): + eq_(col.name, cols[i]["name"]) + + @testing.requires.temp_table_reflection + @testing.requires.view_column_reflection + @testing.requires.temporary_views + def test_get_temp_view_columns(self, connection): + insp = inspect(connection) + cols = insp.get_columns("user_tmp_v") + eq_([col["name"] for col in cols], ["id", "name", "foo"]) + + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.primary_key_constraint_reflection + def test_get_pk_constraint(self, connection, use_schema): + if use_schema: + schema = testing.config.test_schema + else: + schema = None + + users, addresses = self.tables.users, self.tables.email_addresses + insp = inspect(connection) + exp = self.exp_pks(schema=schema) + + users_cons = insp.get_pk_constraint(users.name, schema=schema) + self._check_list( + [users_cons], [exp[(schema, users.name)]], self._required_pk_keys + ) + + addr_cons = insp.get_pk_constraint(addresses.name, schema=schema) + exp_cols = exp[(schema, addresses.name)]["constrained_columns"] + eq_(addr_cons["constrained_columns"], exp_cols) + + with testing.requires.reflects_pk_names.fail_if(): + eq_(addr_cons["name"], "email_ad_pk") + + no_cst = self.tables.no_constraints.name + self._check_list( + [insp.get_pk_constraint(no_cst, schema=schema)], + [exp[(schema, no_cst)]], + self._required_pk_keys, + ) + + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.foreign_key_constraint_reflection + def test_get_foreign_keys(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + users, addresses = (self.tables.users, self.tables.email_addresses) + insp = inspect(connection) + expected_schema = schema + # users + + if testing.requires.self_referential_foreign_keys.enabled: + users_fkeys = insp.get_foreign_keys(users.name, schema=schema) + fkey1 = users_fkeys[0] + + with testing.requires.named_constraints.fail_if(): + eq_(fkey1["name"], "user_id_fk") + + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) + eq_(fkey1["constrained_columns"], ["parent_user_id"]) + + # addresses + addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema) + fkey1 = addr_fkeys[0] + + with testing.requires.implicitly_named_constraints.fail_if(): + is_true(fkey1["name"] is not None) + + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) + eq_(fkey1["constrained_columns"], ["remote_user_id"]) + + no_cst = self.tables.no_constraints.name + eq_(insp.get_foreign_keys(no_cst, schema=schema), []) + + @testing.requires.cross_schema_fk_reflection + @testing.requires.schemas + def test_get_inter_schema_foreign_keys(self, connection): + local_table, remote_table, remote_table_2 = self.tables( + "%s.local_table" % connection.dialect.default_schema_name, + "%s.remote_table" % testing.config.test_schema, + "%s.remote_table_2" % testing.config.test_schema, + ) + + insp = inspect(connection) + + local_fkeys = insp.get_foreign_keys(local_table.name) + eq_(len(local_fkeys), 1) + + fkey1 = local_fkeys[0] + eq_(fkey1["referred_schema"], testing.config.test_schema) + eq_(fkey1["referred_table"], remote_table_2.name) + eq_(fkey1["referred_columns"], ["id"]) + eq_(fkey1["constrained_columns"], ["remote_id"]) + + remote_fkeys = insp.get_foreign_keys( + remote_table.name, schema=testing.config.test_schema + ) + eq_(len(remote_fkeys), 1) + + fkey2 = remote_fkeys[0] + + is_true( + fkey2["referred_schema"] + in ( + None, + connection.dialect.default_schema_name, + ) + ) + eq_(fkey2["referred_table"], local_table.name) + eq_(fkey2["referred_columns"], ["id"]) + eq_(fkey2["constrained_columns"], ["local_id"]) + + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.index_reflection + def test_get_indexes(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + # The database may decide to create indexes for foreign keys, etc. + # so there may be more indexes than expected. + insp = inspect(connection) + indexes = insp.get_indexes("users", schema=schema) + exp = self.exp_indexes(schema=schema) + self._check_list( + indexes, exp[(schema, "users")], self._required_index_keys + ) + + no_cst = self.tables.no_constraints.name + self._check_list( + insp.get_indexes(no_cst, schema=schema), + exp[(schema, no_cst)], + self._required_index_keys, + ) + + @testing.combinations( + ("noncol_idx_test_nopk", "noncol_idx_nopk"), + ("noncol_idx_test_pk", "noncol_idx_pk"), + argnames="tname,ixname", + ) + @testing.requires.index_reflection + @testing.requires.indexes_with_ascdesc + @testing.requires.reflect_indexes_with_ascdesc + def test_get_noncol_index(self, connection, tname, ixname): + insp = inspect(connection) + indexes = insp.get_indexes(tname) + # reflecting an index that has "x DESC" in it as the column. + # the DB may or may not give us "x", but make sure we get the index + # back, it has a name, it's connected to the table. + expected_indexes = self.exp_indexes()[(None, tname)] + self._check_list(indexes, expected_indexes, self._required_index_keys) + + t = Table(tname, MetaData(), autoload_with=connection) + eq_(len(t.indexes), 1) + is_(list(t.indexes)[0].table, t) + eq_(list(t.indexes)[0].name, ixname) + + @testing.requires.temp_table_reflection + @testing.requires.unique_constraint_reflection + def test_get_temp_table_unique_constraints(self, connection): + insp = inspect(connection) + name = self.temp_table_name() + reflected = insp.get_unique_constraints(name) + exp = self.exp_ucs(all_=True)[(None, name)] + self._check_list(reflected, exp, self._required_index_keys) + + @testing.requires.temp_table_reflect_indexes + def test_get_temp_table_indexes(self, connection): + insp = inspect(connection) + table_name = self.temp_table_name() + indexes = insp.get_indexes(table_name) + for ind in indexes: + ind.pop("dialect_options", None) + expected = [ + {"unique": False, "column_names": ["foo"], "name": "user_tmp_ix"} + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] + eq_( + [idx for idx in indexes if idx["name"] == "user_tmp_ix"], + expected, + ) + + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) + @testing.requires.unique_constraint_reflection + def test_get_unique_constraints(self, metadata, connection, use_schema): + # SQLite dialect needs to parse the names of the constraints + # separately from what it gets from PRAGMA index_list(), and + # then matches them up. so same set of column_names in two + # constraints will confuse it. Perhaps we should no longer + # bother with index_list() here since we have the whole + # CREATE TABLE? + + if use_schema: + schema = config.test_schema + else: + schema = None + uniques = sorted( + [ + {"name": "unique_a", "column_names": ["a"]}, + {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]}, + {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]}, + {"name": "unique_asc_key", "column_names": ["asc", "key"]}, + {"name": "i.have.dots", "column_names": ["b"]}, + {"name": "i have spaces", "column_names": ["c"]}, + ], + key=operator.itemgetter("name"), + ) + table = Table( + "testtbl", + metadata, + Column("a", sa.String(20)), + Column("b", sa.String(30)), + Column("c", sa.Integer), + # reserved identifiers + Column("asc", sa.String(30)), + Column("key", sa.String(30)), + schema=schema, + ) + for uc in uniques: + table.append_constraint( + sa.UniqueConstraint(*uc["column_names"], name=uc["name"]) + ) + table.create(connection) + + insp = inspect(connection) + reflected = sorted( + insp.get_unique_constraints("testtbl", schema=schema), + key=operator.itemgetter("name"), + ) + + names_that_duplicate_index = set() + + eq_(len(uniques), len(reflected)) + + for orig, refl in zip(uniques, reflected): + # Different dialects handle duplicate index and constraints + # differently, so ignore this flag + dupe = refl.pop("duplicates_index", None) + if dupe: + names_that_duplicate_index.add(dupe) + eq_(refl.pop("comment", None), None) + eq_(orig, refl) + + reflected_metadata = MetaData() + reflected = Table( + "testtbl", + reflected_metadata, + autoload_with=connection, + schema=schema, + ) + + # test "deduplicates for index" logic. MySQL and Oracle + # "unique constraints" are actually unique indexes (with possible + # exception of a unique that is a dupe of another one in the case + # of Oracle). make sure # they aren't duplicated. + idx_names = {idx.name for idx in reflected.indexes} + uq_names = { + uq.name + for uq in reflected.constraints + if isinstance(uq, sa.UniqueConstraint) + }.difference(["unique_c_a_b"]) + + assert not idx_names.intersection(uq_names) + if names_that_duplicate_index: + eq_(names_that_duplicate_index, idx_names) + eq_(uq_names, set()) + + no_cst = self.tables.no_constraints.name + eq_(insp.get_unique_constraints(no_cst, schema=schema), []) + + @testing.requires.view_reflection + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + def test_get_view_definition(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + insp = inspect(connection) + for view in ["users_v", "email_addresses_v", "dingalings_v"]: + v = insp.get_view_definition(view, schema=schema) + is_true(bool(v)) + + @testing.requires.view_reflection + def test_get_view_definition_does_not_exist(self, connection): + insp = inspect(connection) + with expect_raises(NoSuchTableError): + insp.get_view_definition("view_does_not_exist") + with expect_raises(NoSuchTableError): + insp.get_view_definition("users") # a table + + @testing.requires.table_reflection + def test_autoincrement_col(self, connection): + """test that 'autoincrement' is reflected according to sqla's policy. + + Don't mark this test as unsupported for any backend ! + + (technically it fails with MySQL InnoDB since "id" comes before "id2") + + A backend is better off not returning "autoincrement" at all, + instead of potentially returning "False" for an auto-incrementing + primary key column. + + """ + + insp = inspect(connection) + + for tname, cname in [ + ("users", "user_id"), + ("email_addresses", "address_id"), + ("dingalings", "dingaling_id"), + ]: + cols = insp.get_columns(tname) + id_ = {c["name"]: c for c in cols}[cname] + assert id_.get("autoincrement", True) + + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) + def test_get_table_options(self, use_schema): + insp = inspect(config.db) + schema = config.test_schema if use_schema else None + + if testing.requires.reflect_table_options.enabled: + res = insp.get_table_options("users", schema=schema) + is_true(isinstance(res, dict)) + # NOTE: can't really create a table with no option + res = insp.get_table_options("no_constraints", schema=schema) + is_true(isinstance(res, dict)) + else: + with expect_raises(NotImplementedError): + res = insp.get_table_options("users", schema=schema) + + @testing.combinations((True, testing.requires.schemas), False) + def test_multi_get_table_options(self, use_schema): + insp = inspect(config.db) + if testing.requires.reflect_table_options.enabled: + schema = config.test_schema if use_schema else None + res = insp.get_multi_table_options(schema=schema) + + exp = { + (schema, table): insp.get_table_options(table, schema=schema) + for table in insp.get_table_names(schema=schema) + } + eq_(res, exp) + else: + with expect_raises(NotImplementedError): + res = insp.get_multi_table_options() + + @testing.fixture + def get_multi_exp(self, connection): + def provide_fixture( + schema, scope, kind, use_filter, single_reflect_fn, exp_method + ): + insp = inspect(connection) + # call the reflection function at least once to avoid + # "Unexpected success" errors if the result is actually empty + # and NotImplementedError is not raised + single_reflect_fn(insp, "email_addresses") + kw = {"scope": scope, "kind": kind} + if schema: + schema = schema() + + filter_names = [] + + if ObjectKind.TABLE in kind: + filter_names.extend( + ["comment_test", "users", "does-not-exist"] + ) + if ObjectKind.VIEW in kind: + filter_names.extend(["email_addresses_v", "does-not-exist"]) + if ObjectKind.MATERIALIZED_VIEW in kind: + filter_names.extend(["dingalings_v", "does-not-exist"]) + + if schema: + kw["schema"] = schema + if use_filter: + kw["filter_names"] = filter_names + + exp = exp_method( + schema=schema, + scope=scope, + kind=kind, + filter_names=kw.get("filter_names"), + ) + kws = [kw] + if scope == ObjectScope.DEFAULT: + nkw = kw.copy() + nkw.pop("scope") + kws.append(nkw) + if kind == ObjectKind.TABLE: + nkw = kw.copy() + nkw.pop("kind") + kws.append(nkw) + + return inspect(connection), kws, exp + + return provide_fixture + + @testing.requires.reflect_table_options + @_multi_combination + def test_multi_get_table_options_tables( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_table_options, + self.exp_options, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_table_options(**kw) + eq_(result, exp) + + @testing.requires.comment_reflection + @_multi_combination + def test_get_multi_table_comment( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_table_comment, + self.exp_comments, + ) + for kw in kws: + insp.clear_cache() + eq_(insp.get_multi_table_comment(**kw), exp) + + def _check_expressions(self, result, exp, err_msg): + def _clean(text: str): + return re.sub(r"['\" ]", "", text).lower() + + if isinstance(exp, dict): + eq_({_clean(e): v for e, v in result.items()}, exp, err_msg) + else: + eq_([_clean(e) for e in result], exp, err_msg) + + def _check_list(self, result, exp, req_keys=None, msg=None): + if req_keys is None: + eq_(result, exp, msg) + else: + eq_(len(result), len(exp), msg) + for r, e in zip(result, exp): + for k in set(r) | set(e): + if k in req_keys or (k in r and k in e): + err_msg = f"{msg} - {k} - {r}" + if k in ("expressions", "column_sorting"): + self._check_expressions(r[k], e[k], err_msg) + else: + eq_(r[k], e[k], err_msg) + + def _check_table_dict(self, result, exp, req_keys=None, make_lists=False): + eq_(set(result.keys()), set(exp.keys())) + for k in result: + r, e = result[k], exp[k] + if make_lists: + r, e = [r], [e] + self._check_list(r, e, req_keys, k) + + @_multi_combination + def test_get_multi_columns( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_columns, + self.exp_columns, + ) + + for kw in kws: + insp.clear_cache() + result = insp.get_multi_columns(**kw) + self._check_table_dict(result, exp, self._required_column_keys) + + @testing.requires.primary_key_constraint_reflection + @_multi_combination + def test_get_multi_pk_constraint( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_pk_constraint, + self.exp_pks, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_pk_constraint(**kw) + self._check_table_dict( + result, exp, self._required_pk_keys, make_lists=True + ) + + def _adjust_sort(self, result, expected, key): + if not testing.requires.implicitly_named_constraints.enabled: + for obj in [result, expected]: + for val in obj.values(): + if len(val) > 1 and any( + v.get("name") in (None, mock.ANY) for v in val + ): + val.sort(key=key) + + @testing.requires.foreign_key_constraint_reflection + @_multi_combination + def test_get_multi_foreign_keys( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_foreign_keys, + self.exp_fks, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_foreign_keys(**kw) + self._adjust_sort( + result, exp, lambda d: tuple(d["constrained_columns"]) + ) + self._check_table_dict(result, exp, self._required_fk_keys) + + @testing.requires.index_reflection + @_multi_combination + def test_get_multi_indexes( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_indexes, + self.exp_indexes, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_indexes(**kw) + self._check_table_dict(result, exp, self._required_index_keys) + + @testing.requires.unique_constraint_reflection + @_multi_combination + def test_get_multi_unique_constraints( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_unique_constraints, + self.exp_ucs, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_unique_constraints(**kw) + self._adjust_sort(result, exp, lambda d: tuple(d["column_names"])) + self._check_table_dict(result, exp, self._required_unique_cst_keys) + + @testing.requires.check_constraint_reflection + @_multi_combination + def test_get_multi_check_constraints( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_check_constraints, + self.exp_ccs, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_check_constraints(**kw) + self._adjust_sort(result, exp, lambda d: tuple(d["sqltext"])) + self._check_table_dict(result, exp, self._required_cc_keys) + + @testing.combinations( + ("get_table_options", testing.requires.reflect_table_options), + "get_columns", + ( + "get_pk_constraint", + testing.requires.primary_key_constraint_reflection, + ), + ( + "get_foreign_keys", + testing.requires.foreign_key_constraint_reflection, + ), + ("get_indexes", testing.requires.index_reflection), + ( + "get_unique_constraints", + testing.requires.unique_constraint_reflection, + ), + ( + "get_check_constraints", + testing.requires.check_constraint_reflection, + ), + ("get_table_comment", testing.requires.comment_reflection), + argnames="method", + ) + def test_not_existing_table(self, method, connection): + insp = inspect(connection) + meth = getattr(insp, method) + with expect_raises(NoSuchTableError): + meth("table_does_not_exists") + + def test_unreflectable(self, connection): + mc = Inspector.get_multi_columns + + def patched(*a, **k): + ur = k.setdefault("unreflectable", {}) + ur[(None, "some_table")] = UnreflectableTableError("err") + return mc(*a, **k) + + with mock.patch.object(Inspector, "get_multi_columns", patched): + with expect_raises_message(UnreflectableTableError, "err"): + inspect(connection).reflect_table( + Table("some_table", MetaData()), None + ) + + @testing.combinations(True, False, argnames="use_schema") + @testing.combinations( + (True, testing.requires.views), False, argnames="views" + ) + def test_metadata(self, connection, use_schema, views): + m = MetaData() + schema = config.test_schema if use_schema else None + m.reflect(connection, schema=schema, views=views, resolve_fks=False) + + insp = inspect(connection) + tables = insp.get_table_names(schema) + if views: + tables += insp.get_view_names(schema) + try: + tables += insp.get_materialized_view_names(schema) + except NotImplementedError: + pass + if schema: + tables = [f"{schema}.{t}" for t in tables] + eq_(sorted(m.tables), sorted(tables)) + + @testing.requires.comment_reflection + def test_comments_unicode(self, connection, metadata): + Table( + "unicode_comments", + metadata, + Column("unicode", Integer, comment="é試蛇ẟΩ"), + Column("emoji", Integer, comment="☁️✨"), + comment="試蛇ẟΩ✨", + ) + + metadata.create_all(connection) + + insp = inspect(connection) + tc = insp.get_table_comment("unicode_comments") + eq_(tc, {"text": "試蛇ẟΩ✨"}) + + cols = insp.get_columns("unicode_comments") + value = {c["name"]: c["comment"] for c in cols} + exp = {"unicode": "é試蛇ẟΩ", "emoji": "☁️✨"} + eq_(value, exp) + + @testing.requires.comment_reflection_full_unicode + def test_comments_unicode_full(self, connection, metadata): + Table( + "unicode_comments", + metadata, + Column("emoji", Integer, comment="🐍🧙🝝🧙♂️🧙♀️"), + comment="🎩🁰🝑🤷♀️🤷♂️", + ) + + metadata.create_all(connection) + + insp = inspect(connection) + tc = insp.get_table_comment("unicode_comments") + eq_(tc, {"text": "🎩🁰🝑🤷♀️🤷♂️"}) + c = insp.get_columns("unicode_comments")[0] + eq_({c["name"]: c["comment"]}, {"emoji": "🐍🧙🝝🧙♂️🧙♀️"}) + + +class TableNoColumnsTest(fixtures.TestBase): + __requires__ = ("reflect_tables_no_columns",) + __backend__ = True + + @testing.fixture + def table_no_columns(self, connection, metadata): + Table("empty", metadata) + metadata.create_all(connection) + + @testing.fixture + def view_no_columns(self, connection, metadata): + Table("empty", metadata) + event.listen( + metadata, + "after_create", + DDL("CREATE VIEW empty_v AS SELECT * FROM empty"), + ) + + # for transactional DDL the transaction is rolled back before this + # drop statement is invoked + event.listen( + metadata, "before_drop", DDL("DROP VIEW IF EXISTS empty_v") + ) + metadata.create_all(connection) + + def test_reflect_table_no_columns(self, connection, table_no_columns): + t2 = Table("empty", MetaData(), autoload_with=connection) + eq_(list(t2.c), []) + + def test_get_columns_table_no_columns(self, connection, table_no_columns): + insp = inspect(connection) + eq_(insp.get_columns("empty"), []) + multi = insp.get_multi_columns() + eq_(multi, {(None, "empty"): []}) + + def test_reflect_incl_table_no_columns(self, connection, table_no_columns): + m = MetaData() + m.reflect(connection) + assert set(m.tables).intersection(["empty"]) + + @testing.requires.views + def test_reflect_view_no_columns(self, connection, view_no_columns): + t2 = Table("empty_v", MetaData(), autoload_with=connection) + eq_(list(t2.c), []) + + @testing.requires.views + def test_get_columns_view_no_columns(self, connection, view_no_columns): + insp = inspect(connection) + eq_(insp.get_columns("empty_v"), []) + multi = insp.get_multi_columns(kind=ObjectKind.VIEW) + eq_(multi, {(None, "empty_v"): []}) + + +class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase): + __backend__ = True + + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) + @testing.requires.check_constraint_reflection + def test_get_check_constraints(self, metadata, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + Table( + "sa_cc", + metadata, + Column("a", Integer()), + sa.CheckConstraint("a > 1 AND a < 5", name="cc1"), + sa.CheckConstraint( + "a = 1 OR (a > 2 AND a < 5)", name="UsesCasing" + ), + schema=schema, + ) + Table( + "no_constraints", + metadata, + Column("data", sa.String(20)), + schema=schema, + ) + + metadata.create_all(connection) + + insp = inspect(connection) + reflected = sorted( + insp.get_check_constraints("sa_cc", schema=schema), + key=operator.itemgetter("name"), + ) + + # trying to minimize effect of quoting, parenthesis, etc. + # may need to add more to this as new dialects get CHECK + # constraint reflection support + def normalize(sqltext): + return " ".join( + re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I) + ) + + reflected = [ + {"name": item["name"], "sqltext": normalize(item["sqltext"])} + for item in reflected + ] + eq_( + reflected, + [ + {"name": "UsesCasing", "sqltext": "a = 1 or a > 2 and a < 5"}, + {"name": "cc1", "sqltext": "a > 1 and a < 5"}, + ], + ) + no_cst = "no_constraints" + eq_(insp.get_check_constraints(no_cst, schema=schema), []) + + @testing.requires.indexes_with_expressions + def test_reflect_expression_based_indexes(self, metadata, connection): + t = Table( + "t", + metadata, + Column("x", String(30)), + Column("y", String(30)), + Column("z", String(30)), + ) + + Index("t_idx", func.lower(t.c.x), t.c.z, func.lower(t.c.y)) + long_str = "long string " * 100 + Index("t_idx_long", func.coalesce(t.c.x, long_str)) + Index("t_idx_2", t.c.x) + + metadata.create_all(connection) + + insp = inspect(connection) + + expected = [ + { + "name": "t_idx_2", + "column_names": ["x"], + "unique": False, + "dialect_options": {}, + } + ] + + def completeIndex(entry): + if testing.requires.index_reflects_included_columns.enabled: + entry["include_columns"] = [] + entry["dialect_options"] = { + f"{connection.engine.name}_include": [] + } + else: + entry.setdefault("dialect_options", {}) + + completeIndex(expected[0]) + + class lower_index_str(str): + def __eq__(self, other): + ol = other.lower() + # test that lower and x or y are in the string + return "lower" in ol and ("x" in ol or "y" in ol) + + class coalesce_index_str(str): + def __eq__(self, other): + # test that coalesce and the string is in other + return "coalesce" in other.lower() and long_str in other + + if testing.requires.reflect_indexes_with_expressions.enabled: + expr_index = { + "name": "t_idx", + "column_names": [None, "z", None], + "expressions": [ + lower_index_str("lower(x)"), + "z", + lower_index_str("lower(y)"), + ], + "unique": False, + } + completeIndex(expr_index) + expected.insert(0, expr_index) + + expr_index_long = { + "name": "t_idx_long", + "column_names": [None], + "expressions": [ + coalesce_index_str(f"coalesce(x, '{long_str}')") + ], + "unique": False, + } + completeIndex(expr_index_long) + expected.append(expr_index_long) + + eq_(insp.get_indexes("t"), expected) + m2 = MetaData() + t2 = Table("t", m2, autoload_with=connection) + else: + with expect_warnings( + "Skipped unsupported reflection of expression-based " + "index t_idx" + ): + eq_(insp.get_indexes("t"), expected) + m2 = MetaData() + t2 = Table("t", m2, autoload_with=connection) + + self.compare_table_index_with_expected( + t2, expected, connection.engine.name + ) + + @testing.requires.index_reflects_included_columns + def test_reflect_covering_index(self, metadata, connection): + t = Table( + "t", + metadata, + Column("x", String(30)), + Column("y", String(30)), + ) + idx = Index("t_idx", t.c.x) + idx.dialect_options[connection.engine.name]["include"] = ["y"] + + metadata.create_all(connection) + + insp = inspect(connection) + + get_indexes = insp.get_indexes("t") + eq_( + get_indexes, + [ + { + "name": "t_idx", + "column_names": ["x"], + "include_columns": ["y"], + "unique": False, + "dialect_options": mock.ANY, + } + ], + ) + eq_( + get_indexes[0]["dialect_options"][ + "%s_include" % connection.engine.name + ], + ["y"], + ) + + t2 = Table("t", MetaData(), autoload_with=connection) + eq_( + list(t2.indexes)[0].dialect_options[connection.engine.name][ + "include" + ], + ["y"], + ) + + def _type_round_trip(self, connection, metadata, *types): + t = Table( + "t", + metadata, + *[Column("t%d" % i, type_) for i, type_ in enumerate(types)], + ) + t.create(connection) + + return [c["type"] for c in inspect(connection).get_columns("t")] + + @testing.requires.table_reflection + def test_numeric_reflection(self, connection, metadata): + for typ in self._type_round_trip( + connection, metadata, sql_types.Numeric(18, 5) + ): + assert isinstance(typ, sql_types.Numeric) + eq_(typ.precision, 18) + eq_(typ.scale, 5) + + @testing.requires.table_reflection + def test_varchar_reflection(self, connection, metadata): + typ = self._type_round_trip( + connection, metadata, sql_types.String(52) + )[0] + assert isinstance(typ, sql_types.String) + eq_(typ.length, 52) + + @testing.requires.table_reflection + def test_nullable_reflection(self, connection, metadata): + t = Table( + "t", + metadata, + Column("a", Integer, nullable=True), + Column("b", Integer, nullable=False), + ) + t.create(connection) + eq_( + { + col["name"]: col["nullable"] + for col in inspect(connection).get_columns("t") + }, + {"a": True, "b": False}, + ) + + @testing.combinations( + ( + None, + "CASCADE", + None, + testing.requires.foreign_key_constraint_option_reflection_ondelete, + ), + ( + None, + None, + "SET NULL", + testing.requires.foreign_key_constraint_option_reflection_onupdate, + ), + ( + {}, + None, + "NO ACTION", + testing.requires.foreign_key_constraint_option_reflection_onupdate, + ), + ( + {}, + "NO ACTION", + None, + testing.requires.fk_constraint_option_reflection_ondelete_noaction, + ), + ( + None, + None, + "RESTRICT", + testing.requires.fk_constraint_option_reflection_onupdate_restrict, + ), + ( + None, + "RESTRICT", + None, + testing.requires.fk_constraint_option_reflection_ondelete_restrict, + ), + argnames="expected,ondelete,onupdate", + ) + def test_get_foreign_key_options( + self, connection, metadata, expected, ondelete, onupdate + ): + options = {} + if ondelete: + options["ondelete"] = ondelete + if onupdate: + options["onupdate"] = onupdate + + if expected is None: + expected = options + + Table( + "x", + metadata, + Column("id", Integer, primary_key=True), + test_needs_fk=True, + ) + + Table( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("x_id", Integer, ForeignKey("x.id", name="xid")), + Column("test", String(10)), + test_needs_fk=True, + ) + + Table( + "user", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("tid", Integer), + sa.ForeignKeyConstraint( + ["tid"], ["table.id"], name="myfk", **options + ), + test_needs_fk=True, + ) + + metadata.create_all(connection) + + insp = inspect(connection) + + # test 'options' is always present for a backend + # that can reflect these, since alembic looks for this + opts = insp.get_foreign_keys("table")[0]["options"] + + eq_({k: opts[k] for k in opts if opts[k]}, {}) + + opts = insp.get_foreign_keys("user")[0]["options"] + eq_(opts, expected) + # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected) + + +class NormalizedNameTest(fixtures.TablesTest): + __requires__ = ("denormalized_names",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + quoted_name("t1", quote=True), + metadata, + Column("id", Integer, primary_key=True), + ) + Table( + quoted_name("t2", quote=True), + metadata, + Column("id", Integer, primary_key=True), + Column("t1id", ForeignKey("t1.id")), + ) + + def test_reflect_lowercase_forced_tables(self): + m2 = MetaData() + t2_ref = Table( + quoted_name("t2", quote=True), m2, autoload_with=config.db + ) + t1_ref = m2.tables["t1"] + assert t2_ref.c.t1id.references(t1_ref.c.id) + + m3 = MetaData() + m3.reflect( + config.db, only=lambda name, m: name.lower() in ("t1", "t2") + ) + assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id) + + def test_get_table_names(self): + tablenames = [ + t + for t in inspect(config.db).get_table_names() + if t.lower() in ("t1", "t2") + ] + + eq_(tablenames[0].upper(), tablenames[0].lower()) + eq_(tablenames[1].upper(), tablenames[1].lower()) + + +class ComputedReflectionTest(fixtures.ComputedReflectionFixtureTest): + def test_computed_col_default_not_set(self): + insp = inspect(config.db) + + cols = insp.get_columns("computed_default_table") + col_data = {c["name"]: c for c in cols} + is_true("42" in col_data["with_default"]["default"]) + is_(col_data["normal"]["default"], None) + is_(col_data["computed_col"]["default"], None) + + def test_get_column_returns_computed(self): + insp = inspect(config.db) + + cols = insp.get_columns("computed_default_table") + data = {c["name"]: c for c in cols} + for key in ("id", "normal", "with_default"): + is_true("computed" not in data[key]) + compData = data["computed_col"] + is_true("computed" in compData) + is_true("sqltext" in compData["computed"]) + eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42") + eq_( + "persisted" in compData["computed"], + testing.requires.computed_columns_reflect_persisted.enabled, + ) + if testing.requires.computed_columns_reflect_persisted.enabled: + eq_( + compData["computed"]["persisted"], + testing.requires.computed_columns_default_persisted.enabled, + ) + + def check_column(self, data, column, sqltext, persisted): + is_true("computed" in data[column]) + compData = data[column]["computed"] + eq_(self.normalize(compData["sqltext"]), sqltext) + if testing.requires.computed_columns_reflect_persisted.enabled: + is_true("persisted" in compData) + is_(compData["persisted"], persisted) + + def test_get_column_returns_persisted(self): + insp = inspect(config.db) + + cols = insp.get_columns("computed_column_table") + data = {c["name"]: c for c in cols} + + self.check_column( + data, + "computed_no_flag", + "normal+42", + testing.requires.computed_columns_default_persisted.enabled, + ) + if testing.requires.computed_columns_virtual.enabled: + self.check_column( + data, + "computed_virtual", + "normal+2", + False, + ) + if testing.requires.computed_columns_stored.enabled: + self.check_column( + data, + "computed_stored", + "normal-42", + True, + ) + + @testing.requires.schemas + def test_get_column_returns_persisted_with_schema(self): + insp = inspect(config.db) + + cols = insp.get_columns( + "computed_column_table", schema=config.test_schema + ) + data = {c["name"]: c for c in cols} + + self.check_column( + data, + "computed_no_flag", + "normal/42", + testing.requires.computed_columns_default_persisted.enabled, + ) + if testing.requires.computed_columns_virtual.enabled: + self.check_column( + data, + "computed_virtual", + "normal/2", + False, + ) + if testing.requires.computed_columns_stored.enabled: + self.check_column( + data, + "computed_stored", + "normal*42", + True, + ) + + +class IdentityReflectionTest(fixtures.TablesTest): + run_inserts = run_deletes = None + + __backend__ = True + __requires__ = ("identity_columns", "table_reflection") + + @classmethod + def define_tables(cls, metadata): + Table( + "t1", + metadata, + Column("normal", Integer), + Column("id1", Integer, Identity()), + ) + Table( + "t2", + metadata, + Column( + "id2", + Integer, + Identity( + always=True, + start=2, + increment=3, + minvalue=-2, + maxvalue=42, + cycle=True, + cache=4, + ), + ), + ) + if testing.requires.schemas.enabled: + Table( + "t1", + metadata, + Column("normal", Integer), + Column("id1", Integer, Identity(always=True, start=20)), + schema=config.test_schema, + ) + + def check(self, value, exp, approx): + if testing.requires.identity_columns_standard.enabled: + common_keys = ( + "always", + "start", + "increment", + "minvalue", + "maxvalue", + "cycle", + "cache", + ) + for k in list(value): + if k not in common_keys: + value.pop(k) + if approx: + eq_(len(value), len(exp)) + for k in value: + if k == "minvalue": + is_true(value[k] <= exp[k]) + elif k in {"maxvalue", "cache"}: + is_true(value[k] >= exp[k]) + else: + eq_(value[k], exp[k], k) + else: + eq_(value, exp) + else: + eq_(value["start"], exp["start"]) + eq_(value["increment"], exp["increment"]) + + def test_reflect_identity(self): + insp = inspect(config.db) + + cols = insp.get_columns("t1") + insp.get_columns("t2") + for col in cols: + if col["name"] == "normal": + is_false("identity" in col) + elif col["name"] == "id1": + if "autoincrement" in col: + is_true(col["autoincrement"]) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=False, + start=1, + increment=1, + minvalue=1, + maxvalue=2147483647, + cycle=False, + cache=1, + ), + approx=True, + ) + elif col["name"] == "id2": + if "autoincrement" in col: + is_true(col["autoincrement"]) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=True, + start=2, + increment=3, + minvalue=-2, + maxvalue=42, + cycle=True, + cache=4, + ), + approx=False, + ) + + @testing.requires.schemas + def test_reflect_identity_schema(self): + insp = inspect(config.db) + + cols = insp.get_columns("t1", schema=config.test_schema) + for col in cols: + if col["name"] == "normal": + is_false("identity" in col) + elif col["name"] == "id1": + if "autoincrement" in col: + is_true(col["autoincrement"]) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=True, + start=20, + increment=1, + minvalue=1, + maxvalue=2147483647, + cycle=False, + cache=1, + ), + approx=True, + ) + + +class CompositeKeyReflectionTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + tb1 = Table( + "tb1", + metadata, + Column("id", Integer), + Column("attr", Integer), + Column("name", sql_types.VARCHAR(20)), + sa.PrimaryKeyConstraint("name", "id", "attr", name="pk_tb1"), + schema=None, + test_needs_fk=True, + ) + Table( + "tb2", + metadata, + Column("id", Integer, primary_key=True), + Column("pid", Integer), + Column("pattr", Integer), + Column("pname", sql_types.VARCHAR(20)), + sa.ForeignKeyConstraint( + ["pname", "pid", "pattr"], + [tb1.c.name, tb1.c.id, tb1.c.attr], + name="fk_tb1_name_id_attr", + ), + schema=None, + test_needs_fk=True, + ) + + @testing.requires.primary_key_constraint_reflection + def test_pk_column_order(self, connection): + # test for issue #5661 + insp = inspect(connection) + primary_key = insp.get_pk_constraint(self.tables.tb1.name) + eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"]) + + @testing.requires.foreign_key_constraint_reflection + def test_fk_column_order(self, connection): + # test for issue #5661 + insp = inspect(connection) + foreign_keys = insp.get_foreign_keys(self.tables.tb2.name) + eq_(len(foreign_keys), 1) + fkey1 = foreign_keys[0] + eq_(fkey1.get("referred_columns"), ["name", "id", "attr"]) + eq_(fkey1.get("constrained_columns"), ["pname", "pid", "pattr"]) + + +__all__ = ( + "ComponentReflectionTest", + "ComponentReflectionTestExtra", + "TableNoColumnsTest", + "QuotedNameArgumentTest", + "BizarroCharacterFKResolutionTest", + "HasTableTest", + "HasIndexTest", + "NormalizedNameTest", + "ComputedReflectionTest", + "IdentityReflectionTest", + "CompositeKeyReflectionTest", +) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_results.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_results.py new file mode 100644 index 0000000..b3f432f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_results.py @@ -0,0 +1,468 @@ +# testing/suite/test_results.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import datetime + +from .. import engines +from .. import fixtures +from ..assertions import eq_ +from ..config import requirements +from ..schema import Column +from ..schema import Table +from ... import DateTime +from ... import func +from ... import Integer +from ... import select +from ... import sql +from ... import String +from ... import testing +from ... import text + + +class RowFetchTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "plain_pk", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + Table( + "has_dates", + metadata, + Column("id", Integer, primary_key=True), + Column("today", DateTime), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.plain_pk.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + connection.execute( + cls.tables.has_dates.insert(), + [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}], + ) + + def test_via_attr(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row.id, 1) + eq_(row.data, "d1") + + def test_via_string(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row._mapping["id"], 1) + eq_(row._mapping["data"], "d1") + + def test_via_int(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row[0], 1) + eq_(row[1], "d1") + + def test_via_col_object(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row._mapping[self.tables.plain_pk.c.id], 1) + eq_(row._mapping[self.tables.plain_pk.c.data], "d1") + + @requirements.duplicate_names_in_cursor_description + def test_row_with_dupe_names(self, connection): + result = connection.execute( + select( + self.tables.plain_pk.c.data, + self.tables.plain_pk.c.data.label("data"), + ).order_by(self.tables.plain_pk.c.id) + ) + row = result.first() + eq_(result.keys(), ["data", "data"]) + eq_(row, ("d1", "d1")) + + def test_row_w_scalar_select(self, connection): + """test that a scalar select as a column is returned as such + and that type conversion works OK. + + (this is half a SQLAlchemy Core test and half to catch database + backends that may have unusual behavior with scalar selects.) + + """ + datetable = self.tables.has_dates + s = select(datetable.alias("x").c.today).scalar_subquery() + s2 = select(datetable.c.id, s.label("somelabel")) + row = connection.execute(s2).first() + + eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0)) + + +class PercentSchemaNamesTest(fixtures.TablesTest): + """tests using percent signs, spaces in table and column names. + + This didn't work for PostgreSQL / MySQL drivers for a long time + but is now supported. + + """ + + __requires__ = ("percent_schema_names",) + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + cls.tables.percent_table = Table( + "percent%table", + metadata, + Column("percent%", Integer), + Column("spaces % more spaces", Integer), + ) + cls.tables.lightweight_percent_table = sql.table( + "percent%table", + sql.column("percent%"), + sql.column("spaces % more spaces"), + ) + + def test_single_roundtrip(self, connection): + percent_table = self.tables.percent_table + for params in [ + {"percent%": 5, "spaces % more spaces": 12}, + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, + ]: + connection.execute(percent_table.insert(), params) + self._assert_table(connection) + + def test_executemany_roundtrip(self, connection): + percent_table = self.tables.percent_table + connection.execute( + percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12} + ) + connection.execute( + percent_table.insert(), + [ + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, + ], + ) + self._assert_table(connection) + + @requirements.insert_executemany_returning + def test_executemany_returning_roundtrip(self, connection): + percent_table = self.tables.percent_table + connection.execute( + percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12} + ) + result = connection.execute( + percent_table.insert().returning( + percent_table.c["percent%"], + percent_table.c["spaces % more spaces"], + ), + [ + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, + ], + ) + eq_(result.all(), [(7, 11), (9, 10), (11, 9)]) + self._assert_table(connection) + + def _assert_table(self, conn): + percent_table = self.tables.percent_table + lightweight_percent_table = self.tables.lightweight_percent_table + + for table in ( + percent_table, + percent_table.alias(), + lightweight_percent_table, + lightweight_percent_table.alias(), + ): + eq_( + list( + conn.execute(table.select().order_by(table.c["percent%"])) + ), + [(5, 12), (7, 11), (9, 10), (11, 9)], + ) + + eq_( + list( + conn.execute( + table.select() + .where(table.c["spaces % more spaces"].in_([9, 10])) + .order_by(table.c["percent%"]) + ) + ), + [(9, 10), (11, 9)], + ) + + row = conn.execute( + table.select().order_by(table.c["percent%"]) + ).first() + eq_(row._mapping["percent%"], 5) + eq_(row._mapping["spaces % more spaces"], 12) + + eq_(row._mapping[table.c["percent%"]], 5) + eq_(row._mapping[table.c["spaces % more spaces"]], 12) + + conn.execute( + percent_table.update().values( + {percent_table.c["spaces % more spaces"]: 15} + ) + ) + + eq_( + list( + conn.execute( + percent_table.select().order_by( + percent_table.c["percent%"] + ) + ) + ), + [(5, 15), (7, 15), (9, 15), (11, 15)], + ) + + +class ServerSideCursorsTest( + fixtures.TestBase, testing.AssertsExecutionResults +): + __requires__ = ("server_side_cursors",) + + __backend__ = True + + def _is_server_side(self, cursor): + # TODO: this is a huge issue as it prevents these tests from being + # usable by third party dialects. + if self.engine.dialect.driver == "psycopg2": + return bool(cursor.name) + elif self.engine.dialect.driver == "pymysql": + sscursor = __import__("pymysql.cursors").cursors.SSCursor + return isinstance(cursor, sscursor) + elif self.engine.dialect.driver in ("aiomysql", "asyncmy", "aioodbc"): + return cursor.server_side + elif self.engine.dialect.driver == "mysqldb": + sscursor = __import__("MySQLdb.cursors").cursors.SSCursor + return isinstance(cursor, sscursor) + elif self.engine.dialect.driver == "mariadbconnector": + return not cursor.buffered + elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"): + return cursor.server_side + elif self.engine.dialect.driver == "pg8000": + return getattr(cursor, "server_side", False) + elif self.engine.dialect.driver == "psycopg": + return bool(getattr(cursor, "name", False)) + else: + return False + + def _fixture(self, server_side_cursors): + if server_side_cursors: + with testing.expect_deprecated( + "The create_engine.server_side_cursors parameter is " + "deprecated and will be removed in a future release. " + "Please use the Connection.execution_options.stream_results " + "parameter." + ): + self.engine = engines.testing_engine( + options={"server_side_cursors": server_side_cursors} + ) + else: + self.engine = engines.testing_engine( + options={"server_side_cursors": server_side_cursors} + ) + return self.engine + + @testing.combinations( + ("global_string", True, "select 1", True), + ("global_text", True, text("select 1"), True), + ("global_expr", True, select(1), True), + ("global_off_explicit", False, text("select 1"), False), + ( + "stmt_option", + False, + select(1).execution_options(stream_results=True), + True, + ), + ( + "stmt_option_disabled", + True, + select(1).execution_options(stream_results=False), + False, + ), + ("for_update_expr", True, select(1).with_for_update(), True), + # TODO: need a real requirement for this, or dont use this test + ( + "for_update_string", + True, + "SELECT 1 FOR UPDATE", + True, + testing.skip_if(["sqlite", "mssql"]), + ), + ("text_no_ss", False, text("select 42"), False), + ( + "text_ss_option", + False, + text("select 42").execution_options(stream_results=True), + True, + ), + id_="iaaa", + argnames="engine_ss_arg, statement, cursor_ss_status", + ) + def test_ss_cursor_status( + self, engine_ss_arg, statement, cursor_ss_status + ): + engine = self._fixture(engine_ss_arg) + with engine.begin() as conn: + if isinstance(statement, str): + result = conn.exec_driver_sql(statement) + else: + result = conn.execute(statement) + eq_(self._is_server_side(result.cursor), cursor_ss_status) + result.close() + + def test_conn_option(self): + engine = self._fixture(False) + + with engine.connect() as conn: + # should be enabled for this one + result = conn.execution_options( + stream_results=True + ).exec_driver_sql("select 1") + assert self._is_server_side(result.cursor) + + # the connection has autobegun, which means at the end of the + # block, we will roll back, which on MySQL at least will fail + # with "Commands out of sync" if the result set + # is not closed, so we close it first. + # + # fun fact! why did we not have this result.close() in this test + # before 2.0? don't we roll back in the connection pool + # unconditionally? yes! and in fact if you run this test in 1.4 + # with stdout shown, there is in fact "Exception during reset or + # similar" with "Commands out sync" emitted a warning! 2.0's + # architecture finds and fixes what was previously an expensive + # silent error condition. + result.close() + + def test_stmt_enabled_conn_option_disabled(self): + engine = self._fixture(False) + + s = select(1).execution_options(stream_results=True) + + with engine.connect() as conn: + # not this one + result = conn.execution_options(stream_results=False).execute(s) + assert not self._is_server_side(result.cursor) + + def test_aliases_and_ss(self): + engine = self._fixture(False) + s1 = ( + select(sql.literal_column("1").label("x")) + .execution_options(stream_results=True) + .subquery() + ) + + # options don't propagate out when subquery is used as a FROM clause + with engine.begin() as conn: + result = conn.execute(s1.select()) + assert not self._is_server_side(result.cursor) + result.close() + + s2 = select(1).select_from(s1) + with engine.begin() as conn: + result = conn.execute(s2) + assert not self._is_server_side(result.cursor) + result.close() + + def test_roundtrip_fetchall(self, metadata): + md = self.metadata + + engine = self._fixture(True) + test_table = Table( + "test_table", + md, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + with engine.begin() as connection: + test_table.create(connection, checkfirst=True) + connection.execute(test_table.insert(), dict(data="data1")) + connection.execute(test_table.insert(), dict(data="data2")) + eq_( + connection.execute( + test_table.select().order_by(test_table.c.id) + ).fetchall(), + [(1, "data1"), (2, "data2")], + ) + connection.execute( + test_table.update() + .where(test_table.c.id == 2) + .values(data=test_table.c.data + " updated") + ) + eq_( + connection.execute( + test_table.select().order_by(test_table.c.id) + ).fetchall(), + [(1, "data1"), (2, "data2 updated")], + ) + connection.execute(test_table.delete()) + eq_( + connection.scalar( + select(func.count("*")).select_from(test_table) + ), + 0, + ) + + def test_roundtrip_fetchmany(self, metadata): + md = self.metadata + + engine = self._fixture(True) + test_table = Table( + "test_table", + md, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + with engine.begin() as connection: + test_table.create(connection, checkfirst=True) + connection.execute( + test_table.insert(), + [dict(data="data%d" % i) for i in range(1, 20)], + ) + + result = connection.execute( + test_table.select().order_by(test_table.c.id) + ) + + eq_( + result.fetchmany(5), + [(i, "data%d" % i) for i in range(1, 6)], + ) + eq_( + result.fetchmany(10), + [(i, "data%d" % i) for i in range(6, 16)], + ) + eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)]) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_rowcount.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_rowcount.py new file mode 100644 index 0000000..a7dbd36 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_rowcount.py @@ -0,0 +1,258 @@ +# testing/suite/test_rowcount.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from sqlalchemy import bindparam +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy import text +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures + + +class RowCountTest(fixtures.TablesTest): + """test rowcount functionality""" + + __requires__ = ("sane_rowcount",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "employees", + metadata, + Column( + "employee_id", + Integer, + autoincrement=False, + primary_key=True, + ), + Column("name", String(50)), + Column("department", String(1)), + ) + + @classmethod + def insert_data(cls, connection): + cls.data = data = [ + ("Angela", "A"), + ("Andrew", "A"), + ("Anand", "A"), + ("Bob", "B"), + ("Bobette", "B"), + ("Buffy", "B"), + ("Charlie", "C"), + ("Cynthia", "C"), + ("Chris", "C"), + ] + + employees_table = cls.tables.employees + connection.execute( + employees_table.insert(), + [ + {"employee_id": i, "name": n, "department": d} + for i, (n, d) in enumerate(data) + ], + ) + + def test_basic(self, connection): + employees_table = self.tables.employees + s = select( + employees_table.c.name, employees_table.c.department + ).order_by(employees_table.c.employee_id) + rows = connection.execute(s).fetchall() + + eq_(rows, self.data) + + @testing.variation("statement", ["update", "delete", "insert", "select"]) + @testing.variation("close_first", [True, False]) + def test_non_rowcount_scenarios_no_raise( + self, connection, statement, close_first + ): + employees_table = self.tables.employees + + # WHERE matches 3, 3 rows changed + department = employees_table.c.department + + if statement.update: + r = connection.execute( + employees_table.update().where(department == "C"), + {"department": "Z"}, + ) + elif statement.delete: + r = connection.execute( + employees_table.delete().where(department == "C"), + {"department": "Z"}, + ) + elif statement.insert: + r = connection.execute( + employees_table.insert(), + [ + {"employee_id": 25, "name": "none 1", "department": "X"}, + {"employee_id": 26, "name": "none 2", "department": "Z"}, + {"employee_id": 27, "name": "none 3", "department": "Z"}, + ], + ) + elif statement.select: + s = select( + employees_table.c.name, employees_table.c.department + ).where(employees_table.c.department == "C") + r = connection.execute(s) + r.all() + else: + statement.fail() + + if close_first: + r.close() + + assert r.rowcount in (-1, 3) + + def test_update_rowcount1(self, connection): + employees_table = self.tables.employees + + # WHERE matches 3, 3 rows changed + department = employees_table.c.department + r = connection.execute( + employees_table.update().where(department == "C"), + {"department": "Z"}, + ) + assert r.rowcount == 3 + + def test_update_rowcount2(self, connection): + employees_table = self.tables.employees + + # WHERE matches 3, 0 rows changed + department = employees_table.c.department + + r = connection.execute( + employees_table.update().where(department == "C"), + {"department": "C"}, + ) + eq_(r.rowcount, 3) + + @testing.variation("implicit_returning", [True, False]) + @testing.variation( + "dml", + [ + ("update", testing.requires.update_returning), + ("delete", testing.requires.delete_returning), + ], + ) + def test_update_delete_rowcount_return_defaults( + self, connection, implicit_returning, dml + ): + """note this test should succeed for all RETURNING backends + as of 2.0. In + Idf28379f8705e403a3c6a937f6a798a042ef2540 we changed rowcount to use + len(rows) when we have implicit returning + + """ + + if implicit_returning: + employees_table = self.tables.employees + else: + employees_table = Table( + "employees", + MetaData(), + Column( + "employee_id", + Integer, + autoincrement=False, + primary_key=True, + ), + Column("name", String(50)), + Column("department", String(1)), + implicit_returning=False, + ) + + department = employees_table.c.department + + if dml.update: + stmt = ( + employees_table.update() + .where(department == "C") + .values(name=employees_table.c.department + "Z") + .return_defaults() + ) + elif dml.delete: + stmt = ( + employees_table.delete() + .where(department == "C") + .return_defaults() + ) + else: + dml.fail() + + r = connection.execute(stmt) + eq_(r.rowcount, 3) + + def test_raw_sql_rowcount(self, connection): + # test issue #3622, make sure eager rowcount is called for text + result = connection.exec_driver_sql( + "update employees set department='Z' where department='C'" + ) + eq_(result.rowcount, 3) + + def test_text_rowcount(self, connection): + # test issue #3622, make sure eager rowcount is called for text + result = connection.execute( + text("update employees set department='Z' where department='C'") + ) + eq_(result.rowcount, 3) + + def test_delete_rowcount(self, connection): + employees_table = self.tables.employees + + # WHERE matches 3, 3 rows deleted + department = employees_table.c.department + r = connection.execute( + employees_table.delete().where(department == "C") + ) + eq_(r.rowcount, 3) + + @testing.requires.sane_multi_rowcount + def test_multi_update_rowcount(self, connection): + employees_table = self.tables.employees + stmt = ( + employees_table.update() + .where(employees_table.c.name == bindparam("emp_name")) + .values(department="C") + ) + + r = connection.execute( + stmt, + [ + {"emp_name": "Bob"}, + {"emp_name": "Cynthia"}, + {"emp_name": "nonexistent"}, + ], + ) + + eq_(r.rowcount, 2) + + @testing.requires.sane_multi_rowcount + def test_multi_delete_rowcount(self, connection): + employees_table = self.tables.employees + + stmt = employees_table.delete().where( + employees_table.c.name == bindparam("emp_name") + ) + + r = connection.execute( + stmt, + [ + {"emp_name": "Bob"}, + {"emp_name": "Cynthia"}, + {"emp_name": "nonexistent"}, + ], + ) + + eq_(r.rowcount, 2) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_select.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_select.py new file mode 100644 index 0000000..866bf09 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_select.py @@ -0,0 +1,1888 @@ +# testing/suite/test_select.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import collections.abc as collections_abc +import itertools + +from .. import AssertsCompiledSQL +from .. import AssertsExecutionResults +from .. import config +from .. import fixtures +from ..assertions import assert_raises +from ..assertions import eq_ +from ..assertions import in_ +from ..assertsql import CursorSQL +from ..schema import Column +from ..schema import Table +from ... import bindparam +from ... import case +from ... import column +from ... import Computed +from ... import exists +from ... import false +from ... import ForeignKey +from ... import func +from ... import Identity +from ... import Integer +from ... import literal +from ... import literal_column +from ... import null +from ... import select +from ... import String +from ... import table +from ... import testing +from ... import text +from ... import true +from ... import tuple_ +from ... import TupleType +from ... import union +from ... import values +from ...exc import DatabaseError +from ...exc import ProgrammingError + + +class CollateTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(100)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "collate data1"}, + {"id": 2, "data": "collate data2"}, + ], + ) + + def _assert_result(self, select, result): + with config.db.connect() as conn: + eq_(conn.execute(select).fetchall(), result) + + @testing.requires.order_by_collation + def test_collate_order_by(self): + collation = testing.requires.get_order_by_collation(testing.config) + + self._assert_result( + select(self.tables.some_table).order_by( + self.tables.some_table.c.data.collate(collation).asc() + ), + [(1, "collate data1"), (2, "collate data2")], + ) + + +class OrderByLabelTest(fixtures.TablesTest): + """Test the dialect sends appropriate ORDER BY expressions when + labels are used. + + This essentially exercises the "supports_simple_order_by_label" + setting. + + """ + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("q", String(50)), + Column("p", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"}, + {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"}, + {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"}, + ], + ) + + def _assert_result(self, select, result): + with config.db.connect() as conn: + eq_(conn.execute(select).fetchall(), result) + + def test_plain(self): + table = self.tables.some_table + lx = table.c.x.label("lx") + self._assert_result(select(lx).order_by(lx), [(1,), (2,), (3,)]) + + def test_composed_int(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label("lx") + self._assert_result(select(lx).order_by(lx), [(3,), (5,), (7,)]) + + def test_composed_multiple(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label("lx") + ly = (func.lower(table.c.q) + table.c.p).label("ly") + self._assert_result( + select(lx, ly).order_by(lx, ly.desc()), + [(3, "q1p3"), (5, "q2p2"), (7, "q3p1")], + ) + + def test_plain_desc(self): + table = self.tables.some_table + lx = table.c.x.label("lx") + self._assert_result(select(lx).order_by(lx.desc()), [(3,), (2,), (1,)]) + + def test_composed_int_desc(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label("lx") + self._assert_result(select(lx).order_by(lx.desc()), [(7,), (5,), (3,)]) + + @testing.requires.group_by_complex_expression + def test_group_by_composed(self): + table = self.tables.some_table + expr = (table.c.x + table.c.y).label("lx") + stmt = ( + select(func.count(table.c.id), expr).group_by(expr).order_by(expr) + ) + self._assert_result(stmt, [(1, 3), (1, 5), (1, 7)]) + + +class ValuesExpressionTest(fixtures.TestBase): + __requires__ = ("table_value_constructor",) + + __backend__ = True + + def test_tuples(self, connection): + value_expr = values( + column("id", Integer), column("name", String), name="my_values" + ).data([(1, "name1"), (2, "name2"), (3, "name3")]) + + eq_( + connection.execute(select(value_expr)).all(), + [(1, "name1"), (2, "name2"), (3, "name3")], + ) + + +class FetchLimitOffsetTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + {"id": 5, "x": 4, "y": 6}, + ], + ) + + def _assert_result( + self, connection, select, result, params=(), set_=False + ): + if set_: + query_res = connection.execute(select, params).fetchall() + eq_(len(query_res), len(result)) + eq_(set(query_res), set(result)) + + else: + eq_(connection.execute(select, params).fetchall(), result) + + def _assert_result_str(self, select, result, params=()): + with config.db.connect() as conn: + eq_(conn.exec_driver_sql(select, params).fetchall(), result) + + def test_simple_limit(self, connection): + table = self.tables.some_table + stmt = select(table).order_by(table.c.id) + self._assert_result( + connection, + stmt.limit(2), + [(1, 1, 2), (2, 2, 3)], + ) + self._assert_result( + connection, + stmt.limit(3), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + ) + + def test_limit_render_multiple_times(self, connection): + table = self.tables.some_table + stmt = select(table.c.id).limit(1).scalar_subquery() + + u = union(select(stmt), select(stmt)).subquery().select() + + self._assert_result( + connection, + u, + [ + (1,), + ], + ) + + @testing.requires.fetch_first + def test_simple_fetch(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(2), + [(1, 1, 2), (2, 2, 3)], + ) + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(3), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.offset + def test_simple_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(2), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(3), + [(4, 4, 5), (5, 4, 6)], + ) + + @testing.combinations( + ([(2, 0), (2, 1), (3, 2)]), + ([(2, 1), (2, 0), (3, 2)]), + ([(3, 1), (2, 1), (3, 1)]), + argnames="cases", + ) + @testing.requires.offset + def test_simple_limit_offset(self, connection, cases): + table = self.tables.some_table + connection = connection.execution_options(compiled_cache={}) + + assert_data = [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)] + + for limit, offset in cases: + expected = assert_data[offset : offset + limit] + self._assert_result( + connection, + select(table).order_by(table.c.id).limit(limit).offset(offset), + expected, + ) + + @testing.requires.fetch_first + def test_simple_fetch_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(2).offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(3).offset(2), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.fetch_no_order_by + def test_fetch_offset_no_order(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).fetch(10), + [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + set_=True, + ) + + @testing.requires.offset + def test_simple_offset_zero(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(0), + [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(1), + [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.offset + def test_limit_offset_nobinds(self): + """test that 'literal binds' mode works - no bound params.""" + + table = self.tables.some_table + stmt = select(table).order_by(table.c.id).limit(2).offset(1) + sql = stmt.compile( + dialect=config.db.dialect, compile_kwargs={"literal_binds": True} + ) + sql = str(sql) + + self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)]) + + @testing.requires.fetch_first + def test_fetch_offset_nobinds(self): + """test that 'literal binds' mode works - no bound params.""" + + table = self.tables.some_table + stmt = select(table).order_by(table.c.id).fetch(2).offset(1) + sql = stmt.compile( + dialect=config.db.dialect, compile_kwargs={"literal_binds": True} + ) + sql = str(sql) + + self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)]) + + @testing.requires.bound_limit_offset + def test_bound_limit(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).limit(bindparam("l")), + [(1, 1, 2), (2, 2, 3)], + params={"l": 2}, + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).limit(bindparam("l")), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + params={"l": 3}, + ) + + @testing.requires.bound_limit_offset + def test_bound_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(bindparam("o")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"o": 2}, + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(bindparam("o")), + [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"o": 1}, + ) + + @testing.requires.bound_limit_offset + def test_bound_limit_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(bindparam("l")) + .offset(bindparam("o")), + [(2, 2, 3), (3, 3, 4)], + params={"l": 2, "o": 1}, + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(bindparam("l")) + .offset(bindparam("o")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"l": 3, "o": 2}, + ) + + @testing.requires.fetch_first + def test_bound_fetch_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(bindparam("f")) + .offset(bindparam("o")), + [(2, 2, 3), (3, 3, 4)], + params={"f": 2, "o": 1}, + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(bindparam("f")) + .offset(bindparam("o")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"f": 3, "o": 2}, + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .offset(literal_column("1") + literal_column("2")), + [(4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_limit(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("2")), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_limit_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("1")) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5)], + ) + + @testing.requires.fetch_first + @testing.requires.fetch_expression + def test_expr_fetch_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(literal_column("1") + literal_column("1")) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5)], + ) + + @testing.requires.sql_expression_limit_offset + def test_simple_limit_expr_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(2) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5)], + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(3) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_limit_simple_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("1")) + .offset(2), + [(3, 3, 4), (4, 4, 5)], + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("1")) + .offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.fetch_ties + def test_simple_fetch_ties(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.x.desc()).fetch(1, with_ties=True), + [(4, 4, 5), (5, 4, 6)], + set_=True, + ) + + self._assert_result( + connection, + select(table).order_by(table.c.x.desc()).fetch(3, with_ties=True), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + set_=True, + ) + + @testing.requires.fetch_ties + @testing.requires.fetch_offset_with_options + def test_fetch_offset_ties(self, connection): + table = self.tables.some_table + fa = connection.execute( + select(table) + .order_by(table.c.x) + .fetch(2, with_ties=True) + .offset(2) + ).fetchall() + eq_(fa[0], (3, 3, 4)) + eq_(set(fa), {(3, 3, 4), (4, 4, 5), (5, 4, 6)}) + + @testing.requires.fetch_ties + @testing.requires.fetch_offset_with_options + def test_fetch_offset_ties_exact_number(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.x) + .fetch(2, with_ties=True) + .offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.x) + .fetch(3, with_ties=True) + .offset(3), + [(4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.fetch_percent + def test_simple_fetch_percent(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(20, percent=True), + [(1, 1, 2)], + ) + + @testing.requires.fetch_percent + @testing.requires.fetch_offset_with_options + def test_fetch_offset_percent(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(40, percent=True) + .offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.fetch_ties + @testing.requires.fetch_percent + def test_simple_fetch_percent_ties(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.x.desc()) + .fetch(20, percent=True, with_ties=True), + [(4, 4, 5), (5, 4, 6)], + set_=True, + ) + + @testing.requires.fetch_ties + @testing.requires.fetch_percent + @testing.requires.fetch_offset_with_options + def test_fetch_offset_percent_ties(self, connection): + table = self.tables.some_table + fa = connection.execute( + select(table) + .order_by(table.c.x) + .fetch(40, percent=True, with_ties=True) + .offset(2) + ).fetchall() + eq_(fa[0], (3, 3, 4)) + eq_(set(fa), {(3, 3, 4), (4, 4, 5), (5, 4, 6)}) + + +class SameNamedSchemaTableTest(fixtures.TablesTest): + """tests for #7471""" + + __backend__ = True + + __requires__ = ("schemas",) + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + schema=config.test_schema, + ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column( + "some_table_id", + Integer, + # ForeignKey("%s.some_table.id" % config.test_schema), + nullable=False, + ), + ) + + @classmethod + def insert_data(cls, connection): + some_table, some_table_schema = cls.tables( + "some_table", "%s.some_table" % config.test_schema + ) + connection.execute(some_table_schema.insert(), {"id": 1}) + connection.execute(some_table.insert(), {"id": 1, "some_table_id": 1}) + + def test_simple_join_both_tables(self, connection): + some_table, some_table_schema = self.tables( + "some_table", "%s.some_table" % config.test_schema + ) + + eq_( + connection.execute( + select(some_table, some_table_schema).join_from( + some_table, + some_table_schema, + some_table.c.some_table_id == some_table_schema.c.id, + ) + ).first(), + (1, 1, 1), + ) + + def test_simple_join_whereclause_only(self, connection): + some_table, some_table_schema = self.tables( + "some_table", "%s.some_table" % config.test_schema + ) + + eq_( + connection.execute( + select(some_table) + .join_from( + some_table, + some_table_schema, + some_table.c.some_table_id == some_table_schema.c.id, + ) + .where(some_table.c.id == 1) + ).first(), + (1, 1), + ) + + def test_subquery(self, connection): + some_table, some_table_schema = self.tables( + "some_table", "%s.some_table" % config.test_schema + ) + + subq = ( + select(some_table) + .join_from( + some_table, + some_table_schema, + some_table.c.some_table_id == some_table_schema.c.id, + ) + .where(some_table.c.id == 1) + .subquery() + ) + + eq_( + connection.execute( + select(some_table, subq.c.id) + .join_from( + some_table, + subq, + some_table.c.some_table_id == subq.c.id, + ) + .where(some_table.c.id == 1) + ).first(), + (1, 1, 1), + ) + + +class JoinTest(fixtures.TablesTest): + __backend__ = True + + def _assert_result(self, select, result, params=()): + with config.db.connect() as conn: + eq_(conn.execute(select, params).fetchall(), result) + + @classmethod + def define_tables(cls, metadata): + Table("a", metadata, Column("id", Integer, primary_key=True)) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("a_id", ForeignKey("a.id"), nullable=False), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.a.insert(), + [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}], + ) + + connection.execute( + cls.tables.b.insert(), + [ + {"id": 1, "a_id": 1}, + {"id": 2, "a_id": 1}, + {"id": 4, "a_id": 2}, + {"id": 5, "a_id": 3}, + ], + ) + + def test_inner_join_fk(self): + a, b = self.tables("a", "b") + + stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id) + + self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)]) + + def test_inner_join_true(self): + a, b = self.tables("a", "b") + + stmt = ( + select(a, b) + .select_from(a.join(b, true())) + .order_by(a.c.id, b.c.id) + ) + + self._assert_result( + stmt, + [ + (a, b, c) + for (a,), (b, c) in itertools.product( + [(1,), (2,), (3,), (4,), (5,)], + [(1, 1), (2, 1), (4, 2), (5, 3)], + ) + ], + ) + + def test_inner_join_false(self): + a, b = self.tables("a", "b") + + stmt = ( + select(a, b) + .select_from(a.join(b, false())) + .order_by(a.c.id, b.c.id) + ) + + self._assert_result(stmt, []) + + def test_outer_join_false(self): + a, b = self.tables("a", "b") + + stmt = ( + select(a, b) + .select_from(a.outerjoin(b, false())) + .order_by(a.c.id, b.c.id) + ) + + self._assert_result( + stmt, + [ + (1, None, None), + (2, None, None), + (3, None, None), + (4, None, None), + (5, None, None), + ], + ) + + def test_outer_join_fk(self): + a, b = self.tables("a", "b") + + stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id) + + self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)]) + + +class CompoundSelectTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + ], + ) + + def _assert_result(self, select, result, params=()): + with config.db.connect() as conn: + eq_(conn.execute(select, params).fetchall(), result) + + def test_plain_union(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_select_from_plain_union(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2).alias().select() + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.order_by_col_from_union + @testing.requires.parens_in_union_contained_select_w_limit_offset + def test_limit_offset_selectable_in_unions(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_wo_limit_offset + def test_order_by_selectable_in_unions(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_distinct_selectable_in_unions(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).distinct() + s2 = select(table).where(table.c.id == 3).distinct() + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_w_limit_offset + def test_limit_offset_in_unions_from_alias(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id) + + # this necessarily has double parens + u1 = union(s1, s2).alias() + self._assert_result( + u1.select().limit(2).order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_limit_offset_aliased_selectable_in_unions(self): + table = self.tables.some_table + s1 = ( + select(table) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + s2 = ( + select(table) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + +class PostCompileParamsTest( + AssertsExecutionResults, AssertsCompiledSQL, fixtures.TablesTest +): + __backend__ = True + + __requires__ = ("standard_cursor_sql",) + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("z", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2, "z": "z1"}, + {"id": 2, "x": 2, "y": 3, "z": "z2"}, + {"id": 3, "x": 3, "y": 4, "z": "z3"}, + {"id": 4, "x": 4, "y": 5, "z": "z4"}, + ], + ) + + def test_compile(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x == bindparam("q", literal_execute=True) + ) + + self.assert_compile( + stmt, + "SELECT some_table.id FROM some_table " + "WHERE some_table.x = __[POSTCOMPILE_q]", + {}, + ) + + def test_compile_literal_binds(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x == bindparam("q", 10, literal_execute=True) + ) + + self.assert_compile( + stmt, + "SELECT some_table.id FROM some_table WHERE some_table.x = 10", + {}, + literal_binds=True, + ) + + def test_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x == bindparam("q", literal_execute=True) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=10)) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE some_table.x = 10", + () if config.db.dialect.positional else {}, + ) + ) + + def test_execute_expanding_plus_literal_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x.in_(bindparam("q", expanding=True, literal_execute=True)) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[5, 6, 7])) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE some_table.x IN (5, 6, 7)", + () if config.db.dialect.positional else {}, + ) + ) + + @testing.requires.tuple_in + def test_execute_tuple_expanding_plus_literal_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + tuple_(table.c.x, table.c.y).in_( + bindparam("q", expanding=True, literal_execute=True) + ) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[(5, 10), (12, 18)])) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE (some_table.x, some_table.y) " + "IN (%s(5, 10), (12, 18))" + % ("VALUES " if config.db.dialect.tuple_in_values else ""), + () if config.db.dialect.positional else {}, + ) + ) + + @testing.requires.tuple_in + def test_execute_tuple_expanding_plus_literal_heterogeneous_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + tuple_(table.c.x, table.c.z).in_( + bindparam("q", expanding=True, literal_execute=True) + ) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[(5, "z1"), (12, "z3")])) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE (some_table.x, some_table.z) " + "IN (%s(5, 'z1'), (12, 'z3'))" + % ("VALUES " if config.db.dialect.tuple_in_values else ""), + () if config.db.dialect.positional else {}, + ) + ) + + +class ExpandingBoundInTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("z", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2, "z": "z1"}, + {"id": 2, "x": 2, "y": 3, "z": "z2"}, + {"id": 3, "x": 3, "y": 4, "z": "z3"}, + {"id": 4, "x": 4, "y": 5, "z": "z4"}, + ], + ) + + def _assert_result(self, select, result, params=()): + with config.db.connect() as conn: + eq_(conn.execute(select, params).fetchall(), result) + + def test_multiple_empty_sets_bindparam(self): + # test that any anonymous aliasing used by the dialect + # is fine with duplicates + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_(bindparam("q"))) + .where(table.c.y.in_(bindparam("p"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": [], "p": []}) + + def test_multiple_empty_sets_direct(self): + # test that any anonymous aliasing used by the dialect + # is fine with duplicates + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([])) + .where(table.c.y.in_([])) + .order_by(table.c.id) + ) + self._assert_result(stmt, []) + + @testing.requires.tuple_in_w_empty + def test_empty_heterogeneous_tuples_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + @testing.requires.tuple_in_w_empty + def test_empty_heterogeneous_tuples_direct(self): + table = self.tables.some_table + + def go(val, expected): + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(val)) + .order_by(table.c.id) + ) + self._assert_result(stmt, expected) + + go([], []) + go([(2, "z2"), (3, "z3"), (4, "z4")], [(2,), (3,), (4,)]) + go([], []) + + @testing.requires.tuple_in_w_empty + def test_empty_homogeneous_tuples_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + @testing.requires.tuple_in_w_empty + def test_empty_homogeneous_tuples_direct(self): + table = self.tables.some_table + + def go(val, expected): + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(val)) + .order_by(table.c.id) + ) + self._assert_result(stmt, expected) + + go([], []) + go([(1, 2), (2, 3), (3, 4)], [(1,), (2,), (3,)]) + go([], []) + + def test_bound_in_scalar_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]}) + + def test_bound_in_scalar_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([2, 3, 4])) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,), (4,)]) + + def test_nonempty_in_plus_empty_notin(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([2, 3])) + .where(table.c.id.not_in([])) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,)]) + + def test_empty_in_plus_notempty_notin(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([])) + .where(table.c.id.not_in([2, 3])) + .order_by(table.c.id) + ) + self._assert_result(stmt, []) + + def test_typed_str_in(self): + """test related to #7292. + + as a type is given to the bound param, there is no ambiguity + to the type of element. + + """ + + stmt = text( + "select id FROM some_table WHERE z IN :q ORDER BY id" + ).bindparams(bindparam("q", type_=String, expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": ["z2", "z3", "z4"]}, + ) + + def test_untyped_str_in(self): + """test related to #7292. + + for untyped expression, we look at the types of elements. + Test for Sequence to detect tuple in. but not strings or bytes! + as always.... + + """ + + stmt = text( + "select id FROM some_table WHERE z IN :q ORDER BY id" + ).bindparams(bindparam("q", expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": ["z2", "z3", "z4"]}, + ) + + @testing.requires.tuple_in + def test_bound_in_two_tuple_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result( + stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]} + ) + + @testing.requires.tuple_in + def test_bound_in_two_tuple_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_([(2, 3), (3, 4), (4, 5)])) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,), (4,)]) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where( + tuple_(table.c.x, table.c.z).in_( + [(2, "z2"), (3, "z3"), (4, "z4")] + ) + ) + .order_by(table.c.id) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_text_bindparam(self): + # note this becomes ARRAY if we dont use expanding + # explicitly right now + stmt = text( + "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" + ).bindparams(bindparam("q", expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_typed_bindparam_non_tuple(self): + class LikeATuple(collections_abc.Sequence): + def __init__(self, *data): + self._data = data + + def __iter__(self): + return iter(self._data) + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + stmt = text( + "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" + ).bindparams( + bindparam( + "q", type_=TupleType(Integer(), String()), expanding=True + ) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={ + "q": [ + LikeATuple(2, "z2"), + LikeATuple(3, "z3"), + LikeATuple(4, "z4"), + ] + }, + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_text_bindparam_non_tuple(self): + # note this becomes ARRAY if we dont use expanding + # explicitly right now + + class LikeATuple(collections_abc.Sequence): + def __init__(self, *data): + self._data = data + + def __iter__(self): + return iter(self._data) + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + stmt = text( + "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" + ).bindparams(bindparam("q", expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={ + "q": [ + LikeATuple(2, "z2"), + LikeATuple(3, "z3"), + LikeATuple(4, "z4"), + ] + }, + ) + + def test_empty_set_against_integer_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + def test_empty_set_against_integer_direct(self): + table = self.tables.some_table + stmt = select(table.c.id).where(table.c.x.in_([])).order_by(table.c.id) + self._assert_result(stmt, []) + + def test_empty_set_against_integer_negation_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.not_in(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) + + def test_empty_set_against_integer_negation_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id).where(table.c.x.not_in([])).order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)]) + + def test_empty_set_against_string_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.z.in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + def test_empty_set_against_string_direct(self): + table = self.tables.some_table + stmt = select(table.c.id).where(table.c.z.in_([])).order_by(table.c.id) + self._assert_result(stmt, []) + + def test_empty_set_against_string_negation_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.z.not_in(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) + + def test_empty_set_against_string_negation_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id).where(table.c.z.not_in([])).order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)]) + + def test_null_in_empty_set_is_false_bindparam(self, connection): + stmt = select( + case( + ( + null().in_(bindparam("foo", value=())), + true(), + ), + else_=false(), + ) + ) + in_(connection.execute(stmt).fetchone()[0], (False, 0)) + + def test_null_in_empty_set_is_false_direct(self, connection): + stmt = select( + case( + ( + null().in_([]), + true(), + ), + else_=false(), + ) + ) + in_(connection.execute(stmt).fetchone()[0], (False, 0)) + + +class LikeFunctionsTest(fixtures.TablesTest): + __backend__ = True + + run_inserts = "once" + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "abcdefg"}, + {"id": 2, "data": "ab/cdefg"}, + {"id": 3, "data": "ab%cdefg"}, + {"id": 4, "data": "ab_cdefg"}, + {"id": 5, "data": "abcde/fg"}, + {"id": 6, "data": "abcde%fg"}, + {"id": 7, "data": "ab#cdefg"}, + {"id": 8, "data": "ab9cdefg"}, + {"id": 9, "data": "abcde#fg"}, + {"id": 10, "data": "abcd9fg"}, + {"id": 11, "data": None}, + ], + ) + + def _test(self, expr, expected): + some_table = self.tables.some_table + + with config.db.connect() as conn: + rows = { + value + for value, in conn.execute(select(some_table.c.id).where(expr)) + } + + eq_(rows, expected) + + def test_startswith_unescaped(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab%c"), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + + def test_startswith_autoescape(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab%c", autoescape=True), {3}) + + def test_startswith_sqlexpr(self): + col = self.tables.some_table.c.data + self._test( + col.startswith(literal_column("'ab%c'")), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + ) + + def test_startswith_escape(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab##c", escape="#"), {7}) + + def test_startswith_autoescape_escape(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab%c", autoescape=True, escape="#"), {3}) + self._test(col.startswith("ab#c", autoescape=True, escape="#"), {7}) + + def test_endswith_unescaped(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e%fg"), {1, 2, 3, 4, 5, 6, 7, 8, 9}) + + def test_endswith_sqlexpr(self): + col = self.tables.some_table.c.data + self._test( + col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9} + ) + + def test_endswith_autoescape(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e%fg", autoescape=True), {6}) + + def test_endswith_escape(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e##fg", escape="#"), {9}) + + def test_endswith_autoescape_escape(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e%fg", autoescape=True, escape="#"), {6}) + self._test(col.endswith("e#fg", autoescape=True, escape="#"), {9}) + + def test_contains_unescaped(self): + col = self.tables.some_table.c.data + self._test(col.contains("b%cde"), {1, 2, 3, 4, 5, 6, 7, 8, 9}) + + def test_contains_autoescape(self): + col = self.tables.some_table.c.data + self._test(col.contains("b%cde", autoescape=True), {3}) + + def test_contains_escape(self): + col = self.tables.some_table.c.data + self._test(col.contains("b##cde", escape="#"), {7}) + + def test_contains_autoescape_escape(self): + col = self.tables.some_table.c.data + self._test(col.contains("b%cd", autoescape=True, escape="#"), {3}) + self._test(col.contains("b#cd", autoescape=True, escape="#"), {7}) + + @testing.requires.regexp_match + def test_not_regexp_match(self): + col = self.tables.some_table.c.data + self._test(~col.regexp_match("a.cde"), {2, 3, 4, 7, 8, 10}) + + @testing.requires.regexp_replace + def test_regexp_replace(self): + col = self.tables.some_table.c.data + self._test( + col.regexp_replace("a.cde", "FOO").contains("FOO"), {1, 5, 6, 9} + ) + + @testing.requires.regexp_match + @testing.combinations( + ("a.cde", {1, 5, 6, 9}), + ("abc", {1, 5, 6, 9, 10}), + ("^abc", {1, 5, 6, 9, 10}), + ("9cde", {8}), + ("^a", set(range(1, 11))), + ("(b|c)", set(range(1, 11))), + ("^(b|c)", set()), + ) + def test_regexp_match(self, text, expected): + col = self.tables.some_table.c.data + self._test(col.regexp_match(text), expected) + + +class ComputedColumnTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("computed_columns",) + + @classmethod + def define_tables(cls, metadata): + Table( + "square", + metadata, + Column("id", Integer, primary_key=True), + Column("side", Integer), + Column("area", Integer, Computed("side * side")), + Column("perimeter", Integer, Computed("4 * side")), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.square.insert(), + [{"id": 1, "side": 10}, {"id": 10, "side": 42}], + ) + + def test_select_all(self): + with config.db.connect() as conn: + res = conn.execute( + select(text("*")) + .select_from(self.tables.square) + .order_by(self.tables.square.c.id) + ).fetchall() + eq_(res, [(1, 10, 100, 40), (10, 42, 1764, 168)]) + + def test_select_columns(self): + with config.db.connect() as conn: + res = conn.execute( + select( + self.tables.square.c.area, self.tables.square.c.perimeter + ) + .select_from(self.tables.square) + .order_by(self.tables.square.c.id) + ).fetchall() + eq_(res, [(100, 40), (1764, 168)]) + + +class IdentityColumnTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("identity_columns",) + run_inserts = "once" + run_deletes = "once" + + @classmethod + def define_tables(cls, metadata): + Table( + "tbl_a", + metadata, + Column( + "id", + Integer, + Identity( + always=True, start=42, nominvalue=True, nomaxvalue=True + ), + primary_key=True, + ), + Column("desc", String(100)), + ) + Table( + "tbl_b", + metadata, + Column( + "id", + Integer, + Identity(increment=-5, start=0, minvalue=-1000, maxvalue=0), + primary_key=True, + ), + Column("desc", String(100)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.tbl_a.insert(), + [{"desc": "a"}, {"desc": "b"}], + ) + connection.execute( + cls.tables.tbl_b.insert(), + [{"desc": "a"}, {"desc": "b"}], + ) + connection.execute( + cls.tables.tbl_b.insert(), + [{"id": 42, "desc": "c"}], + ) + + def test_select_all(self, connection): + res = connection.execute( + select(text("*")) + .select_from(self.tables.tbl_a) + .order_by(self.tables.tbl_a.c.id) + ).fetchall() + eq_(res, [(42, "a"), (43, "b")]) + + res = connection.execute( + select(text("*")) + .select_from(self.tables.tbl_b) + .order_by(self.tables.tbl_b.c.id) + ).fetchall() + eq_(res, [(-5, "b"), (0, "a"), (42, "c")]) + + def test_select_columns(self, connection): + res = connection.execute( + select(self.tables.tbl_a.c.id).order_by(self.tables.tbl_a.c.id) + ).fetchall() + eq_(res, [(42,), (43,)]) + + @testing.requires.identity_columns_standard + def test_insert_always_error(self, connection): + def fn(): + connection.execute( + self.tables.tbl_a.insert(), + [{"id": 200, "desc": "a"}], + ) + + assert_raises((DatabaseError, ProgrammingError), fn) + + +class IdentityAutoincrementTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("autoincrement_without_sequence",) + + @classmethod + def define_tables(cls, metadata): + Table( + "tbl", + metadata, + Column( + "id", + Integer, + Identity(), + primary_key=True, + autoincrement=True, + ), + Column("desc", String(100)), + ) + + def test_autoincrement_with_identity(self, connection): + res = connection.execute(self.tables.tbl.insert(), {"desc": "row"}) + res = connection.execute(self.tables.tbl.select()).first() + eq_(res, (1, "row")) + + +class ExistsTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "stuff", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.stuff.insert(), + [ + {"id": 1, "data": "some data"}, + {"id": 2, "data": "some data"}, + {"id": 3, "data": "some data"}, + {"id": 4, "data": "some other data"}, + ], + ) + + def test_select_exists(self, connection): + stuff = self.tables.stuff + eq_( + connection.execute( + select(literal(1)).where( + exists().where(stuff.c.data == "some data") + ) + ).fetchall(), + [(1,)], + ) + + def test_select_exists_false(self, connection): + stuff = self.tables.stuff + eq_( + connection.execute( + select(literal(1)).where( + exists().where(stuff.c.data == "no data") + ) + ).fetchall(), + [], + ) + + +class DistinctOnTest(AssertsCompiledSQL, fixtures.TablesTest): + __backend__ = True + + @testing.fails_if(testing.requires.supports_distinct_on) + def test_distinct_on(self): + stm = select("*").distinct(column("q")).select_from(table("foo")) + with testing.expect_deprecated( + "DISTINCT ON is currently supported only by the PostgreSQL " + ): + self.assert_compile(stm, "SELECT DISTINCT * FROM foo") + + +class IsOrIsNotDistinctFromTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("supports_is_distinct_from",) + + @classmethod + def define_tables(cls, metadata): + Table( + "is_distinct_test", + metadata, + Column("id", Integer, primary_key=True), + Column("col_a", Integer, nullable=True), + Column("col_b", Integer, nullable=True), + ) + + @testing.combinations( + ("both_int_different", 0, 1, 1), + ("both_int_same", 1, 1, 0), + ("one_null_first", None, 1, 1), + ("one_null_second", 0, None, 1), + ("both_null", None, None, 0), + id_="iaaa", + argnames="col_a_value, col_b_value, expected_row_count_for_is", + ) + def test_is_or_is_not_distinct_from( + self, col_a_value, col_b_value, expected_row_count_for_is, connection + ): + tbl = self.tables.is_distinct_test + + connection.execute( + tbl.insert(), + [{"id": 1, "col_a": col_a_value, "col_b": col_b_value}], + ) + + result = connection.execute( + tbl.select().where(tbl.c.col_a.is_distinct_from(tbl.c.col_b)) + ).fetchall() + eq_( + len(result), + expected_row_count_for_is, + ) + + expected_row_count_for_is_not = ( + 1 if expected_row_count_for_is == 0 else 0 + ) + result = connection.execute( + tbl.select().where(tbl.c.col_a.is_not_distinct_from(tbl.c.col_b)) + ).fetchall() + eq_( + len(result), + expected_row_count_for_is_not, + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_sequence.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_sequence.py new file mode 100644 index 0000000..138616f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_sequence.py @@ -0,0 +1,317 @@ +# testing/suite/test_sequence.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .. import config +from .. import fixtures +from ..assertions import eq_ +from ..assertions import is_true +from ..config import requirements +from ..provision import normalize_sequence +from ..schema import Column +from ..schema import Table +from ... import inspect +from ... import Integer +from ... import MetaData +from ... import Sequence +from ... import String +from ... import testing + + +class SequenceTest(fixtures.TablesTest): + __requires__ = ("sequences",) + __backend__ = True + + run_create_tables = "each" + + @classmethod + def define_tables(cls, metadata): + Table( + "seq_pk", + metadata, + Column( + "id", + Integer, + normalize_sequence(config, Sequence("tab_id_seq")), + primary_key=True, + ), + Column("data", String(50)), + ) + + Table( + "seq_opt_pk", + metadata, + Column( + "id", + Integer, + normalize_sequence( + config, + Sequence("tab_id_seq", data_type=Integer, optional=True), + ), + primary_key=True, + ), + Column("data", String(50)), + ) + + Table( + "seq_no_returning", + metadata, + Column( + "id", + Integer, + normalize_sequence(config, Sequence("noret_id_seq")), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=False, + ) + + if testing.requires.schemas.enabled: + Table( + "seq_no_returning_sch", + metadata, + Column( + "id", + Integer, + normalize_sequence( + config, + Sequence( + "noret_sch_id_seq", schema=config.test_schema + ), + ), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=False, + schema=config.test_schema, + ) + + def test_insert_roundtrip(self, connection): + connection.execute(self.tables.seq_pk.insert(), dict(data="some data")) + self._assert_round_trip(self.tables.seq_pk, connection) + + def test_insert_lastrowid(self, connection): + r = connection.execute( + self.tables.seq_pk.insert(), dict(data="some data") + ) + eq_( + r.inserted_primary_key, (testing.db.dialect.default_sequence_base,) + ) + + def test_nextval_direct(self, connection): + r = connection.scalar(self.tables.seq_pk.c.id.default) + eq_(r, testing.db.dialect.default_sequence_base) + + @requirements.sequences_optional + def test_optional_seq(self, connection): + r = connection.execute( + self.tables.seq_opt_pk.insert(), dict(data="some data") + ) + eq_(r.inserted_primary_key, (1,)) + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_(row, (testing.db.dialect.default_sequence_base, "some data")) + + def test_insert_roundtrip_no_implicit_returning(self, connection): + connection.execute( + self.tables.seq_no_returning.insert(), dict(data="some data") + ) + self._assert_round_trip(self.tables.seq_no_returning, connection) + + @testing.combinations((True,), (False,), argnames="implicit_returning") + @testing.requires.schemas + def test_insert_roundtrip_translate(self, connection, implicit_returning): + seq_no_returning = Table( + "seq_no_returning_sch", + MetaData(), + Column( + "id", + Integer, + normalize_sequence( + config, Sequence("noret_sch_id_seq", schema="alt_schema") + ), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=implicit_returning, + schema="alt_schema", + ) + + connection = connection.execution_options( + schema_translate_map={"alt_schema": config.test_schema} + ) + connection.execute(seq_no_returning.insert(), dict(data="some data")) + self._assert_round_trip(seq_no_returning, connection) + + @testing.requires.schemas + def test_nextval_direct_schema_translate(self, connection): + seq = normalize_sequence( + config, Sequence("noret_sch_id_seq", schema="alt_schema") + ) + connection = connection.execution_options( + schema_translate_map={"alt_schema": config.test_schema} + ) + + r = connection.scalar(seq) + eq_(r, testing.db.dialect.default_sequence_base) + + +class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase): + __requires__ = ("sequences",) + __backend__ = True + + def test_literal_binds_inline_compile(self, connection): + table = Table( + "x", + MetaData(), + Column( + "y", Integer, normalize_sequence(config, Sequence("y_seq")) + ), + Column("q", Integer), + ) + + stmt = table.insert().values(q=5) + + seq_nextval = connection.dialect.statement_compiler( + statement=None, dialect=connection.dialect + ).visit_sequence(normalize_sequence(config, Sequence("y_seq"))) + self.assert_compile( + stmt, + "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,), + literal_binds=True, + dialect=connection.dialect, + ) + + +class HasSequenceTest(fixtures.TablesTest): + run_deletes = None + + __requires__ = ("sequences",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + normalize_sequence(config, Sequence("user_id_seq", metadata=metadata)) + normalize_sequence( + config, + Sequence( + "other_seq", + metadata=metadata, + nomaxvalue=True, + nominvalue=True, + ), + ) + if testing.requires.schemas.enabled: + normalize_sequence( + config, + Sequence( + "user_id_seq", schema=config.test_schema, metadata=metadata + ), + ) + normalize_sequence( + config, + Sequence( + "schema_seq", schema=config.test_schema, metadata=metadata + ), + ) + Table( + "user_id_table", + metadata, + Column("id", Integer, primary_key=True), + ) + + def test_has_sequence(self, connection): + eq_(inspect(connection).has_sequence("user_id_seq"), True) + + def test_has_sequence_cache(self, connection, metadata): + insp = inspect(connection) + eq_(insp.has_sequence("user_id_seq"), True) + ss = normalize_sequence(config, Sequence("new_seq", metadata=metadata)) + eq_(insp.has_sequence("new_seq"), False) + ss.create(connection) + try: + eq_(insp.has_sequence("new_seq"), False) + insp.clear_cache() + eq_(insp.has_sequence("new_seq"), True) + finally: + ss.drop(connection) + + def test_has_sequence_other_object(self, connection): + eq_(inspect(connection).has_sequence("user_id_table"), False) + + @testing.requires.schemas + def test_has_sequence_schema(self, connection): + eq_( + inspect(connection).has_sequence( + "user_id_seq", schema=config.test_schema + ), + True, + ) + + def test_has_sequence_neg(self, connection): + eq_(inspect(connection).has_sequence("some_sequence"), False) + + @testing.requires.schemas + def test_has_sequence_schemas_neg(self, connection): + eq_( + inspect(connection).has_sequence( + "some_sequence", schema=config.test_schema + ), + False, + ) + + @testing.requires.schemas + def test_has_sequence_default_not_in_remote(self, connection): + eq_( + inspect(connection).has_sequence( + "other_sequence", schema=config.test_schema + ), + False, + ) + + @testing.requires.schemas + def test_has_sequence_remote_not_in_default(self, connection): + eq_(inspect(connection).has_sequence("schema_seq"), False) + + def test_get_sequence_names(self, connection): + exp = {"other_seq", "user_id_seq"} + + res = set(inspect(connection).get_sequence_names()) + is_true(res.intersection(exp) == exp) + is_true("schema_seq" not in res) + + @testing.requires.schemas + def test_get_sequence_names_no_sequence_schema(self, connection): + eq_( + inspect(connection).get_sequence_names( + schema=config.test_schema_2 + ), + [], + ) + + @testing.requires.schemas + def test_get_sequence_names_sequences_schema(self, connection): + eq_( + sorted( + inspect(connection).get_sequence_names( + schema=config.test_schema + ) + ), + ["schema_seq", "user_id_seq"], + ) + + +class HasSequenceTestEmpty(fixtures.TestBase): + __requires__ = ("sequences",) + __backend__ = True + + def test_get_sequence_names_no_sequence(self, connection): + eq_( + inspect(connection).get_sequence_names(), + [], + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_types.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_types.py new file mode 100644 index 0000000..4a7c1f1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_types.py @@ -0,0 +1,2071 @@ +# testing/suite/test_types.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +import datetime +import decimal +import json +import re +import uuid + +from .. import config +from .. import engines +from .. import fixtures +from .. import mock +from ..assertions import eq_ +from ..assertions import is_ +from ..assertions import ne_ +from ..config import requirements +from ..schema import Column +from ..schema import Table +from ... import and_ +from ... import ARRAY +from ... import BigInteger +from ... import bindparam +from ... import Boolean +from ... import case +from ... import cast +from ... import Date +from ... import DateTime +from ... import Float +from ... import Integer +from ... import Interval +from ... import JSON +from ... import literal +from ... import literal_column +from ... import MetaData +from ... import null +from ... import Numeric +from ... import select +from ... import String +from ... import testing +from ... import Text +from ... import Time +from ... import TIMESTAMP +from ... import type_coerce +from ... import TypeDecorator +from ... import Unicode +from ... import UnicodeText +from ... import UUID +from ... import Uuid +from ...orm import declarative_base +from ...orm import Session +from ...sql import sqltypes +from ...sql.sqltypes import LargeBinary +from ...sql.sqltypes import PickleType + + +class _LiteralRoundTripFixture: + supports_whereclause = True + + @testing.fixture + def literal_round_trip(self, metadata, connection): + """test literal rendering""" + + # for literal, we test the literal render in an INSERT + # into a typed column. we can then SELECT it back as its + # official type; ideally we'd be able to use CAST here + # but MySQL in particular can't CAST fully + + def run( + type_, + input_, + output, + filter_=None, + compare=None, + support_whereclause=True, + ): + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + + for value in input_: + ins = t.insert().values( + x=literal(value, type_, literal_execute=True) + ) + connection.execute(ins) + + ins = t.insert().values( + x=literal(None, type_, literal_execute=True) + ) + connection.execute(ins) + + if support_whereclause and self.supports_whereclause: + if compare: + stmt = t.select().where( + t.c.x + == literal( + compare, + type_, + literal_execute=True, + ), + t.c.x + == literal( + input_[0], + type_, + literal_execute=True, + ), + ) + else: + stmt = t.select().where( + t.c.x + == literal( + compare if compare is not None else input_[0], + type_, + literal_execute=True, + ) + ) + else: + stmt = t.select().where(t.c.x.is_not(None)) + + rows = connection.execute(stmt).all() + assert rows, "No rows returned" + for row in rows: + value = row[0] + if filter_ is not None: + value = filter_(value) + assert value in output + + stmt = t.select().where(t.c.x.is_(None)) + rows = connection.execute(stmt).all() + eq_(rows, [(None,)]) + + return run + + +class _UnicodeFixture(_LiteralRoundTripFixture, fixtures.TestBase): + __requires__ = ("unicode_data",) + + data = ( + "Alors vous imaginez ma 🐍 surprise, au lever du jour, " + "quand une drôle de petite 🐍 voix m’a réveillé. Elle " + "disait: « S’il vous plaît… dessine-moi 🐍 un mouton! »" + ) + + @property + def supports_whereclause(self): + return config.requirements.expressions_against_unbounded_text.enabled + + @classmethod + def define_tables(cls, metadata): + Table( + "unicode_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("unicode_data", cls.datatype), + ) + + def test_round_trip(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), {"id": 1, "unicode_data": self.data} + ) + + row = connection.execute(select(unicode_table.c.unicode_data)).first() + + eq_(row, (self.data,)) + assert isinstance(row[0], str) + + def test_round_trip_executemany(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), + [{"id": i, "unicode_data": self.data} for i in range(1, 4)], + ) + + rows = connection.execute( + select(unicode_table.c.unicode_data) + ).fetchall() + eq_(rows, [(self.data,) for i in range(1, 4)]) + for row in rows: + assert isinstance(row[0], str) + + def _test_null_strings(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), {"id": 1, "unicode_data": None} + ) + row = connection.execute(select(unicode_table.c.unicode_data)).first() + eq_(row, (None,)) + + def _test_empty_strings(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), {"id": 1, "unicode_data": ""} + ) + row = connection.execute(select(unicode_table.c.unicode_data)).first() + eq_(row, ("",)) + + def test_literal(self, literal_round_trip): + literal_round_trip(self.datatype, [self.data], [self.data]) + + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip(self.datatype, ["réve🐍 illé"], ["réve🐍 illé"]) + + +class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest): + __requires__ = ("unicode_data",) + __backend__ = True + + datatype = Unicode(255) + + @requirements.empty_strings_varchar + def test_empty_strings_varchar(self, connection): + self._test_empty_strings(connection) + + def test_null_strings_varchar(self, connection): + self._test_null_strings(connection) + + +class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest): + __requires__ = "unicode_data", "text_type" + __backend__ = True + + datatype = UnicodeText() + + @requirements.empty_strings_text + def test_empty_strings_text(self, connection): + self._test_empty_strings(connection) + + def test_null_strings_text(self, connection): + self._test_null_strings(connection) + + +class ArrayTest(_LiteralRoundTripFixture, fixtures.TablesTest): + """Add ARRAY test suite, #8138. + + This only works on PostgreSQL right now. + + """ + + __requires__ = ("array_type",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "array_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("single_dim", ARRAY(Integer)), + Column("multi_dim", ARRAY(String, dimensions=2)), + ) + + def test_array_roundtrip(self, connection): + array_table = self.tables.array_table + + connection.execute( + array_table.insert(), + { + "id": 1, + "single_dim": [1, 2, 3], + "multi_dim": [["one", "two"], ["thr'ee", "réve🐍 illé"]], + }, + ) + row = connection.execute( + select(array_table.c.single_dim, array_table.c.multi_dim) + ).first() + eq_(row, ([1, 2, 3], [["one", "two"], ["thr'ee", "réve🐍 illé"]])) + + def test_literal_simple(self, literal_round_trip): + literal_round_trip( + ARRAY(Integer), + ([1, 2, 3],), + ([1, 2, 3],), + support_whereclause=False, + ) + + def test_literal_complex(self, literal_round_trip): + literal_round_trip( + ARRAY(String, dimensions=2), + ([["one", "two"], ["thr'ee", "réve🐍 illé"]],), + ([["one", "two"], ["thr'ee", "réve🐍 illé"]],), + support_whereclause=False, + ) + + +class BinaryTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "binary_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("binary_data", LargeBinary), + Column("pickle_data", PickleType), + ) + + @testing.combinations(b"this is binary", b"7\xe7\x9f", argnames="data") + def test_binary_roundtrip(self, connection, data): + binary_table = self.tables.binary_table + + connection.execute( + binary_table.insert(), {"id": 1, "binary_data": data} + ) + row = connection.execute(select(binary_table.c.binary_data)).first() + eq_(row, (data,)) + + def test_pickle_roundtrip(self, connection): + binary_table = self.tables.binary_table + + connection.execute( + binary_table.insert(), + {"id": 1, "pickle_data": {"foo": [1, 2, 3], "bar": "bat"}}, + ) + row = connection.execute(select(binary_table.c.pickle_data)).first() + eq_(row, ({"foo": [1, 2, 3], "bar": "bat"},)) + + +class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __requires__ = ("text_type",) + __backend__ = True + + @property + def supports_whereclause(self): + return config.requirements.expressions_against_unbounded_text.enabled + + @classmethod + def define_tables(cls, metadata): + Table( + "text_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("text_data", Text), + ) + + def test_text_roundtrip(self, connection): + text_table = self.tables.text_table + + connection.execute( + text_table.insert(), {"id": 1, "text_data": "some text"} + ) + row = connection.execute(select(text_table.c.text_data)).first() + eq_(row, ("some text",)) + + @testing.requires.empty_strings_text + def test_text_empty_strings(self, connection): + text_table = self.tables.text_table + + connection.execute(text_table.insert(), {"id": 1, "text_data": ""}) + row = connection.execute(select(text_table.c.text_data)).first() + eq_(row, ("",)) + + def test_text_null_strings(self, connection): + text_table = self.tables.text_table + + connection.execute(text_table.insert(), {"id": 1, "text_data": None}) + row = connection.execute(select(text_table.c.text_data)).first() + eq_(row, (None,)) + + def test_literal(self, literal_round_trip): + literal_round_trip(Text, ["some text"], ["some text"]) + + @requirements.unicode_data_no_special_types + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip(Text, ["réve🐍 illé"], ["réve🐍 illé"]) + + def test_literal_quoting(self, literal_round_trip): + data = """some 'text' hey "hi there" that's text""" + literal_round_trip(Text, [data], [data]) + + def test_literal_backslashes(self, literal_round_trip): + data = r"backslash one \ backslash two \\ end" + literal_round_trip(Text, [data], [data]) + + def test_literal_percentsigns(self, literal_round_trip): + data = r"percent % signs %% percent" + literal_round_trip(Text, [data], [data]) + + +class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @requirements.unbounded_varchar + def test_nolength_string(self): + metadata = MetaData() + foo = Table("foo", metadata, Column("one", String)) + + foo.create(config.db) + foo.drop(config.db) + + def test_literal(self, literal_round_trip): + # note that in Python 3, this invokes the Unicode + # datatype for the literal part because all strings are unicode + literal_round_trip(String(40), ["some text"], ["some text"]) + + @requirements.unicode_data_no_special_types + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip(String(40), ["réve🐍 illé"], ["réve🐍 illé"]) + + @testing.combinations( + ("%B%", ["AB", "BC"]), + ("A%C", ["AC"]), + ("A%C%Z", []), + argnames="expr, expected", + ) + def test_dont_truncate_rightside( + self, metadata, connection, expr, expected + ): + t = Table("t", metadata, Column("x", String(2))) + t.create(connection) + + connection.execute(t.insert(), [{"x": "AB"}, {"x": "BC"}, {"x": "AC"}]) + + eq_( + connection.scalars(select(t.c.x).where(t.c.x.like(expr))).all(), + expected, + ) + + def test_literal_quoting(self, literal_round_trip): + data = """some 'text' hey "hi there" that's text""" + literal_round_trip(String(40), [data], [data]) + + def test_literal_backslashes(self, literal_round_trip): + data = r"backslash one \ backslash two \\ end" + literal_round_trip(String(40), [data], [data]) + + def test_concatenate_binary(self, connection): + """dialects with special string concatenation operators should + implement visit_concat_op_binary() and visit_concat_op_clauselist() + in their compiler. + + .. versionchanged:: 2.0 visit_concat_op_clauselist() is also needed + for dialects to override the string concatenation operator. + + """ + eq_(connection.scalar(select(literal("a") + "b")), "ab") + + def test_concatenate_clauselist(self, connection): + """dialects with special string concatenation operators should + implement visit_concat_op_binary() and visit_concat_op_clauselist() + in their compiler. + + .. versionchanged:: 2.0 visit_concat_op_clauselist() is also needed + for dialects to override the string concatenation operator. + + """ + eq_( + connection.scalar(select(literal("a") + "b" + "c" + "d" + "e")), + "abcde", + ) + + +class IntervalTest(_LiteralRoundTripFixture, fixtures.TestBase): + __requires__ = ("datetime_interval",) + __backend__ = True + + datatype = Interval + data = datetime.timedelta(days=1, seconds=4) + + def test_literal(self, literal_round_trip): + literal_round_trip(self.datatype, [self.data], [self.data]) + + def test_select_direct_literal_interval(self, connection): + row = connection.execute(select(literal(self.data))).first() + eq_(row, (self.data,)) + + def test_arithmetic_operation_literal_interval(self, connection): + now = datetime.datetime.now().replace(microsecond=0) + # Able to subtract + row = connection.execute( + select(literal(now) - literal(self.data)) + ).scalar() + eq_(row, now - self.data) + + # Able to Add + row = connection.execute( + select(literal(now) + literal(self.data)) + ).scalar() + eq_(row, now + self.data) + + @testing.fixture + def arithmetic_table_fixture(cls, metadata, connection): + class Decorated(TypeDecorator): + impl = cls.datatype + cache_ok = True + + it = Table( + "interval_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("interval_data", cls.datatype), + Column("date_data", DateTime), + Column("decorated_interval_data", Decorated), + ) + it.create(connection) + return it + + def test_arithmetic_operation_table_interval_and_literal_interval( + self, connection, arithmetic_table_fixture + ): + interval_table = arithmetic_table_fixture + data = datetime.timedelta(days=2, seconds=5) + connection.execute( + interval_table.insert(), {"id": 1, "interval_data": data} + ) + # Subtraction Operation + value = connection.execute( + select(interval_table.c.interval_data - literal(self.data)) + ).scalar() + eq_(value, data - self.data) + + # Addition Operation + value = connection.execute( + select(interval_table.c.interval_data + literal(self.data)) + ).scalar() + eq_(value, data + self.data) + + def test_arithmetic_operation_table_date_and_literal_interval( + self, connection, arithmetic_table_fixture + ): + interval_table = arithmetic_table_fixture + now = datetime.datetime.now().replace(microsecond=0) + connection.execute( + interval_table.insert(), {"id": 1, "date_data": now} + ) + # Subtraction Operation + value = connection.execute( + select(interval_table.c.date_data - literal(self.data)) + ).scalar() + eq_(value, (now - self.data)) + + # Addition Operation + value = connection.execute( + select(interval_table.c.date_data + literal(self.data)) + ).scalar() + eq_(value, (now + self.data)) + + +class PrecisionIntervalTest(IntervalTest): + __requires__ = ("datetime_interval",) + __backend__ = True + + datatype = Interval(day_precision=9, second_precision=9) + data = datetime.timedelta(days=103, seconds=4) + + +class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): + compare = None + + @classmethod + def define_tables(cls, metadata): + class Decorated(TypeDecorator): + impl = cls.datatype + cache_ok = True + + Table( + "date_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("date_data", cls.datatype), + Column("decorated_date_data", Decorated), + ) + + def test_round_trip(self, connection): + date_table = self.tables.date_table + + connection.execute( + date_table.insert(), {"id": 1, "date_data": self.data} + ) + + row = connection.execute(select(date_table.c.date_data)).first() + + compare = self.compare or self.data + eq_(row, (compare,)) + assert isinstance(row[0], type(compare)) + + def test_round_trip_decorated(self, connection): + date_table = self.tables.date_table + + connection.execute( + date_table.insert(), {"id": 1, "decorated_date_data": self.data} + ) + + row = connection.execute( + select(date_table.c.decorated_date_data) + ).first() + + compare = self.compare or self.data + eq_(row, (compare,)) + assert isinstance(row[0], type(compare)) + + def test_null(self, connection): + date_table = self.tables.date_table + + connection.execute(date_table.insert(), {"id": 1, "date_data": None}) + + row = connection.execute(select(date_table.c.date_data)).first() + eq_(row, (None,)) + + @testing.requires.datetime_literals + def test_literal(self, literal_round_trip): + compare = self.compare or self.data + + literal_round_trip( + self.datatype, [self.data], [compare], compare=compare + ) + + @testing.requires.standalone_null_binds_whereclause + def test_null_bound_comparison(self): + # this test is based on an Oracle issue observed in #4886. + # passing NULL for an expression that needs to be interpreted as + # a certain type, does the DBAPI have the info it needs to do this. + date_table = self.tables.date_table + with config.db.begin() as conn: + result = conn.execute( + date_table.insert(), {"id": 1, "date_data": self.data} + ) + id_ = result.inserted_primary_key[0] + stmt = select(date_table.c.id).where( + case( + ( + bindparam("foo", type_=self.datatype) != None, + bindparam("foo", type_=self.datatype), + ), + else_=date_table.c.date_data, + ) + == date_table.c.date_data + ) + + row = conn.execute(stmt, {"foo": None}).first() + eq_(row[0], id_) + + +class DateTimeTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime",) + __backend__ = True + datatype = DateTime + data = datetime.datetime(2012, 10, 15, 12, 57, 18) + + @testing.requires.datetime_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class DateTimeTZTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime_timezone",) + __backend__ = True + datatype = DateTime(timezone=True) + data = datetime.datetime( + 2012, 10, 15, 12, 57, 18, tzinfo=datetime.timezone.utc + ) + + @testing.requires.datetime_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime_microseconds",) + __backend__ = True + datatype = DateTime + data = datetime.datetime(2012, 10, 15, 12, 57, 18, 39642) + + +class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("timestamp_microseconds",) + __backend__ = True + datatype = TIMESTAMP + data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) + + @testing.requires.timestamp_microseconds_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class TimeTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("time",) + __backend__ = True + datatype = Time + data = datetime.time(12, 57, 18) + + @testing.requires.time_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class TimeTZTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("time_timezone",) + __backend__ = True + datatype = Time(timezone=True) + data = datetime.time(12, 57, 18, tzinfo=datetime.timezone.utc) + + @testing.requires.time_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("time_microseconds",) + __backend__ = True + datatype = Time + data = datetime.time(12, 57, 18, 396) + + @testing.requires.time_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class DateTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("date",) + __backend__ = True + datatype = Date + data = datetime.date(2012, 10, 15) + + @testing.requires.date_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest): + """this particular suite is testing that datetime parameters get + coerced to dates, which tends to be something DBAPIs do. + + """ + + __requires__ = "date", "date_coerces_from_datetime" + __backend__ = True + datatype = Date + data = datetime.datetime(2012, 10, 15, 12, 57, 18) + compare = datetime.date(2012, 10, 15) + + @testing.requires.datetime_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime_historic",) + __backend__ = True + datatype = DateTime + data = datetime.datetime(1850, 11, 10, 11, 52, 35) + + @testing.requires.date_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class DateHistoricTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("date_historic",) + __backend__ = True + datatype = Date + data = datetime.date(1727, 4, 1) + + @testing.requires.date_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + def test_literal(self, literal_round_trip): + literal_round_trip(Integer, [5], [5]) + + def _huge_ints(): + return testing.combinations( + 2147483649, # 32 bits + 2147483648, # 32 bits + 2147483647, # 31 bits + 2147483646, # 31 bits + -2147483649, # 32 bits + -2147483648, # 32 interestingly, asyncpg accepts this one as int32 + -2147483647, # 31 + -2147483646, # 31 + 0, + 1376537018368127, + -1376537018368127, + argnames="intvalue", + ) + + @_huge_ints() + def test_huge_int_auto_accommodation(self, connection, intvalue): + """test #7909""" + + eq_( + connection.scalar( + select(intvalue).where(literal(intvalue) == intvalue) + ), + intvalue, + ) + + @_huge_ints() + def test_huge_int(self, integer_round_trip, intvalue): + integer_round_trip(BigInteger, intvalue) + + @testing.fixture + def integer_round_trip(self, metadata, connection): + def run(datatype, data): + int_table = Table( + "integer_table", + metadata, + Column( + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("integer_data", datatype), + ) + + metadata.create_all(config.db) + + connection.execute( + int_table.insert(), {"id": 1, "integer_data": data} + ) + + row = connection.execute(select(int_table.c.integer_data)).first() + + eq_(row, (data,)) + + assert isinstance(row[0], int) + + return run + + +class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @testing.fixture + def string_as_int(self): + class StringAsInt(TypeDecorator): + impl = String(50) + cache_ok = True + + def column_expression(self, col): + return cast(col, Integer) + + def bind_expression(self, col): + return cast(type_coerce(col, Integer), String(50)) + + return StringAsInt() + + def test_special_type(self, metadata, connection, string_as_int): + type_ = string_as_int + + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + + connection.execute(t.insert(), [{"x": x} for x in [1, 2, 3]]) + + result = {row[0] for row in connection.execute(t.select())} + eq_(result, {1, 2, 3}) + + result = { + row[0] for row in connection.execute(t.select().where(t.c.x == 2)) + } + eq_(result, {2}) + + +class TrueDivTest(fixtures.TestBase): + __backend__ = True + + @testing.combinations( + ("15", "10", 1.5), + ("-15", "10", -1.5), + argnames="left, right, expected", + ) + def test_truediv_integer(self, connection, left, right, expected): + """test #4926""" + + eq_( + connection.scalar( + select( + literal_column(left, type_=Integer()) + / literal_column(right, type_=Integer()) + ) + ), + expected, + ) + + @testing.combinations( + ("15", "10", 1), ("-15", "5", -3), argnames="left, right, expected" + ) + def test_floordiv_integer(self, connection, left, right, expected): + """test #4926""" + + eq_( + connection.scalar( + select( + literal_column(left, type_=Integer()) + // literal_column(right, type_=Integer()) + ) + ), + expected, + ) + + @testing.combinations( + ("5.52", "2.4", "2.3"), argnames="left, right, expected" + ) + def test_truediv_numeric(self, connection, left, right, expected): + """test #4926""" + + eq_( + connection.scalar( + select( + literal_column(left, type_=Numeric(10, 2)) + / literal_column(right, type_=Numeric(10, 2)) + ) + ), + decimal.Decimal(expected), + ) + + @testing.combinations( + ("5.52", "2.4", 2.3), argnames="left, right, expected" + ) + def test_truediv_float(self, connection, left, right, expected): + """test #4926""" + + eq_( + connection.scalar( + select( + literal_column(left, type_=Float()) + / literal_column(right, type_=Float()) + ) + ), + expected, + ) + + @testing.combinations( + ("5.52", "2.4", "2.0"), argnames="left, right, expected" + ) + def test_floordiv_numeric(self, connection, left, right, expected): + """test #4926""" + + eq_( + connection.scalar( + select( + literal_column(left, type_=Numeric()) + // literal_column(right, type_=Numeric()) + ) + ), + decimal.Decimal(expected), + ) + + def test_truediv_integer_bound(self, connection): + """test #4926""" + + eq_( + connection.scalar(select(literal(15) / literal(10))), + 1.5, + ) + + def test_floordiv_integer_bound(self, connection): + """test #4926""" + + eq_( + connection.scalar(select(literal(15) // literal(10))), + 1, + ) + + +class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @testing.fixture + def do_numeric_test(self, metadata, connection): + def run(type_, input_, output, filter_=None, check_scale=False): + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + connection.execute(t.insert(), [{"x": x} for x in input_]) + + result = {row[0] for row in connection.execute(t.select())} + output = set(output) + if filter_: + result = {filter_(x) for x in result} + output = {filter_(x) for x in output} + eq_(result, output) + if check_scale: + eq_([str(x) for x in result], [str(x) for x in output]) + + connection.execute(t.delete()) + + # test that this is actually a number! + # note we have tiny scale here as we have tests with very + # small scale Numeric types. PostgreSQL will raise an error + # if you use values outside the available scale. + if type_.asdecimal: + test_value = decimal.Decimal("2.9") + add_value = decimal.Decimal("37.12") + else: + test_value = 2.9 + add_value = 37.12 + + connection.execute(t.insert(), {"x": test_value}) + assert_we_are_a_number = connection.scalar( + select(type_coerce(t.c.x + add_value, type_)) + ) + eq_( + round(assert_we_are_a_number, 3), + round(test_value + add_value, 3), + ) + + return run + + def test_render_literal_numeric(self, literal_round_trip): + literal_round_trip( + Numeric(precision=8, scale=4), + [15.7563, decimal.Decimal("15.7563")], + [decimal.Decimal("15.7563")], + ) + + def test_render_literal_numeric_asfloat(self, literal_round_trip): + literal_round_trip( + Numeric(precision=8, scale=4, asdecimal=False), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + ) + + def test_render_literal_float(self, literal_round_trip): + literal_round_trip( + Float(), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + filter_=lambda n: n is not None and round(n, 5) or None, + support_whereclause=False, + ) + + @testing.requires.precision_generic_float_type + def test_float_custom_scale(self, do_numeric_test): + do_numeric_test( + Float(None, decimal_return_scale=7, asdecimal=True), + [15.7563827, decimal.Decimal("15.7563827")], + [decimal.Decimal("15.7563827")], + check_scale=True, + ) + + def test_numeric_as_decimal(self, do_numeric_test): + do_numeric_test( + Numeric(precision=8, scale=4), + [15.7563, decimal.Decimal("15.7563")], + [decimal.Decimal("15.7563")], + ) + + def test_numeric_as_float(self, do_numeric_test): + do_numeric_test( + Numeric(precision=8, scale=4, asdecimal=False), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + ) + + @testing.requires.infinity_floats + def test_infinity_floats(self, do_numeric_test): + """test for #977, #7283""" + + do_numeric_test( + Float(None), + [float("inf")], + [float("inf")], + ) + + @testing.requires.fetch_null_from_numeric + def test_numeric_null_as_decimal(self, do_numeric_test): + do_numeric_test(Numeric(precision=8, scale=4), [None], [None]) + + @testing.requires.fetch_null_from_numeric + def test_numeric_null_as_float(self, do_numeric_test): + do_numeric_test( + Numeric(precision=8, scale=4, asdecimal=False), [None], [None] + ) + + @testing.requires.floats_to_four_decimals + def test_float_as_decimal(self, do_numeric_test): + do_numeric_test( + Float(asdecimal=True), + [15.756, decimal.Decimal("15.756"), None], + [decimal.Decimal("15.756"), None], + filter_=lambda n: n is not None and round(n, 4) or None, + ) + + def test_float_as_float(self, do_numeric_test): + do_numeric_test( + Float(), + [15.756, decimal.Decimal("15.756")], + [15.756], + filter_=lambda n: n is not None and round(n, 5) or None, + ) + + @testing.requires.literal_float_coercion + def test_float_coerce_round_trip(self, connection): + expr = 15.7563 + + val = connection.scalar(select(literal(expr))) + eq_(val, expr) + + # this does not work in MySQL, see #4036, however we choose not + # to render CAST unconditionally since this is kind of an edge case. + + @testing.requires.implicit_decimal_binds + def test_decimal_coerce_round_trip(self, connection): + expr = decimal.Decimal("15.7563") + + val = connection.scalar(select(literal(expr))) + eq_(val, expr) + + def test_decimal_coerce_round_trip_w_cast(self, connection): + expr = decimal.Decimal("15.7563") + + val = connection.scalar(select(cast(expr, Numeric(10, 4)))) + eq_(val, expr) + + @testing.requires.precision_numerics_general + def test_precision_decimal(self, do_numeric_test): + numbers = { + decimal.Decimal("54.234246451650"), + decimal.Decimal("0.004354"), + decimal.Decimal("900.0"), + } + + do_numeric_test(Numeric(precision=18, scale=12), numbers, numbers) + + @testing.requires.precision_numerics_enotation_large + def test_enotation_decimal(self, do_numeric_test): + """test exceedingly small decimals. + + Decimal reports values with E notation when the exponent + is greater than 6. + + """ + + numbers = { + decimal.Decimal("1E-2"), + decimal.Decimal("1E-3"), + decimal.Decimal("1E-4"), + decimal.Decimal("1E-5"), + decimal.Decimal("1E-6"), + decimal.Decimal("1E-7"), + decimal.Decimal("1E-8"), + decimal.Decimal("0.01000005940696"), + decimal.Decimal("0.00000005940696"), + decimal.Decimal("0.00000000000696"), + decimal.Decimal("0.70000000000696"), + decimal.Decimal("696E-12"), + } + do_numeric_test(Numeric(precision=18, scale=14), numbers, numbers) + + @testing.requires.precision_numerics_enotation_large + def test_enotation_decimal_large(self, do_numeric_test): + """test exceedingly large decimals.""" + + numbers = { + decimal.Decimal("4E+8"), + decimal.Decimal("5748E+15"), + decimal.Decimal("1.521E+15"), + decimal.Decimal("00000000000000.1E+12"), + } + do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers) + + @testing.requires.precision_numerics_many_significant_digits + def test_many_significant_digits(self, do_numeric_test): + numbers = { + decimal.Decimal("31943874831932418390.01"), + decimal.Decimal("319438950232418390.273596"), + decimal.Decimal("87673.594069654243"), + } + do_numeric_test(Numeric(precision=38, scale=12), numbers, numbers) + + @testing.requires.precision_numerics_retains_significant_digits + def test_numeric_no_decimal(self, do_numeric_test): + numbers = {decimal.Decimal("1.000")} + do_numeric_test( + Numeric(precision=5, scale=3), numbers, numbers, check_scale=True + ) + + @testing.combinations(sqltypes.Float, sqltypes.Double, argnames="cls_") + @testing.requires.float_is_numeric + def test_float_is_not_numeric(self, connection, cls_): + target_type = cls_().dialect_impl(connection.dialect) + numeric_type = sqltypes.Numeric().dialect_impl(connection.dialect) + + ne_(target_type.__visit_name__, numeric_type.__visit_name__) + ne_(target_type.__class__, numeric_type.__class__) + + +class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "boolean_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("value", Boolean), + Column("unconstrained_value", Boolean(create_constraint=False)), + ) + + def test_render_literal_bool(self, literal_round_trip): + literal_round_trip(Boolean(), [True, False], [True, False]) + + def test_round_trip(self, connection): + boolean_table = self.tables.boolean_table + + connection.execute( + boolean_table.insert(), + {"id": 1, "value": True, "unconstrained_value": False}, + ) + + row = connection.execute( + select(boolean_table.c.value, boolean_table.c.unconstrained_value) + ).first() + + eq_(row, (True, False)) + assert isinstance(row[0], bool) + + @testing.requires.nullable_booleans + def test_null(self, connection): + boolean_table = self.tables.boolean_table + + connection.execute( + boolean_table.insert(), + {"id": 1, "value": None, "unconstrained_value": None}, + ) + + row = connection.execute( + select(boolean_table.c.value, boolean_table.c.unconstrained_value) + ).first() + + eq_(row, (None, None)) + + def test_whereclause(self): + # testing "WHERE <column>" renders a compatible expression + boolean_table = self.tables.boolean_table + + with config.db.begin() as conn: + conn.execute( + boolean_table.insert(), + [ + {"id": 1, "value": True, "unconstrained_value": True}, + {"id": 2, "value": False, "unconstrained_value": False}, + ], + ) + + eq_( + conn.scalar( + select(boolean_table.c.id).where(boolean_table.c.value) + ), + 1, + ) + eq_( + conn.scalar( + select(boolean_table.c.id).where( + boolean_table.c.unconstrained_value + ) + ), + 1, + ) + eq_( + conn.scalar( + select(boolean_table.c.id).where(~boolean_table.c.value) + ), + 2, + ) + eq_( + conn.scalar( + select(boolean_table.c.id).where( + ~boolean_table.c.unconstrained_value + ) + ), + 2, + ) + + +class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __requires__ = ("json_type",) + __backend__ = True + + datatype = JSON + + @classmethod + def define_tables(cls, metadata): + Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), nullable=False), + Column("data", cls.datatype, nullable=False), + Column("nulldata", cls.datatype(none_as_null=True)), + ) + + def test_round_trip_data1(self, connection): + self._test_round_trip({"key1": "value1", "key2": "value2"}, connection) + + @testing.combinations( + ("unicode", True), ("ascii", False), argnames="unicode_", id_="ia" + ) + @testing.combinations(100, 1999, 3000, 4000, 5000, 9000, argnames="length") + def test_round_trip_pretty_large_data(self, connection, unicode_, length): + if unicode_: + data = "réve🐍illé" * ((length // 9) + 1) + data = data[0 : (length // 2)] + else: + data = "abcdefg" * ((length // 7) + 1) + data = data[0:length] + + self._test_round_trip({"key1": data, "key2": data}, connection) + + def _test_round_trip(self, data_element, connection): + data_table = self.tables.data_table + + connection.execute( + data_table.insert(), + {"id": 1, "name": "row1", "data": data_element}, + ) + + row = connection.execute(select(data_table.c.data)).first() + + eq_(row, (data_element,)) + + def _index_fixtures(include_comparison): + if include_comparison: + # basically SQL Server and MariaDB can kind of do json + # comparison, MySQL, PG and SQLite can't. not worth it. + json_elements = [] + else: + json_elements = [ + ("json", {"foo": "bar"}), + ("json", ["one", "two", "three"]), + (None, {"foo": "bar"}), + (None, ["one", "two", "three"]), + ] + + elements = [ + ("boolean", True), + ("boolean", False), + ("boolean", None), + ("string", "some string"), + ("string", None), + ("string", "réve illé"), + ( + "string", + "réve🐍 illé", + testing.requires.json_index_supplementary_unicode_element, + ), + ("integer", 15), + ("integer", 1), + ("integer", 0), + ("integer", None), + ("float", 28.5), + ("float", None), + ("float", 1234567.89, testing.requires.literal_float_coercion), + ("numeric", 1234567.89), + # this one "works" because the float value you see here is + # lost immediately to floating point stuff + ( + "numeric", + 99998969694839.983485848, + ), + ("numeric", 99939.983485848), + ("_decimal", decimal.Decimal("1234567.89")), + ( + "_decimal", + decimal.Decimal("99998969694839.983485848"), + # fails on SQLite and MySQL (non-mariadb) + requirements.cast_precision_numerics_many_significant_digits, + ), + ( + "_decimal", + decimal.Decimal("99939.983485848"), + ), + ] + json_elements + + def decorate(fn): + fn = testing.combinations(id_="sa", *elements)(fn) + + return fn + + return decorate + + def _json_value_insert(self, connection, datatype, value, data_element): + data_table = self.tables.data_table + if datatype == "_decimal": + # Python's builtin json serializer basically doesn't support + # Decimal objects without implicit float conversion period. + # users can otherwise use simplejson which supports + # precision decimals + + # https://bugs.python.org/issue16535 + + # inserting as strings to avoid a new fixture around the + # dialect which would have idiosyncrasies for different + # backends. + + class DecimalEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, decimal.Decimal): + return str(o) + return super().default(o) + + json_data = json.dumps(data_element, cls=DecimalEncoder) + + # take the quotes out. yup, there is *literally* no other + # way to get Python's json.dumps() to put all the digits in + # the string + json_data = re.sub(r'"(%s)"' % str(value), str(value), json_data) + + datatype = "numeric" + + connection.execute( + data_table.insert().values( + name="row1", + # to pass the string directly to every backend, including + # PostgreSQL which needs the value to be CAST as JSON + # both in the SQL as well as at the prepared statement + # level for asyncpg, while at the same time MySQL + # doesn't even support CAST for JSON, here we are + # sending the string embedded in the SQL without using + # a parameter. + data=bindparam(None, json_data, literal_execute=True), + nulldata=bindparam(None, json_data, literal_execute=True), + ), + ) + else: + connection.execute( + data_table.insert(), + { + "name": "row1", + "data": data_element, + "nulldata": data_element, + }, + ) + + p_s = None + + if datatype: + if datatype == "numeric": + a, b = str(value).split(".") + s = len(b) + p = len(a) + s + + if isinstance(value, decimal.Decimal): + compare_value = value + else: + compare_value = decimal.Decimal(str(value)) + + p_s = (p, s) + else: + compare_value = value + else: + compare_value = value + + return datatype, compare_value, p_s + + @_index_fixtures(False) + def test_index_typed_access(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": value} + + with config.db.begin() as conn: + datatype, compare_value, p_s = self._json_value_insert( + conn, datatype, value, data_element + ) + + expr = data_table.c.data["key1"] + if datatype: + if datatype == "numeric" and p_s: + expr = expr.as_numeric(*p_s) + else: + expr = getattr(expr, "as_%s" % datatype)() + + roundtrip = conn.scalar(select(expr)) + eq_(roundtrip, compare_value) + is_(type(roundtrip), type(compare_value)) + + @_index_fixtures(True) + def test_index_typed_comparison(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": value} + + with config.db.begin() as conn: + datatype, compare_value, p_s = self._json_value_insert( + conn, datatype, value, data_element + ) + + expr = data_table.c.data["key1"] + if datatype: + if datatype == "numeric" and p_s: + expr = expr.as_numeric(*p_s) + else: + expr = getattr(expr, "as_%s" % datatype)() + + row = conn.execute( + select(expr).where(expr == compare_value) + ).first() + + # make sure we get a row even if value is None + eq_(row, (compare_value,)) + + @_index_fixtures(True) + def test_path_typed_comparison(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": {"subkey1": value}} + with config.db.begin() as conn: + datatype, compare_value, p_s = self._json_value_insert( + conn, datatype, value, data_element + ) + + expr = data_table.c.data[("key1", "subkey1")] + + if datatype: + if datatype == "numeric" and p_s: + expr = expr.as_numeric(*p_s) + else: + expr = getattr(expr, "as_%s" % datatype)() + + row = conn.execute( + select(expr).where(expr == compare_value) + ).first() + + # make sure we get a row even if value is None + eq_(row, (compare_value,)) + + @testing.combinations( + (True,), + (False,), + (None,), + (15,), + (0,), + (-1,), + (-1.0,), + (15.052,), + ("a string",), + ("réve illé",), + ("réve🐍 illé",), + ) + def test_single_element_round_trip(self, element): + data_table = self.tables.data_table + data_element = element + with config.db.begin() as conn: + conn.execute( + data_table.insert(), + { + "name": "row1", + "data": data_element, + "nulldata": data_element, + }, + ) + + row = conn.execute( + select(data_table.c.data, data_table.c.nulldata) + ).first() + + eq_(row, (data_element, data_element)) + + def test_round_trip_custom_json(self): + data_table = self.tables.data_table + data_element = {"key1": "data1"} + + js = mock.Mock(side_effect=json.dumps) + jd = mock.Mock(side_effect=json.loads) + engine = engines.testing_engine( + options=dict(json_serializer=js, json_deserializer=jd) + ) + + # support sqlite :memory: database... + data_table.create(engine, checkfirst=True) + with engine.begin() as conn: + conn.execute( + data_table.insert(), {"name": "row1", "data": data_element} + ) + row = conn.execute(select(data_table.c.data)).first() + + eq_(row, (data_element,)) + eq_(js.mock_calls, [mock.call(data_element)]) + if testing.requires.json_deserializer_binary.enabled: + eq_( + jd.mock_calls, + [mock.call(json.dumps(data_element).encode())], + ) + else: + eq_(jd.mock_calls, [mock.call(json.dumps(data_element))]) + + @testing.combinations( + ("parameters",), + ("multiparameters",), + ("values",), + ("omit",), + argnames="insert_type", + ) + def test_round_trip_none_as_sql_null(self, connection, insert_type): + col = self.tables.data_table.c["nulldata"] + + conn = connection + + if insert_type == "parameters": + stmt, params = self.tables.data_table.insert(), { + "name": "r1", + "nulldata": None, + "data": None, + } + elif insert_type == "multiparameters": + stmt, params = self.tables.data_table.insert(), [ + {"name": "r1", "nulldata": None, "data": None} + ] + elif insert_type == "values": + stmt, params = ( + self.tables.data_table.insert().values( + name="r1", + nulldata=None, + data=None, + ), + {}, + ) + elif insert_type == "omit": + stmt, params = ( + self.tables.data_table.insert(), + {"name": "r1", "data": None}, + ) + + else: + assert False + + conn.execute(stmt, params) + + eq_( + conn.scalar( + select(self.tables.data_table.c.name).where(col.is_(null())) + ), + "r1", + ) + + eq_(conn.scalar(select(col)), None) + + def test_round_trip_json_null_as_json_null(self, connection): + col = self.tables.data_table.c["data"] + + conn = connection + conn.execute( + self.tables.data_table.insert(), + {"name": "r1", "data": JSON.NULL}, + ) + + eq_( + conn.scalar( + select(self.tables.data_table.c.name).where( + cast(col, String) == "null" + ) + ), + "r1", + ) + + eq_(conn.scalar(select(col)), None) + + @testing.combinations( + ("parameters",), + ("multiparameters",), + ("values",), + argnames="insert_type", + ) + def test_round_trip_none_as_json_null(self, connection, insert_type): + col = self.tables.data_table.c["data"] + + if insert_type == "parameters": + stmt, params = self.tables.data_table.insert(), { + "name": "r1", + "data": None, + } + elif insert_type == "multiparameters": + stmt, params = self.tables.data_table.insert(), [ + {"name": "r1", "data": None} + ] + elif insert_type == "values": + stmt, params = ( + self.tables.data_table.insert().values(name="r1", data=None), + {}, + ) + else: + assert False + + conn = connection + conn.execute(stmt, params) + + eq_( + conn.scalar( + select(self.tables.data_table.c.name).where( + cast(col, String) == "null" + ) + ), + "r1", + ) + + eq_(conn.scalar(select(col)), None) + + def test_unicode_round_trip(self): + # note we include Unicode supplementary characters as well + with config.db.begin() as conn: + conn.execute( + self.tables.data_table.insert(), + { + "name": "r1", + "data": { + "réve🐍 illé": "réve🐍 illé", + "data": {"k1": "drôl🐍e"}, + }, + }, + ) + + eq_( + conn.scalar(select(self.tables.data_table.c.data)), + { + "réve🐍 illé": "réve🐍 illé", + "data": {"k1": "drôl🐍e"}, + }, + ) + + def test_eval_none_flag_orm(self, connection): + Base = declarative_base() + + class Data(Base): + __table__ = self.tables.data_table + + with Session(connection) as s: + d1 = Data(name="d1", data=None, nulldata=None) + s.add(d1) + s.commit() + + s.bulk_insert_mappings( + Data, [{"name": "d2", "data": None, "nulldata": None}] + ) + eq_( + s.query( + cast(self.tables.data_table.c.data, String()), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d1") + .first(), + ("null", None), + ) + eq_( + s.query( + cast(self.tables.data_table.c.data, String()), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d2") + .first(), + ("null", None), + ) + + +class JSONLegacyStringCastIndexTest( + _LiteralRoundTripFixture, fixtures.TablesTest +): + """test JSON index access with "cast to string", which we have documented + for a long time as how to compare JSON values, but is ultimately not + reliable in all cases. The "as_XYZ()" comparators should be used + instead. + + """ + + __requires__ = ("json_type", "legacy_unconditional_json_extract") + __backend__ = True + + datatype = JSON + + data1 = {"key1": "value1", "key2": "value2"} + + data2 = { + "Key 'One'": "value1", + "key two": "value2", + "key three": "value ' three '", + } + + data3 = { + "key1": [1, 2, 3], + "key2": ["one", "two", "three"], + "key3": [{"four": "five"}, {"six": "seven"}], + } + + data4 = ["one", "two", "three"] + + data5 = { + "nested": { + "elem1": [{"a": "b", "c": "d"}, {"e": "f", "g": "h"}], + "elem2": {"elem3": {"elem4": "elem5"}}, + } + } + + data6 = {"a": 5, "b": "some value", "c": {"foo": "bar"}} + + @classmethod + def define_tables(cls, metadata): + Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), nullable=False), + Column("data", cls.datatype), + Column("nulldata", cls.datatype(none_as_null=True)), + ) + + def _criteria_fixture(self): + with config.db.begin() as conn: + conn.execute( + self.tables.data_table.insert(), + [ + {"name": "r1", "data": self.data1}, + {"name": "r2", "data": self.data2}, + {"name": "r3", "data": self.data3}, + {"name": "r4", "data": self.data4}, + {"name": "r5", "data": self.data5}, + {"name": "r6", "data": self.data6}, + ], + ) + + def _test_index_criteria(self, crit, expected, test_literal=True): + self._criteria_fixture() + with config.db.connect() as conn: + stmt = select(self.tables.data_table.c.name).where(crit) + + eq_(conn.scalar(stmt), expected) + + if test_literal: + literal_sql = str( + stmt.compile( + config.db, compile_kwargs={"literal_binds": True} + ) + ) + + eq_(conn.exec_driver_sql(literal_sql).scalar(), expected) + + def test_string_cast_crit_spaces_in_key(self): + name = self.tables.data_table.c.name + col = self.tables.data_table.c["data"] + + # limit the rows here to avoid PG error + # "cannot extract field from a non-object", which is + # fixed in 9.4 but may exist in 9.3 + self._test_index_criteria( + and_( + name.in_(["r1", "r2", "r3"]), + cast(col["key two"], String) == '"value2"', + ), + "r2", + ) + + @config.requirements.json_array_indexes + def test_string_cast_crit_simple_int(self): + name = self.tables.data_table.c.name + col = self.tables.data_table.c["data"] + + # limit the rows here to avoid PG error + # "cannot extract array element from a non-array", which is + # fixed in 9.4 but may exist in 9.3 + self._test_index_criteria( + and_( + name == "r4", + cast(col[1], String) == '"two"', + ), + "r4", + ) + + def test_string_cast_crit_mixed_path(self): + col = self.tables.data_table.c["data"] + self._test_index_criteria( + cast(col[("key3", 1, "six")], String) == '"seven"', + "r3", + ) + + def test_string_cast_crit_string_path(self): + col = self.tables.data_table.c["data"] + self._test_index_criteria( + cast(col[("nested", "elem2", "elem3", "elem4")], String) + == '"elem5"', + "r5", + ) + + def test_string_cast_crit_against_string_basic(self): + name = self.tables.data_table.c.name + col = self.tables.data_table.c["data"] + + self._test_index_criteria( + and_( + name == "r6", + cast(col["b"], String) == '"some value"', + ), + "r6", + ) + + +class UuidTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __backend__ = True + + datatype = Uuid + + @classmethod + def define_tables(cls, metadata): + Table( + "uuid_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("uuid_data", cls.datatype), + Column("uuid_text_data", cls.datatype(as_uuid=False)), + Column("uuid_data_nonnative", Uuid(native_uuid=False)), + Column( + "uuid_text_data_nonnative", + Uuid(as_uuid=False, native_uuid=False), + ), + ) + + def test_uuid_round_trip(self, connection): + data = uuid.uuid4() + uuid_table = self.tables.uuid_table + + connection.execute( + uuid_table.insert(), + {"id": 1, "uuid_data": data, "uuid_data_nonnative": data}, + ) + row = connection.execute( + select( + uuid_table.c.uuid_data, uuid_table.c.uuid_data_nonnative + ).where( + uuid_table.c.uuid_data == data, + uuid_table.c.uuid_data_nonnative == data, + ) + ).first() + eq_(row, (data, data)) + + def test_uuid_text_round_trip(self, connection): + data = str(uuid.uuid4()) + uuid_table = self.tables.uuid_table + + connection.execute( + uuid_table.insert(), + { + "id": 1, + "uuid_text_data": data, + "uuid_text_data_nonnative": data, + }, + ) + row = connection.execute( + select( + uuid_table.c.uuid_text_data, + uuid_table.c.uuid_text_data_nonnative, + ).where( + uuid_table.c.uuid_text_data == data, + uuid_table.c.uuid_text_data_nonnative == data, + ) + ).first() + eq_((row[0].lower(), row[1].lower()), (data, data)) + + def test_literal_uuid(self, literal_round_trip): + data = uuid.uuid4() + literal_round_trip(self.datatype, [data], [data]) + + def test_literal_text(self, literal_round_trip): + data = str(uuid.uuid4()) + literal_round_trip( + self.datatype(as_uuid=False), + [data], + [data], + filter_=lambda x: x.lower(), + ) + + def test_literal_nonnative_uuid(self, literal_round_trip): + data = uuid.uuid4() + literal_round_trip(Uuid(native_uuid=False), [data], [data]) + + def test_literal_nonnative_text(self, literal_round_trip): + data = str(uuid.uuid4()) + literal_round_trip( + Uuid(as_uuid=False, native_uuid=False), + [data], + [data], + filter_=lambda x: x.lower(), + ) + + @testing.requires.insert_returning + def test_uuid_returning(self, connection): + data = uuid.uuid4() + str_data = str(data) + uuid_table = self.tables.uuid_table + + result = connection.execute( + uuid_table.insert().returning( + uuid_table.c.uuid_data, + uuid_table.c.uuid_text_data, + uuid_table.c.uuid_data_nonnative, + uuid_table.c.uuid_text_data_nonnative, + ), + { + "id": 1, + "uuid_data": data, + "uuid_text_data": str_data, + "uuid_data_nonnative": data, + "uuid_text_data_nonnative": str_data, + }, + ) + row = result.first() + + eq_(row, (data, str_data, data, str_data)) + + +class NativeUUIDTest(UuidTest): + __requires__ = ("uuid_data_type",) + + datatype = UUID + + +__all__ = ( + "ArrayTest", + "BinaryTest", + "UnicodeVarcharTest", + "UnicodeTextTest", + "JSONTest", + "JSONLegacyStringCastIndexTest", + "DateTest", + "DateTimeTest", + "DateTimeTZTest", + "TextTest", + "NumericTest", + "IntegerTest", + "IntervalTest", + "PrecisionIntervalTest", + "CastTypeDecoratorTest", + "DateTimeHistoricTest", + "DateTimeCoercedToDateTimeTest", + "TimeMicrosecondsTest", + "TimestampMicrosecondsTest", + "TimeTest", + "TimeTZTest", + "TrueDivTest", + "DateTimeMicrosecondsTest", + "DateHistoricTest", + "StringTest", + "BooleanTest", + "UuidTest", + "NativeUUIDTest", +) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_unicode_ddl.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_unicode_ddl.py new file mode 100644 index 0000000..1f15ab5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_unicode_ddl.py @@ -0,0 +1,189 @@ +# testing/suite/test_unicode_ddl.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from sqlalchemy import desc +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import testing +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures +from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import Table + + +class UnicodeSchemaTest(fixtures.TablesTest): + __requires__ = ("unicode_ddl",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + global t1, t2, t3 + + t1 = Table( + "unitable1", + metadata, + Column("méil", Integer, primary_key=True), + Column("\u6e2c\u8a66", Integer), + test_needs_fk=True, + ) + t2 = Table( + "Unitéble2", + metadata, + Column("méil", Integer, primary_key=True, key="a"), + Column( + "\u6e2c\u8a66", + Integer, + ForeignKey("unitable1.méil"), + key="b", + ), + test_needs_fk=True, + ) + + # Few DBs support Unicode foreign keys + if testing.against("sqlite"): + t3 = Table( + "\u6e2c\u8a66", + metadata, + Column( + "\u6e2c\u8a66_id", + Integer, + primary_key=True, + autoincrement=False, + ), + Column( + "unitable1_\u6e2c\u8a66", + Integer, + ForeignKey("unitable1.\u6e2c\u8a66"), + ), + Column("Unitéble2_b", Integer, ForeignKey("Unitéble2.b")), + Column( + "\u6e2c\u8a66_self", + Integer, + ForeignKey("\u6e2c\u8a66.\u6e2c\u8a66_id"), + ), + test_needs_fk=True, + ) + else: + t3 = Table( + "\u6e2c\u8a66", + metadata, + Column( + "\u6e2c\u8a66_id", + Integer, + primary_key=True, + autoincrement=False, + ), + Column("unitable1_\u6e2c\u8a66", Integer), + Column("Unitéble2_b", Integer), + Column("\u6e2c\u8a66_self", Integer), + test_needs_fk=True, + ) + + def test_insert(self, connection): + connection.execute(t1.insert(), {"méil": 1, "\u6e2c\u8a66": 5}) + connection.execute(t2.insert(), {"a": 1, "b": 1}) + connection.execute( + t3.insert(), + { + "\u6e2c\u8a66_id": 1, + "unitable1_\u6e2c\u8a66": 5, + "Unitéble2_b": 1, + "\u6e2c\u8a66_self": 1, + }, + ) + + eq_(connection.execute(t1.select()).fetchall(), [(1, 5)]) + eq_(connection.execute(t2.select()).fetchall(), [(1, 1)]) + eq_(connection.execute(t3.select()).fetchall(), [(1, 5, 1, 1)]) + + def test_col_targeting(self, connection): + connection.execute(t1.insert(), {"méil": 1, "\u6e2c\u8a66": 5}) + connection.execute(t2.insert(), {"a": 1, "b": 1}) + connection.execute( + t3.insert(), + { + "\u6e2c\u8a66_id": 1, + "unitable1_\u6e2c\u8a66": 5, + "Unitéble2_b": 1, + "\u6e2c\u8a66_self": 1, + }, + ) + + row = connection.execute(t1.select()).first() + eq_(row._mapping[t1.c["méil"]], 1) + eq_(row._mapping[t1.c["\u6e2c\u8a66"]], 5) + + row = connection.execute(t2.select()).first() + eq_(row._mapping[t2.c["a"]], 1) + eq_(row._mapping[t2.c["b"]], 1) + + row = connection.execute(t3.select()).first() + eq_(row._mapping[t3.c["\u6e2c\u8a66_id"]], 1) + eq_(row._mapping[t3.c["unitable1_\u6e2c\u8a66"]], 5) + eq_(row._mapping[t3.c["Unitéble2_b"]], 1) + eq_(row._mapping[t3.c["\u6e2c\u8a66_self"]], 1) + + def test_reflect(self, connection): + connection.execute(t1.insert(), {"méil": 2, "\u6e2c\u8a66": 7}) + connection.execute(t2.insert(), {"a": 2, "b": 2}) + connection.execute( + t3.insert(), + { + "\u6e2c\u8a66_id": 2, + "unitable1_\u6e2c\u8a66": 7, + "Unitéble2_b": 2, + "\u6e2c\u8a66_self": 2, + }, + ) + + meta = MetaData() + tt1 = Table(t1.name, meta, autoload_with=connection) + tt2 = Table(t2.name, meta, autoload_with=connection) + tt3 = Table(t3.name, meta, autoload_with=connection) + + connection.execute(tt1.insert(), {"méil": 1, "\u6e2c\u8a66": 5}) + connection.execute(tt2.insert(), {"méil": 1, "\u6e2c\u8a66": 1}) + connection.execute( + tt3.insert(), + { + "\u6e2c\u8a66_id": 1, + "unitable1_\u6e2c\u8a66": 5, + "Unitéble2_b": 1, + "\u6e2c\u8a66_self": 1, + }, + ) + + eq_( + connection.execute(tt1.select().order_by(desc("méil"))).fetchall(), + [(2, 7), (1, 5)], + ) + eq_( + connection.execute(tt2.select().order_by(desc("méil"))).fetchall(), + [(2, 2), (1, 1)], + ) + eq_( + connection.execute( + tt3.select().order_by(desc("\u6e2c\u8a66_id")) + ).fetchall(), + [(2, 7, 2, 2), (1, 5, 1, 1)], + ) + + def test_repr(self): + meta = MetaData() + t = Table("\u6e2c\u8a66", meta, Column("\u6e2c\u8a66_id", Integer)) + eq_( + repr(t), + ( + "Table('測試', MetaData(), " + "Column('測試_id', Integer(), " + "table=<測試>), " + "schema=None)" + ), + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_update_delete.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_update_delete.py new file mode 100644 index 0000000..fd4757f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_update_delete.py @@ -0,0 +1,139 @@ +# testing/suite/test_update_delete.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .. import fixtures +from ..assertions import eq_ +from ..schema import Column +from ..schema import Table +from ... import Integer +from ... import String +from ... import testing + + +class SimpleUpdateDeleteTest(fixtures.TablesTest): + run_deletes = "each" + __requires__ = ("sane_rowcount",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "plain_pk", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.plain_pk.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + def test_update(self, connection): + t = self.tables.plain_pk + r = connection.execute( + t.update().where(t.c.id == 2), dict(data="d2_new") + ) + assert not r.is_insert + assert not r.returns_rows + assert r.rowcount == 1 + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + def test_delete(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + assert not r.returns_rows + assert r.rowcount == 1 + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) + + @testing.variation("criteria", ["rows", "norows", "emptyin"]) + @testing.requires.update_returning + def test_update_returning(self, connection, criteria): + t = self.tables.plain_pk + + stmt = t.update().returning(t.c.id, t.c.data) + + if criteria.norows: + stmt = stmt.where(t.c.id == 10) + elif criteria.rows: + stmt = stmt.where(t.c.id == 2) + elif criteria.emptyin: + stmt = stmt.where(t.c.id.in_([])) + else: + criteria.fail() + + r = connection.execute(stmt, dict(data="d2_new")) + assert not r.is_insert + assert r.returns_rows + eq_(r.keys(), ["id", "data"]) + + if criteria.rows: + eq_(r.all(), [(2, "d2_new")]) + else: + eq_(r.all(), []) + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + ( + [(1, "d1"), (2, "d2_new"), (3, "d3")] + if criteria.rows + else [(1, "d1"), (2, "d2"), (3, "d3")] + ), + ) + + @testing.variation("criteria", ["rows", "norows", "emptyin"]) + @testing.requires.delete_returning + def test_delete_returning(self, connection, criteria): + t = self.tables.plain_pk + + stmt = t.delete().returning(t.c.id, t.c.data) + + if criteria.norows: + stmt = stmt.where(t.c.id == 10) + elif criteria.rows: + stmt = stmt.where(t.c.id == 2) + elif criteria.emptyin: + stmt = stmt.where(t.c.id.in_([])) + else: + criteria.fail() + + r = connection.execute(stmt) + assert not r.is_insert + assert r.returns_rows + eq_(r.keys(), ["id", "data"]) + + if criteria.rows: + eq_(r.all(), [(2, "d2")]) + else: + eq_(r.all(), []) + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + ( + [(1, "d1"), (3, "d3")] + if criteria.rows + else [(1, "d1"), (2, "d2"), (3, "d3")] + ), + ) + + +__all__ = ("SimpleUpdateDeleteTest",) |